feat: Add a tool confirmation flow that can guard tool execution with explicit confirmation and custom input

The existing `LongRunningTool` does not define a programmatic way to provide & validate structured input, also it relies on LLM to reason and parse the user's response.

For a quick start, annotate the function with `FunctionTool(my_function, require_confirmation=True)`. A more advanced flow is shown in the `human_tool_confirmation` sample.

The new flow is similar to the existing Auth flow:
- User request a tool confirmation by calling `tool_context.request_confirmation()` in the tool or `before_tool_callback`, or just using the `require_confirmation` shortcut in FunctionTool.
- User can provide custom validation logic before tool call proceeds.
- ADK creates corresponding RequestConfirmation FunctionCall Event to ask user for confirmation
- User needs to provide the expected tool confirmation to a RequestConfirmation FunctionResponse Event.
- ADK then checks the response and continues the tool call.

PiperOrigin-RevId: 801019917
This commit is contained in:
Shangjie Chen
2025-08-29 13:56:13 -07:00
committed by Copybara-Service
parent 3ed9097983
commit a17bcbb2aa
15 changed files with 889 additions and 14 deletions
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
@@ -0,0 +1,80 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk import Agent
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_confirmation import ToolConfirmation
from google.adk.tools.tool_context import ToolContext
from google.genai import types
def reimburse(amount: int, tool_context: ToolContext) -> str:
"""Reimburse the employee for the given amount."""
return {'status': 'ok'}
def request_time_off(days: int, tool_context: ToolContext):
"""Request day off for the employee."""
if days <= 0:
return {'status': 'Invalid days to request.'}
if days <= 2:
return {
'status': 'ok',
'approved_days': days,
}
tool_confirmation = tool_context.tool_confirmation
if not tool_confirmation:
tool_context.request_confirmation(
hint=(
'Please approve or reject the tool call request_time_off() by'
' responding with a FunctionResponse with an expected'
' ToolConfirmation payload.'
),
payload={
'approved_days': 0,
},
)
return {'status': 'Manager approval is required.'}
approved_days = tool_confirmation.payload['approved_days']
approved_days = min(approved_days, days)
if approved_days == 0:
return {'status': 'The time off request is rejected.', 'approved_days': 0}
return {
'status': 'ok',
'approved_days': approved_days,
}
root_agent = Agent(
model='gemini-2.5-flash',
name='time_off_agent',
instruction="""
You are a helpful assistant that can help employees with reimbursement and time off requests.
- Use the `reimburse` tool for reimbursement requests.
- Use the `request_time_off` tool for time off requests.
- Prioritize using tools to fulfill the user's request.
- Always respond to the user with the tool results.
""",
tools=[
# Set require_confirmation to True to require user confirmation for the
# tool call. This is an easier way to get user confirmation if the tool
# just need a boolean confirmation.
FunctionTool(reimburse, require_confirmation=True),
request_time_off,
],
generate_content_config=types.GenerateContentConfig(temperature=0.1),
)
+8
View File
@@ -14,6 +14,7 @@
from __future__ import annotations
from typing import Any
from typing import Optional
from pydantic import alias_generators
@@ -22,6 +23,7 @@ from pydantic import ConfigDict
from pydantic import Field
from ..auth.auth_tool import AuthConfig
from ..tools.tool_confirmation import ToolConfirmation
class EventActions(BaseModel):
@@ -64,3 +66,9 @@ class EventActions(BaseModel):
identify the function call.
- Values: The requested auth config.
"""
requested_tool_confirmations: dict[str, ToolConfirmation] = Field(
default_factory=dict
)
"""A dict of tool confirmation requested by this event, keyed by
function call id."""
@@ -18,3 +18,4 @@ from . import contents
from . import functions
from . import identity
from . import instructions
from . import request_confirmation
@@ -638,6 +638,12 @@ class BaseLlmFlow(ABC):
if auth_event:
yield auth_event
tool_confirmation_event = functions.generate_request_confirmation_event(
invocation_context, function_call_event, function_response_event
)
if tool_confirmation_event:
yield tool_confirmation_event
# Always yield the function response event first
yield function_response_event
+19 -10
View File
@@ -27,6 +27,7 @@ from ...events.event import Event
from ...models.llm_request import LlmRequest
from ._base_llm_processor import BaseLlmRequestProcessor
from .functions import remove_client_function_call_id
from .functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
from .functions import REQUEST_EUC_FUNCTION_CALL_NAME
@@ -238,6 +239,9 @@ def _get_contents(
if _is_auth_event(event):
# Skip auth events.
continue
if _is_request_confirmation_event(event):
# Skip request confirmation events.
continue
filtered_events.append(
_convert_foreign_event(event)
if _is_other_agent_reply(agent_name, event)
@@ -431,18 +435,23 @@ def _is_event_belongs_to_branch(
return invocation_branch.startswith(event.branch)
def _is_auth_event(event: Event) -> bool:
if not event.content.parts:
def _is_function_call_event(event: Event, function_name: str) -> bool:
"""Checks if an event is a function call/response for a given function name."""
if not event.content or not event.content.parts:
return False
for part in event.content.parts:
if (
part.function_call
and part.function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
if part.function_call and part.function_call.name == function_name:
return True
if (
part.function_response
and part.function_response.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
if part.function_response and part.function_response.name == function_name:
return True
return False
def _is_auth_event(event: Event) -> bool:
"""Checks if the event is an authentication event."""
return _is_function_call_event(event, REQUEST_EUC_FUNCTION_CALL_NAME)
def _is_request_confirmation_event(event: Event) -> bool:
"""Checks if the event is a request confirmation event."""
return _is_function_call_event(event, REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)
+74 -2
View File
@@ -39,6 +39,7 @@ from ...telemetry import trace_merged_tool_calls
from ...telemetry import trace_tool_call
from ...telemetry import tracer
from ...tools.base_tool import BaseTool
from ...tools.tool_confirmation import ToolConfirmation
from ...tools.tool_context import ToolContext
from ...utils.context_utils import Aclosing
@@ -47,6 +48,7 @@ if TYPE_CHECKING:
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation'
logger = logging.getLogger('google_adk.' + __name__)
@@ -130,11 +132,76 @@ def generate_auth_event(
)
def generate_request_confirmation_event(
invocation_context: InvocationContext,
function_call_event: Event,
function_response_event: Event,
) -> Optional[Event]:
"""Generates a request confirmation event from a function response event."""
if not function_response_event.actions.requested_tool_confirmations:
return None
parts = []
long_running_tool_ids = set()
function_calls = function_call_event.get_function_calls()
for (
function_call_id,
tool_confirmation,
) in function_response_event.actions.requested_tool_confirmations.items():
original_function_call = next(
(fc for fc in function_calls if fc.id == function_call_id), None
)
if not original_function_call:
continue
request_confirmation_function_call = types.FunctionCall(
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
args={
'originalFunctionCall': original_function_call.model_dump(
exclude_none=True, by_alias=True
),
'toolConfirmation': tool_confirmation.model_dump(
by_alias=True, exclude_none=True
),
},
)
request_confirmation_function_call.id = generate_client_function_call_id()
long_running_tool_ids.add(request_confirmation_function_call.id)
parts.append(types.Part(function_call=request_confirmation_function_call))
return Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
branch=invocation_context.branch,
content=types.Content(
parts=parts, role=function_response_event.content.role
),
long_running_tool_ids=long_running_tool_ids,
)
async def handle_function_calls_async(
invocation_context: InvocationContext,
function_call_event: Event,
tools_dict: dict[str, BaseTool],
filters: Optional[set[str]] = None,
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
) -> Optional[Event]:
"""Calls the functions and returns the function response event."""
function_calls = function_call_event.get_function_calls()
return await handle_function_call_list_async(
invocation_context,
function_calls,
tools_dict,
filters,
tool_confirmation_dict,
)
async def handle_function_call_list_async(
invocation_context: InvocationContext,
function_calls: list[types.FunctionCall],
tools_dict: dict[str, BaseTool],
filters: Optional[set[str]] = None,
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
) -> Optional[Event]:
"""Calls the functions and returns the function response event."""
from ...agents.llm_agent import LlmAgent
@@ -143,8 +210,6 @@ async def handle_function_calls_async(
if not isinstance(agent, LlmAgent):
return None
function_calls = function_call_event.get_function_calls()
# Filter function calls
filtered_calls = [
fc for fc in function_calls if not filters or fc.id in filters
@@ -161,6 +226,9 @@ async def handle_function_calls_async(
function_call,
tools_dict,
agent,
tool_confirmation_dict[function_call.id]
if tool_confirmation_dict
else None,
)
)
for function_call in filtered_calls
@@ -198,12 +266,14 @@ async def _execute_single_function_call_async(
function_call: types.FunctionCall,
tools_dict: dict[str, BaseTool],
agent: LlmAgent,
tool_confirmation: Optional[ToolConfirmation] = None,
) -> Optional[Event]:
"""Execute a single function call with thread safety for state modifications."""
tool, tool_context = _get_tool_and_context(
invocation_context,
function_call,
tools_dict,
tool_confirmation,
)
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
@@ -567,6 +637,7 @@ def _get_tool_and_context(
invocation_context: InvocationContext,
function_call: types.FunctionCall,
tools_dict: dict[str, BaseTool],
tool_confirmation: Optional[ToolConfirmation] = None,
):
if function_call.name not in tools_dict:
raise ValueError(
@@ -576,6 +647,7 @@ def _get_tool_and_context(
tool_context = ToolContext(
invocation_context=invocation_context,
function_call_id=function_call.id,
tool_confirmation=tool_confirmation,
)
tool = tools_dict[function_call.name]
@@ -0,0 +1,169 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import logging
from typing import AsyncGenerator
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from . import functions
from ...agents.invocation_context import InvocationContext
from ...agents.readonly_context import ReadonlyContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
from ...tools.tool_confirmation import ToolConfirmation
from ._base_llm_processor import BaseLlmRequestProcessor
from .functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
if TYPE_CHECKING:
from ...agents.llm_agent import LlmAgent
logger = logging.getLogger('google_adk.' + __name__)
class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor):
"""Handles tool confirmation information to build the LLM request."""
@override
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
events = invocation_context.session.events
if not events:
return
request_confirmation_function_responses = (
dict()
) # {function call id, tool confirmation}
confirmation_event_index = -1
for k in range(len(events) - 1, -1, -1):
event = events[k]
# Find the first event authored by user
if not event.author or event.author != 'user':
continue
responses = event.get_function_responses()
if not responses:
return
for function_response in responses:
if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME:
continue
# Find the FunctionResponse event that contains the user provided tool
# confirmation
if (
function_response.response
and len(function_response.response.values()) == 1
and 'response' in function_response.response.keys()
):
# ADK web client will send a request that is always encapted in a
# 'response' key.
tool_confirmation = ToolConfirmation.model_validate(
json.loads(function_response.response['response'])
)
else:
tool_confirmation = ToolConfirmation.model_validate(
function_response.response
)
request_confirmation_function_responses[function_response.id] = (
tool_confirmation
)
confirmation_event_index = k
break
if not request_confirmation_function_responses:
return
for i in range(len(events) - 2, -1, -1):
event = events[i]
# Find the system generated FunctionCall event requesting the tool
# confirmation
function_calls = event.get_function_calls()
if not function_calls:
continue
tools_to_resume_with_confirmation = (
dict()
) # {Function call id, tool confirmation}
tools_to_resume_with_args = dict() # {Function call id, function calls}
for function_call in function_calls:
if (
function_call.id
not in request_confirmation_function_responses.keys()
):
continue
args = function_call.args
if 'originalFunctionCall' not in args:
continue
original_function_call = types.FunctionCall(
**args['originalFunctionCall']
)
tools_to_resume_with_confirmation[original_function_call.id] = (
request_confirmation_function_responses[function_call.id]
)
tools_to_resume_with_args[original_function_call.id] = (
original_function_call
)
if not tools_to_resume_with_confirmation:
continue
# Remove the tools that have already been confirmed.
for i in range(len(events) - 1, confirmation_event_index, -1):
event = events[i]
function_response = event.get_function_responses()
if not function_response:
continue
for function_response in event.get_function_responses():
if function_response.id in tools_to_resume_with_confirmation:
tools_to_resume_with_confirmation.pop(function_response.id)
tools_to_resume_with_args.pop(function_response.id)
if not tools_to_resume_with_confirmation:
break
if not tools_to_resume_with_confirmation:
continue
if function_response_event := await functions.handle_function_call_list_async(
invocation_context,
tools_to_resume_with_args.values(),
{
tool.name: tool
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
)
},
# There could be parallel function calls that require input
# response would be a dict keyed by function call id
tools_to_resume_with_confirmation.keys(),
tools_to_resume_with_confirmation,
):
yield function_response_event
return
request_processor = _RequestConfirmationLlmRequestProcessor()
@@ -25,6 +25,7 @@ from . import basic
from . import contents
from . import identity
from . import instructions
from . import request_confirmation
from ...auth import auth_preprocessor
from .base_llm_flow import BaseLlmFlow
@@ -43,6 +44,7 @@ class SingleFlow(BaseLlmFlow):
self.request_processors += [
basic.request_processor,
auth_preprocessor.request_processor,
request_confirmation.request_processor,
instructions.request_processor,
identity.request_processor,
contents.request_processor,
+39 -2
View File
@@ -15,6 +15,7 @@
from __future__ import annotations
import inspect
import logging
from typing import Any
from typing import Callable
from typing import Optional
@@ -25,8 +26,11 @@ from typing_extensions import override
from ..utils.context_utils import Aclosing
from ._automatic_function_calling_util import build_function_declaration
from .base_tool import BaseTool
from .tool_confirmation import ToolConfirmation
from .tool_context import ToolContext
logger = logging.getLogger('google_adk.' + __name__)
class FunctionTool(BaseTool):
"""A tool that wraps a user-defined Python function.
@@ -35,8 +39,15 @@ class FunctionTool(BaseTool):
func: The function to wrap.
"""
def __init__(self, func: Callable[..., Any]):
"""Extract metadata from a callable object."""
def __init__(
self, func: Callable[..., Any], *, require_confirmation: bool = False
):
"""Initializes the FunctionTool. Extracts metadata from a callable object.
Args:
func: The function to wrap.
require_confirmation: Whether the tool call requires user confirmation.
"""
name = ''
doc = ''
# Handle different types of callables
@@ -61,6 +72,7 @@ class FunctionTool(BaseTool):
super().__init__(name=name, description=doc)
self.func = func
self._ignore_params = ['tool_context', 'input_stream']
self._require_confirmation = require_confirmation
@override
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
@@ -106,6 +118,29 @@ class FunctionTool(BaseTool):
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
return {'error': error_str}
if self._require_confirmation:
if not tool_context.tool_confirmation:
args_to_show = args_to_call.copy()
if 'tool_context' in args_to_show:
args_to_show.pop('tool_context')
tool_context.request_confirmation(
hint=(
f'Please approve or reject the tool call {self.name}() by'
' responding with a FunctionResponse with an expected'
' ToolConfirmation payload.'
),
)
return {
'error': (
'This tool call requires confirmation, please approve or'
' reject.'
)
}
else:
if not tool_context.tool_confirmation.confirmed:
return {'error': 'This tool call is rejected.'}
# Functions are callable objects, but not all callable objects are functions
# checking coroutine function is not enough. We also need to check whether
# Callable's __call__ function is a coroutine funciton
@@ -137,6 +172,8 @@ You could retry calling this tool, but it is IMPORTANT for you to provide all th
].stream
if 'tool_context' in signature.parameters:
args_to_call['tool_context'] = tool_context
# TODO: support tool confirmation for live mode.
async with Aclosing(self.func(**args_to_call)) as agen:
async for item in agen:
yield item
+45
View File
@@ -0,0 +1,45 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import Optional
from pydantic import alias_generators
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from ..utils.feature_decorator import experimental
@experimental
class ToolConfirmation(BaseModel):
"""Represents a tool confirmation configuration."""
model_config = ConfigDict(
extra="forbid",
alias_generator=alias_generators.to_camel,
populate_by_name=True,
)
"""The pydantic model config."""
hint: str = ""
"""The hint text for why the input is needed."""
confirmed: bool = False
"""Whether the tool excution is confirmed."""
payload: Optional[Any] = None
"""The custom data payload needed from the user to continue the flow.
It should be JSON serializable."""
+26
View File
@@ -14,6 +14,7 @@
from __future__ import annotations
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
@@ -21,6 +22,7 @@ from ..agents.callback_context import CallbackContext
from ..auth.auth_credential import AuthCredential
from ..auth.auth_handler import AuthHandler
from ..auth.auth_tool import AuthConfig
from .tool_confirmation import ToolConfirmation
if TYPE_CHECKING:
from ..agents.invocation_context import InvocationContext
@@ -43,6 +45,7 @@ class ToolContext(CallbackContext):
If LLM didn't return this id, ADK will assign one to it. This id is used
to map function call response to the original function call.
event_actions: The event actions of the current tool call.
tool_confirmation: The tool confirmation of the current tool call.
"""
def __init__(
@@ -51,9 +54,11 @@ class ToolContext(CallbackContext):
*,
function_call_id: Optional[str] = None,
event_actions: Optional[EventActions] = None,
tool_confirmation: Optional[ToolConfirmation] = None,
):
super().__init__(invocation_context, event_actions=event_actions)
self.function_call_id = function_call_id
self.tool_confirmation = tool_confirmation
@property
def actions(self) -> EventActions:
@@ -69,6 +74,27 @@ class ToolContext(CallbackContext):
def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
return AuthHandler(auth_config).get_auth_response(self.state)
def request_confirmation(
self,
*,
hint: Optional[str] = None,
payload: Optional[Any] = None,
) -> None:
"""Requests confirmation for the given function call.
Args:
hint: A hint to the user on how to confirm the tool call.
payload: The payload used to confirm the tool call.
"""
if not self.function_call_id:
raise ValueError('function_call_id is not set.')
self._event_actions.requested_tool_confirmations[self.function_call_id] = (
ToolConfirmation(
hint=hint,
payload=payload,
)
)
async def search_memory(self, query: str) -> SearchMemoryResponse:
"""Searches the memory of the current user."""
if self._invocation_context.memory_service is None:
@@ -162,6 +162,60 @@ def test_get_contents_filters_empty_events():
assert contents_result[0].parts[0].text == "Hello"
def test_get_contents_filters_auth_and_confirmation_events():
"""Test _get_contents filters out auth and request confirmation events."""
auth_event = Event(
invocation_id="test_inv",
author="agent",
content=types.Content(
role="model",
parts=[
types.Part(
function_call=types.FunctionCall(
id="auth_func",
name=contents.REQUEST_EUC_FUNCTION_CALL_NAME,
args={},
)
)
],
),
)
confirmation_event = Event(
invocation_id="test_inv",
author="agent",
content=types.Content(
role="model",
parts=[
types.Part(
function_call=types.FunctionResponse(
id="confirm_func",
name=contents.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
response={
"confirmed": True,
},
)
)
],
),
)
valid_event = Event(
invocation_id="test_inv",
author="user",
content=types.Content(
role="user", parts=[types.Part.from_text(text="Hello")]
),
)
contents_result = _get_contents(
None, [auth_event, confirmation_event, valid_event], "test_agent"
)
assert len(contents_result) == 1
assert contents_result[0].role == "user"
assert contents_result[0].parts[0].text == "Hello"
def test_convert_foreign_event():
"""Test _convert_foreign_event function."""
agent_event = Event(
@@ -0,0 +1,302 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from unittest.mock import patch
from google.adk.agents.llm_agent import LlmAgent
from google.adk.events.event import Event
from google.adk.flows.llm_flows import functions
from google.adk.flows.llm_flows.request_confirmation import request_processor
from google.adk.models.llm_request import LlmRequest
from google.adk.tools.tool_confirmation import ToolConfirmation
from google.genai import types
import pytest
from ... import testing_utils
MOCK_TOOL_NAME = "mock_tool"
MOCK_FUNCTION_CALL_ID = "mock_function_call_id"
MOCK_CONFIRMATION_FUNCTION_CALL_ID = "mock_confirmation_function_call_id"
def mock_tool(param1: str):
"""Mock tool function."""
return f"Mock tool result with {param1}"
@pytest.mark.asyncio
async def test_request_confirmation_processor_no_events():
"""Test that the processor returns None when there are no events."""
agent = LlmAgent(name="test_agent", tools=[mock_tool])
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
llm_request = LlmRequest()
events = []
async for event in request_processor.run_async(
invocation_context, llm_request
):
events.append(event)
assert not events
@pytest.mark.asyncio
async def test_request_confirmation_processor_no_function_responses():
"""Test that the processor returns None when the user event has no function responses."""
agent = LlmAgent(name="test_agent", tools=[mock_tool])
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
llm_request = LlmRequest()
invocation_context.session.events.append(
Event(author="user", content=types.Content())
)
events = []
async for event in request_processor.run_async(
invocation_context, llm_request
):
events.append(event)
assert not events
@pytest.mark.asyncio
async def test_request_confirmation_processor_no_confirmation_function_response():
"""Test that the processor returns None when no confirmation function response is present."""
agent = LlmAgent(name="test_agent", tools=[mock_tool])
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
llm_request = LlmRequest()
invocation_context.session.events.append(
Event(
author="user",
content=types.Content(
parts=[
types.Part(
function_response=types.FunctionResponse(
name="other_function", response={}
)
)
]
),
)
)
events = []
async for event in request_processor.run_async(
invocation_context, llm_request
):
events.append(event)
assert not events
@pytest.mark.asyncio
async def test_request_confirmation_processor_success():
"""Test the successful processing of a tool confirmation."""
agent = LlmAgent(name="test_agent", tools=[mock_tool])
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
llm_request = LlmRequest()
original_function_call = types.FunctionCall(
name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID
)
tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint")
tool_confirmation_args = {
"originalFunctionCall": original_function_call.model_dump(
exclude_none=True, by_alias=True
),
"toolConfirmation": tool_confirmation.model_dump(
by_alias=True, exclude_none=True
),
}
# Event with the request for confirmation
invocation_context.session.events.append(
Event(
author="agent",
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall(
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
args=tool_confirmation_args,
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
)
)
]
),
)
)
# Event with the user's confirmation
user_confirmation = ToolConfirmation(confirmed=True)
invocation_context.session.events.append(
Event(
author="user",
content=types.Content(
parts=[
types.Part(
function_response=types.FunctionResponse(
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
response={
"response": user_confirmation.model_dump_json()
},
)
)
]
),
)
)
expected_event = Event(
author="agent",
content=types.Content(
parts=[
types.Part(
function_response=types.FunctionResponse(
name=MOCK_TOOL_NAME,
id=MOCK_FUNCTION_CALL_ID,
response={"result": "Mock tool result with test"},
)
)
]
),
)
with patch(
"google.adk.flows.llm_flows.functions.handle_function_call_list_async"
) as mock_handle_function_call_list_async:
mock_handle_function_call_list_async.return_value = expected_event
events = []
async for event in request_processor.run_async(
invocation_context, llm_request
):
events.append(event)
assert len(events) == 1
assert events[0] == expected_event
mock_handle_function_call_list_async.assert_called_once()
args, _ = mock_handle_function_call_list_async.call_args
assert list(args[1]) == [original_function_call] # function_calls
assert args[3] == {MOCK_FUNCTION_CALL_ID} # tools_to_confirm
assert (
args[4][MOCK_FUNCTION_CALL_ID] == user_confirmation
) # tool_confirmation_dict
@pytest.mark.asyncio
async def test_request_confirmation_processor_tool_not_confirmed():
"""Test when the tool execution is not confirmed by the user."""
agent = LlmAgent(name="test_agent", tools=[mock_tool])
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
llm_request = LlmRequest()
original_function_call = types.FunctionCall(
name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID
)
tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint")
tool_confirmation_args = {
"originalFunctionCall": original_function_call.model_dump(
exclude_none=True, by_alias=True
),
"toolConfirmation": tool_confirmation.model_dump(
by_alias=True, exclude_none=True
),
}
invocation_context.session.events.append(
Event(
author="agent",
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall(
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
args=tool_confirmation_args,
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
)
)
]
),
)
)
user_confirmation = ToolConfirmation(confirmed=False)
invocation_context.session.events.append(
Event(
author="user",
content=types.Content(
parts=[
types.Part(
function_response=types.FunctionResponse(
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
response={
"response": user_confirmation.model_dump_json()
},
)
)
]
),
)
)
with patch(
"google.adk.flows.llm_flows.functions.handle_function_call_list_async"
) as mock_handle_function_call_list_async:
mock_handle_function_call_list_async.return_value = Event(
author="agent",
content=types.Content(
parts=[
types.Part(
function_response=types.FunctionResponse(
name=MOCK_TOOL_NAME,
id=MOCK_FUNCTION_CALL_ID,
response={"error": "Tool execution not confirmed"},
)
)
]
),
)
events = []
async for event in request_processor.run_async(
invocation_context, llm_request
):
events.append(event)
assert len(events) == 1
mock_handle_function_call_list_async.assert_called_once()
args, _ = mock_handle_function_call_list_async.call_args
assert (
args[4][MOCK_FUNCTION_CALL_ID] == user_confirmation
) # tool_confirmation_dict
@@ -17,6 +17,7 @@ from unittest.mock import MagicMock
from google.adk.agents.invocation_context import InvocationContext
from google.adk.sessions.session import Session
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_confirmation import ToolConfirmation
from google.adk.tools.tool_context import ToolContext
import pytest
@@ -345,3 +346,51 @@ async def test_run_async_with_tool_context_and_unexpected_argument():
"received_arg": "world",
"context_present": True,
}
@pytest.mark.asyncio
async def test_run_async_with_require_confirmation():
"""Test that run_async handles require_confirmation flag."""
def sample_func(arg1: str):
return {"received_arg": arg1}
tool = FunctionTool(sample_func, require_confirmation=True)
mock_invocation_context = MagicMock(spec=InvocationContext)
mock_invocation_context.session = MagicMock(spec=Session)
mock_invocation_context.session.state = MagicMock()
mock_invocation_context.agent = MagicMock()
mock_invocation_context.agent.name = "test_agent"
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)
tool_context_mock.function_call_id = "test_function_call_id"
# First call, should request confirmation
result = await tool.run_async(
args={"arg1": "hello"},
tool_context=tool_context_mock,
)
assert result == {
"error": "This tool call requires confirmation, please approve or reject."
}
assert tool_context_mock._event_actions.requested_tool_confirmations[
"test_function_call_id"
].hint == (
"Please approve or reject the tool call sample_func() by responding with"
" a FunctionResponse with an expected ToolConfirmation payload."
)
# Second call, user rejects
tool_context_mock.tool_confirmation = ToolConfirmation(confirmed=False)
result = await tool.run_async(
args={"arg1": "hello"},
tool_context=tool_context_mock,
)
assert result == {"error": "This tool call is rejected."}
# Third call, user approves
tool_context_mock.tool_confirmation = ToolConfirmation(confirmed=True)
result = await tool.run_async(
args={"arg1": "hello"},
tool_context=tool_context_mock,
)
assert result == {"received_arg": "hello"}