You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add a tool confirmation flow that can guard tool execution with explicit confirmation and custom input
The existing `LongRunningTool` does not define a programmatic way to provide & validate structured input, also it relies on LLM to reason and parse the user's response. For a quick start, annotate the function with `FunctionTool(my_function, require_confirmation=True)`. A more advanced flow is shown in the `human_tool_confirmation` sample. The new flow is similar to the existing Auth flow: - User request a tool confirmation by calling `tool_context.request_confirmation()` in the tool or `before_tool_callback`, or just using the `require_confirmation` shortcut in FunctionTool. - User can provide custom validation logic before tool call proceeds. - ADK creates corresponding RequestConfirmation FunctionCall Event to ask user for confirmation - User needs to provide the expected tool confirmation to a RequestConfirmation FunctionResponse Event. - ADK then checks the response and continues the tool call. PiperOrigin-RevId: 801019917
This commit is contained in:
committed by
Copybara-Service
parent
3ed9097983
commit
a17bcbb2aa
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,80 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.tool_confirmation import ToolConfirmation
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def reimburse(amount: int, tool_context: ToolContext) -> str:
|
||||
"""Reimburse the employee for the given amount."""
|
||||
return {'status': 'ok'}
|
||||
|
||||
|
||||
def request_time_off(days: int, tool_context: ToolContext):
|
||||
"""Request day off for the employee."""
|
||||
if days <= 0:
|
||||
return {'status': 'Invalid days to request.'}
|
||||
|
||||
if days <= 2:
|
||||
return {
|
||||
'status': 'ok',
|
||||
'approved_days': days,
|
||||
}
|
||||
|
||||
tool_confirmation = tool_context.tool_confirmation
|
||||
if not tool_confirmation:
|
||||
tool_context.request_confirmation(
|
||||
hint=(
|
||||
'Please approve or reject the tool call request_time_off() by'
|
||||
' responding with a FunctionResponse with an expected'
|
||||
' ToolConfirmation payload.'
|
||||
),
|
||||
payload={
|
||||
'approved_days': 0,
|
||||
},
|
||||
)
|
||||
return {'status': 'Manager approval is required.'}
|
||||
|
||||
approved_days = tool_confirmation.payload['approved_days']
|
||||
approved_days = min(approved_days, days)
|
||||
if approved_days == 0:
|
||||
return {'status': 'The time off request is rejected.', 'approved_days': 0}
|
||||
return {
|
||||
'status': 'ok',
|
||||
'approved_days': approved_days,
|
||||
}
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-2.5-flash',
|
||||
name='time_off_agent',
|
||||
instruction="""
|
||||
You are a helpful assistant that can help employees with reimbursement and time off requests.
|
||||
- Use the `reimburse` tool for reimbursement requests.
|
||||
- Use the `request_time_off` tool for time off requests.
|
||||
- Prioritize using tools to fulfill the user's request.
|
||||
- Always respond to the user with the tool results.
|
||||
""",
|
||||
tools=[
|
||||
# Set require_confirmation to True to require user confirmation for the
|
||||
# tool call. This is an easier way to get user confirmation if the tool
|
||||
# just need a boolean confirmation.
|
||||
FunctionTool(reimburse, require_confirmation=True),
|
||||
request_time_off,
|
||||
],
|
||||
generate_content_config=types.GenerateContentConfig(temperature=0.1),
|
||||
)
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import alias_generators
|
||||
@@ -22,6 +23,7 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..auth.auth_tool import AuthConfig
|
||||
from ..tools.tool_confirmation import ToolConfirmation
|
||||
|
||||
|
||||
class EventActions(BaseModel):
|
||||
@@ -64,3 +66,9 @@ class EventActions(BaseModel):
|
||||
identify the function call.
|
||||
- Values: The requested auth config.
|
||||
"""
|
||||
|
||||
requested_tool_confirmations: dict[str, ToolConfirmation] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
"""A dict of tool confirmation requested by this event, keyed by
|
||||
function call id."""
|
||||
|
||||
@@ -18,3 +18,4 @@ from . import contents
|
||||
from . import functions
|
||||
from . import identity
|
||||
from . import instructions
|
||||
from . import request_confirmation
|
||||
|
||||
@@ -638,6 +638,12 @@ class BaseLlmFlow(ABC):
|
||||
if auth_event:
|
||||
yield auth_event
|
||||
|
||||
tool_confirmation_event = functions.generate_request_confirmation_event(
|
||||
invocation_context, function_call_event, function_response_event
|
||||
)
|
||||
if tool_confirmation_event:
|
||||
yield tool_confirmation_event
|
||||
|
||||
# Always yield the function response event first
|
||||
yield function_response_event
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from .functions import remove_client_function_call_id
|
||||
from .functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
|
||||
from .functions import REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
|
||||
|
||||
@@ -238,6 +239,9 @@ def _get_contents(
|
||||
if _is_auth_event(event):
|
||||
# Skip auth events.
|
||||
continue
|
||||
if _is_request_confirmation_event(event):
|
||||
# Skip request confirmation events.
|
||||
continue
|
||||
filtered_events.append(
|
||||
_convert_foreign_event(event)
|
||||
if _is_other_agent_reply(agent_name, event)
|
||||
@@ -431,18 +435,23 @@ def _is_event_belongs_to_branch(
|
||||
return invocation_branch.startswith(event.branch)
|
||||
|
||||
|
||||
def _is_auth_event(event: Event) -> bool:
|
||||
if not event.content.parts:
|
||||
def _is_function_call_event(event: Event, function_name: str) -> bool:
|
||||
"""Checks if an event is a function call/response for a given function name."""
|
||||
if not event.content or not event.content.parts:
|
||||
return False
|
||||
for part in event.content.parts:
|
||||
if (
|
||||
part.function_call
|
||||
and part.function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
):
|
||||
if part.function_call and part.function_call.name == function_name:
|
||||
return True
|
||||
if (
|
||||
part.function_response
|
||||
and part.function_response.name == REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
):
|
||||
if part.function_response and part.function_response.name == function_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_auth_event(event: Event) -> bool:
|
||||
"""Checks if the event is an authentication event."""
|
||||
return _is_function_call_event(event, REQUEST_EUC_FUNCTION_CALL_NAME)
|
||||
|
||||
|
||||
def _is_request_confirmation_event(event: Event) -> bool:
|
||||
"""Checks if the event is a request confirmation event."""
|
||||
return _is_function_call_event(event, REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)
|
||||
|
||||
@@ -39,6 +39,7 @@ from ...telemetry import trace_merged_tool_calls
|
||||
from ...telemetry import trace_tool_call
|
||||
from ...telemetry import tracer
|
||||
from ...tools.base_tool import BaseTool
|
||||
from ...tools.tool_confirmation import ToolConfirmation
|
||||
from ...tools.tool_context import ToolContext
|
||||
from ...utils.context_utils import Aclosing
|
||||
|
||||
@@ -47,6 +48,7 @@ if TYPE_CHECKING:
|
||||
|
||||
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
|
||||
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
|
||||
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation'
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
@@ -130,11 +132,76 @@ def generate_auth_event(
|
||||
)
|
||||
|
||||
|
||||
def generate_request_confirmation_event(
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
function_response_event: Event,
|
||||
) -> Optional[Event]:
|
||||
"""Generates a request confirmation event from a function response event."""
|
||||
if not function_response_event.actions.requested_tool_confirmations:
|
||||
return None
|
||||
parts = []
|
||||
long_running_tool_ids = set()
|
||||
function_calls = function_call_event.get_function_calls()
|
||||
for (
|
||||
function_call_id,
|
||||
tool_confirmation,
|
||||
) in function_response_event.actions.requested_tool_confirmations.items():
|
||||
original_function_call = next(
|
||||
(fc for fc in function_calls if fc.id == function_call_id), None
|
||||
)
|
||||
if not original_function_call:
|
||||
continue
|
||||
request_confirmation_function_call = types.FunctionCall(
|
||||
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
args={
|
||||
'originalFunctionCall': original_function_call.model_dump(
|
||||
exclude_none=True, by_alias=True
|
||||
),
|
||||
'toolConfirmation': tool_confirmation.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
},
|
||||
)
|
||||
request_confirmation_function_call.id = generate_client_function_call_id()
|
||||
long_running_tool_ids.add(request_confirmation_function_call.id)
|
||||
parts.append(types.Part(function_call=request_confirmation_function_call))
|
||||
|
||||
return Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
branch=invocation_context.branch,
|
||||
content=types.Content(
|
||||
parts=parts, role=function_response_event.content.role
|
||||
),
|
||||
long_running_tool_ids=long_running_tool_ids,
|
||||
)
|
||||
|
||||
|
||||
async def handle_function_calls_async(
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
filters: Optional[set[str]] = None,
|
||||
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
|
||||
) -> Optional[Event]:
|
||||
"""Calls the functions and returns the function response event."""
|
||||
function_calls = function_call_event.get_function_calls()
|
||||
return await handle_function_call_list_async(
|
||||
invocation_context,
|
||||
function_calls,
|
||||
tools_dict,
|
||||
filters,
|
||||
tool_confirmation_dict,
|
||||
)
|
||||
|
||||
|
||||
async def handle_function_call_list_async(
|
||||
invocation_context: InvocationContext,
|
||||
function_calls: list[types.FunctionCall],
|
||||
tools_dict: dict[str, BaseTool],
|
||||
filters: Optional[set[str]] = None,
|
||||
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
|
||||
) -> Optional[Event]:
|
||||
"""Calls the functions and returns the function response event."""
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
@@ -143,8 +210,6 @@ async def handle_function_calls_async(
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return None
|
||||
|
||||
function_calls = function_call_event.get_function_calls()
|
||||
|
||||
# Filter function calls
|
||||
filtered_calls = [
|
||||
fc for fc in function_calls if not filters or fc.id in filters
|
||||
@@ -161,6 +226,9 @@ async def handle_function_calls_async(
|
||||
function_call,
|
||||
tools_dict,
|
||||
agent,
|
||||
tool_confirmation_dict[function_call.id]
|
||||
if tool_confirmation_dict
|
||||
else None,
|
||||
)
|
||||
)
|
||||
for function_call in filtered_calls
|
||||
@@ -198,12 +266,14 @@ async def _execute_single_function_call_async(
|
||||
function_call: types.FunctionCall,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
agent: LlmAgent,
|
||||
tool_confirmation: Optional[ToolConfirmation] = None,
|
||||
) -> Optional[Event]:
|
||||
"""Execute a single function call with thread safety for state modifications."""
|
||||
tool, tool_context = _get_tool_and_context(
|
||||
invocation_context,
|
||||
function_call,
|
||||
tools_dict,
|
||||
tool_confirmation,
|
||||
)
|
||||
|
||||
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
|
||||
@@ -567,6 +637,7 @@ def _get_tool_and_context(
|
||||
invocation_context: InvocationContext,
|
||||
function_call: types.FunctionCall,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
tool_confirmation: Optional[ToolConfirmation] = None,
|
||||
):
|
||||
if function_call.name not in tools_dict:
|
||||
raise ValueError(
|
||||
@@ -576,6 +647,7 @@ def _get_tool_and_context(
|
||||
tool_context = ToolContext(
|
||||
invocation_context=invocation_context,
|
||||
function_call_id=function_call.id,
|
||||
tool_confirmation=tool_confirmation,
|
||||
)
|
||||
|
||||
tool = tools_dict[function_call.name]
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from . import functions
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...tools.tool_confirmation import ToolConfirmation
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from .functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Handles tool confirmation information to build the LLM request."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
events = invocation_context.session.events
|
||||
if not events:
|
||||
return
|
||||
|
||||
request_confirmation_function_responses = (
|
||||
dict()
|
||||
) # {function call id, tool confirmation}
|
||||
|
||||
confirmation_event_index = -1
|
||||
for k in range(len(events) - 1, -1, -1):
|
||||
event = events[k]
|
||||
# Find the first event authored by user
|
||||
if not event.author or event.author != 'user':
|
||||
continue
|
||||
responses = event.get_function_responses()
|
||||
if not responses:
|
||||
return
|
||||
|
||||
for function_response in responses:
|
||||
if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME:
|
||||
continue
|
||||
|
||||
# Find the FunctionResponse event that contains the user provided tool
|
||||
# confirmation
|
||||
if (
|
||||
function_response.response
|
||||
and len(function_response.response.values()) == 1
|
||||
and 'response' in function_response.response.keys()
|
||||
):
|
||||
# ADK web client will send a request that is always encapted in a
|
||||
# 'response' key.
|
||||
tool_confirmation = ToolConfirmation.model_validate(
|
||||
json.loads(function_response.response['response'])
|
||||
)
|
||||
else:
|
||||
tool_confirmation = ToolConfirmation.model_validate(
|
||||
function_response.response
|
||||
)
|
||||
request_confirmation_function_responses[function_response.id] = (
|
||||
tool_confirmation
|
||||
)
|
||||
confirmation_event_index = k
|
||||
break
|
||||
|
||||
if not request_confirmation_function_responses:
|
||||
return
|
||||
|
||||
for i in range(len(events) - 2, -1, -1):
|
||||
event = events[i]
|
||||
# Find the system generated FunctionCall event requesting the tool
|
||||
# confirmation
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
|
||||
tools_to_resume_with_confirmation = (
|
||||
dict()
|
||||
) # {Function call id, tool confirmation}
|
||||
tools_to_resume_with_args = dict() # {Function call id, function calls}
|
||||
|
||||
for function_call in function_calls:
|
||||
if (
|
||||
function_call.id
|
||||
not in request_confirmation_function_responses.keys()
|
||||
):
|
||||
continue
|
||||
|
||||
args = function_call.args
|
||||
if 'originalFunctionCall' not in args:
|
||||
continue
|
||||
original_function_call = types.FunctionCall(
|
||||
**args['originalFunctionCall']
|
||||
)
|
||||
tools_to_resume_with_confirmation[original_function_call.id] = (
|
||||
request_confirmation_function_responses[function_call.id]
|
||||
)
|
||||
tools_to_resume_with_args[original_function_call.id] = (
|
||||
original_function_call
|
||||
)
|
||||
if not tools_to_resume_with_confirmation:
|
||||
continue
|
||||
|
||||
# Remove the tools that have already been confirmed.
|
||||
for i in range(len(events) - 1, confirmation_event_index, -1):
|
||||
event = events[i]
|
||||
function_response = event.get_function_responses()
|
||||
if not function_response:
|
||||
continue
|
||||
|
||||
for function_response in event.get_function_responses():
|
||||
if function_response.id in tools_to_resume_with_confirmation:
|
||||
tools_to_resume_with_confirmation.pop(function_response.id)
|
||||
tools_to_resume_with_args.pop(function_response.id)
|
||||
if not tools_to_resume_with_confirmation:
|
||||
break
|
||||
|
||||
if not tools_to_resume_with_confirmation:
|
||||
continue
|
||||
|
||||
if function_response_event := await functions.handle_function_call_list_async(
|
||||
invocation_context,
|
||||
tools_to_resume_with_args.values(),
|
||||
{
|
||||
tool.name: tool
|
||||
for tool in await agent.canonical_tools(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
},
|
||||
# There could be parallel function calls that require input
|
||||
# response would be a dict keyed by function call id
|
||||
tools_to_resume_with_confirmation.keys(),
|
||||
tools_to_resume_with_confirmation,
|
||||
):
|
||||
yield function_response_event
|
||||
return
|
||||
|
||||
|
||||
request_processor = _RequestConfirmationLlmRequestProcessor()
|
||||
@@ -25,6 +25,7 @@ from . import basic
|
||||
from . import contents
|
||||
from . import identity
|
||||
from . import instructions
|
||||
from . import request_confirmation
|
||||
from ...auth import auth_preprocessor
|
||||
from .base_llm_flow import BaseLlmFlow
|
||||
|
||||
@@ -43,6 +44,7 @@ class SingleFlow(BaseLlmFlow):
|
||||
self.request_processors += [
|
||||
basic.request_processor,
|
||||
auth_preprocessor.request_processor,
|
||||
request_confirmation.request_processor,
|
||||
instructions.request_processor,
|
||||
identity.request_processor,
|
||||
contents.request_processor,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
@@ -25,8 +26,11 @@ from typing_extensions import override
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ._automatic_function_calling_util import build_function_declaration
|
||||
from .base_tool import BaseTool
|
||||
from .tool_confirmation import ToolConfirmation
|
||||
from .tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
class FunctionTool(BaseTool):
|
||||
"""A tool that wraps a user-defined Python function.
|
||||
@@ -35,8 +39,15 @@ class FunctionTool(BaseTool):
|
||||
func: The function to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[..., Any]):
|
||||
"""Extract metadata from a callable object."""
|
||||
def __init__(
|
||||
self, func: Callable[..., Any], *, require_confirmation: bool = False
|
||||
):
|
||||
"""Initializes the FunctionTool. Extracts metadata from a callable object.
|
||||
|
||||
Args:
|
||||
func: The function to wrap.
|
||||
require_confirmation: Whether the tool call requires user confirmation.
|
||||
"""
|
||||
name = ''
|
||||
doc = ''
|
||||
# Handle different types of callables
|
||||
@@ -61,6 +72,7 @@ class FunctionTool(BaseTool):
|
||||
super().__init__(name=name, description=doc)
|
||||
self.func = func
|
||||
self._ignore_params = ['tool_context', 'input_stream']
|
||||
self._require_confirmation = require_confirmation
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
@@ -106,6 +118,29 @@ class FunctionTool(BaseTool):
|
||||
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||
return {'error': error_str}
|
||||
|
||||
if self._require_confirmation:
|
||||
if not tool_context.tool_confirmation:
|
||||
args_to_show = args_to_call.copy()
|
||||
if 'tool_context' in args_to_show:
|
||||
args_to_show.pop('tool_context')
|
||||
|
||||
tool_context.request_confirmation(
|
||||
hint=(
|
||||
f'Please approve or reject the tool call {self.name}() by'
|
||||
' responding with a FunctionResponse with an expected'
|
||||
' ToolConfirmation payload.'
|
||||
),
|
||||
)
|
||||
return {
|
||||
'error': (
|
||||
'This tool call requires confirmation, please approve or'
|
||||
' reject.'
|
||||
)
|
||||
}
|
||||
else:
|
||||
if not tool_context.tool_confirmation.confirmed:
|
||||
return {'error': 'This tool call is rejected.'}
|
||||
|
||||
# Functions are callable objects, but not all callable objects are functions
|
||||
# checking coroutine function is not enough. We also need to check whether
|
||||
# Callable's __call__ function is a coroutine funciton
|
||||
@@ -137,6 +172,8 @@ You could retry calling this tool, but it is IMPORTANT for you to provide all th
|
||||
].stream
|
||||
if 'tool_context' in signature.parameters:
|
||||
args_to_call['tool_context'] = tool_context
|
||||
|
||||
# TODO: support tool confirmation for live mode.
|
||||
async with Aclosing(self.func(**args_to_call)) as agen:
|
||||
async for item in agen:
|
||||
yield item
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import alias_generators
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..utils.feature_decorator import experimental
|
||||
|
||||
|
||||
@experimental
|
||||
class ToolConfirmation(BaseModel):
|
||||
"""Represents a tool confirmation configuration."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
alias_generator=alias_generators.to_camel,
|
||||
populate_by_name=True,
|
||||
)
|
||||
"""The pydantic model config."""
|
||||
|
||||
hint: str = ""
|
||||
"""The hint text for why the input is needed."""
|
||||
confirmed: bool = False
|
||||
"""Whether the tool excution is confirmed."""
|
||||
payload: Optional[Any] = None
|
||||
"""The custom data payload needed from the user to continue the flow.
|
||||
It should be JSON serializable."""
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -21,6 +22,7 @@ from ..agents.callback_context import CallbackContext
|
||||
from ..auth.auth_credential import AuthCredential
|
||||
from ..auth.auth_handler import AuthHandler
|
||||
from ..auth.auth_tool import AuthConfig
|
||||
from .tool_confirmation import ToolConfirmation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
@@ -43,6 +45,7 @@ class ToolContext(CallbackContext):
|
||||
If LLM didn't return this id, ADK will assign one to it. This id is used
|
||||
to map function call response to the original function call.
|
||||
event_actions: The event actions of the current tool call.
|
||||
tool_confirmation: The tool confirmation of the current tool call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -51,9 +54,11 @@ class ToolContext(CallbackContext):
|
||||
*,
|
||||
function_call_id: Optional[str] = None,
|
||||
event_actions: Optional[EventActions] = None,
|
||||
tool_confirmation: Optional[ToolConfirmation] = None,
|
||||
):
|
||||
super().__init__(invocation_context, event_actions=event_actions)
|
||||
self.function_call_id = function_call_id
|
||||
self.tool_confirmation = tool_confirmation
|
||||
|
||||
@property
|
||||
def actions(self) -> EventActions:
|
||||
@@ -69,6 +74,27 @@ class ToolContext(CallbackContext):
|
||||
def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
|
||||
return AuthHandler(auth_config).get_auth_response(self.state)
|
||||
|
||||
def request_confirmation(
|
||||
self,
|
||||
*,
|
||||
hint: Optional[str] = None,
|
||||
payload: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Requests confirmation for the given function call.
|
||||
|
||||
Args:
|
||||
hint: A hint to the user on how to confirm the tool call.
|
||||
payload: The payload used to confirm the tool call.
|
||||
"""
|
||||
if not self.function_call_id:
|
||||
raise ValueError('function_call_id is not set.')
|
||||
self._event_actions.requested_tool_confirmations[self.function_call_id] = (
|
||||
ToolConfirmation(
|
||||
hint=hint,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
|
||||
async def search_memory(self, query: str) -> SearchMemoryResponse:
|
||||
"""Searches the memory of the current user."""
|
||||
if self._invocation_context.memory_service is None:
|
||||
|
||||
@@ -162,6 +162,60 @@ def test_get_contents_filters_empty_events():
|
||||
assert contents_result[0].parts[0].text == "Hello"
|
||||
|
||||
|
||||
def test_get_contents_filters_auth_and_confirmation_events():
|
||||
"""Test _get_contents filters out auth and request confirmation events."""
|
||||
auth_event = Event(
|
||||
invocation_id="test_inv",
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
id="auth_func",
|
||||
name=contents.REQUEST_EUC_FUNCTION_CALL_NAME,
|
||||
args={},
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
confirmation_event = Event(
|
||||
invocation_id="test_inv",
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionResponse(
|
||||
id="confirm_func",
|
||||
name=contents.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
response={
|
||||
"confirmed": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
valid_event = Event(
|
||||
invocation_id="test_inv",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user", parts=[types.Part.from_text(text="Hello")]
|
||||
),
|
||||
)
|
||||
|
||||
contents_result = _get_contents(
|
||||
None, [auth_event, confirmation_event, valid_event], "test_agent"
|
||||
)
|
||||
assert len(contents_result) == 1
|
||||
assert contents_result[0].role == "user"
|
||||
assert contents_result[0].parts[0].text == "Hello"
|
||||
|
||||
|
||||
def test_convert_foreign_event():
|
||||
"""Test _convert_foreign_event function."""
|
||||
agent_event = Event(
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.flows.llm_flows import functions
|
||||
from google.adk.flows.llm_flows.request_confirmation import request_processor
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.tools.tool_confirmation import ToolConfirmation
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import testing_utils
|
||||
|
||||
MOCK_TOOL_NAME = "mock_tool"
|
||||
MOCK_FUNCTION_CALL_ID = "mock_function_call_id"
|
||||
MOCK_CONFIRMATION_FUNCTION_CALL_ID = "mock_confirmation_function_call_id"
|
||||
|
||||
|
||||
def mock_tool(param1: str):
|
||||
"""Mock tool function."""
|
||||
return f"Mock tool result with {param1}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_confirmation_processor_no_events():
|
||||
"""Test that the processor returns None when there are no events."""
|
||||
agent = LlmAgent(name="test_agent", tools=[mock_tool])
|
||||
invocation_context = await testing_utils.create_invocation_context(
|
||||
agent=agent
|
||||
)
|
||||
llm_request = LlmRequest()
|
||||
|
||||
events = []
|
||||
async for event in request_processor.run_async(
|
||||
invocation_context, llm_request
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert not events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_confirmation_processor_no_function_responses():
|
||||
"""Test that the processor returns None when the user event has no function responses."""
|
||||
agent = LlmAgent(name="test_agent", tools=[mock_tool])
|
||||
invocation_context = await testing_utils.create_invocation_context(
|
||||
agent=agent
|
||||
)
|
||||
llm_request = LlmRequest()
|
||||
|
||||
invocation_context.session.events.append(
|
||||
Event(author="user", content=types.Content())
|
||||
)
|
||||
|
||||
events = []
|
||||
async for event in request_processor.run_async(
|
||||
invocation_context, llm_request
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert not events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_confirmation_processor_no_confirmation_function_response():
|
||||
"""Test that the processor returns None when no confirmation function response is present."""
|
||||
agent = LlmAgent(name="test_agent", tools=[mock_tool])
|
||||
invocation_context = await testing_utils.create_invocation_context(
|
||||
agent=agent
|
||||
)
|
||||
llm_request = LlmRequest()
|
||||
|
||||
invocation_context.session.events.append(
|
||||
Event(
|
||||
author="user",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
name="other_function", response={}
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
events = []
|
||||
async for event in request_processor.run_async(
|
||||
invocation_context, llm_request
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert not events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_confirmation_processor_success():
|
||||
"""Test the successful processing of a tool confirmation."""
|
||||
agent = LlmAgent(name="test_agent", tools=[mock_tool])
|
||||
invocation_context = await testing_utils.create_invocation_context(
|
||||
agent=agent
|
||||
)
|
||||
llm_request = LlmRequest()
|
||||
|
||||
original_function_call = types.FunctionCall(
|
||||
name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID
|
||||
)
|
||||
|
||||
tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint")
|
||||
tool_confirmation_args = {
|
||||
"originalFunctionCall": original_function_call.model_dump(
|
||||
exclude_none=True, by_alias=True
|
||||
),
|
||||
"toolConfirmation": tool_confirmation.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
# Event with the request for confirmation
|
||||
invocation_context.session.events.append(
|
||||
Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
args=tool_confirmation_args,
|
||||
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Event with the user's confirmation
|
||||
user_confirmation = ToolConfirmation(confirmed=True)
|
||||
invocation_context.session.events.append(
|
||||
Event(
|
||||
author="user",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
|
||||
response={
|
||||
"response": user_confirmation.model_dump_json()
|
||||
},
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
expected_event = Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
name=MOCK_TOOL_NAME,
|
||||
id=MOCK_FUNCTION_CALL_ID,
|
||||
response={"result": "Mock tool result with test"},
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"google.adk.flows.llm_flows.functions.handle_function_call_list_async"
|
||||
) as mock_handle_function_call_list_async:
|
||||
mock_handle_function_call_list_async.return_value = expected_event
|
||||
|
||||
events = []
|
||||
async for event in request_processor.run_async(
|
||||
invocation_context, llm_request
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0] == expected_event
|
||||
|
||||
mock_handle_function_call_list_async.assert_called_once()
|
||||
args, _ = mock_handle_function_call_list_async.call_args
|
||||
|
||||
assert list(args[1]) == [original_function_call] # function_calls
|
||||
assert args[3] == {MOCK_FUNCTION_CALL_ID} # tools_to_confirm
|
||||
assert (
|
||||
args[4][MOCK_FUNCTION_CALL_ID] == user_confirmation
|
||||
) # tool_confirmation_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_confirmation_processor_tool_not_confirmed():
|
||||
"""Test when the tool execution is not confirmed by the user."""
|
||||
agent = LlmAgent(name="test_agent", tools=[mock_tool])
|
||||
invocation_context = await testing_utils.create_invocation_context(
|
||||
agent=agent
|
||||
)
|
||||
llm_request = LlmRequest()
|
||||
|
||||
original_function_call = types.FunctionCall(
|
||||
name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID
|
||||
)
|
||||
|
||||
tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint")
|
||||
tool_confirmation_args = {
|
||||
"originalFunctionCall": original_function_call.model_dump(
|
||||
exclude_none=True, by_alias=True
|
||||
),
|
||||
"toolConfirmation": tool_confirmation.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
|
||||
invocation_context.session.events.append(
|
||||
Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
args=tool_confirmation_args,
|
||||
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
user_confirmation = ToolConfirmation(confirmed=False)
|
||||
invocation_context.session.events.append(
|
||||
Event(
|
||||
author="user",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
|
||||
response={
|
||||
"response": user_confirmation.model_dump_json()
|
||||
},
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"google.adk.flows.llm_flows.functions.handle_function_call_list_async"
|
||||
) as mock_handle_function_call_list_async:
|
||||
mock_handle_function_call_list_async.return_value = Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
name=MOCK_TOOL_NAME,
|
||||
id=MOCK_FUNCTION_CALL_ID,
|
||||
response={"error": "Tool execution not confirmed"},
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
events = []
|
||||
async for event in request_processor.run_async(
|
||||
invocation_context, llm_request
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 1
|
||||
mock_handle_function_call_list_async.assert_called_once()
|
||||
args, _ = mock_handle_function_call_list_async.call_args
|
||||
assert (
|
||||
args[4][MOCK_FUNCTION_CALL_ID] == user_confirmation
|
||||
) # tool_confirmation_dict
|
||||
@@ -17,6 +17,7 @@ from unittest.mock import MagicMock
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.tool_confirmation import ToolConfirmation
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
import pytest
|
||||
|
||||
@@ -345,3 +346,51 @@ async def test_run_async_with_tool_context_and_unexpected_argument():
|
||||
"received_arg": "world",
|
||||
"context_present": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_require_confirmation():
|
||||
"""Test that run_async handles require_confirmation flag."""
|
||||
|
||||
def sample_func(arg1: str):
|
||||
return {"received_arg": arg1}
|
||||
|
||||
tool = FunctionTool(sample_func, require_confirmation=True)
|
||||
mock_invocation_context = MagicMock(spec=InvocationContext)
|
||||
mock_invocation_context.session = MagicMock(spec=Session)
|
||||
mock_invocation_context.session.state = MagicMock()
|
||||
mock_invocation_context.agent = MagicMock()
|
||||
mock_invocation_context.agent.name = "test_agent"
|
||||
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)
|
||||
tool_context_mock.function_call_id = "test_function_call_id"
|
||||
|
||||
# First call, should request confirmation
|
||||
result = await tool.run_async(
|
||||
args={"arg1": "hello"},
|
||||
tool_context=tool_context_mock,
|
||||
)
|
||||
assert result == {
|
||||
"error": "This tool call requires confirmation, please approve or reject."
|
||||
}
|
||||
assert tool_context_mock._event_actions.requested_tool_confirmations[
|
||||
"test_function_call_id"
|
||||
].hint == (
|
||||
"Please approve or reject the tool call sample_func() by responding with"
|
||||
" a FunctionResponse with an expected ToolConfirmation payload."
|
||||
)
|
||||
|
||||
# Second call, user rejects
|
||||
tool_context_mock.tool_confirmation = ToolConfirmation(confirmed=False)
|
||||
result = await tool.run_async(
|
||||
args={"arg1": "hello"},
|
||||
tool_context=tool_context_mock,
|
||||
)
|
||||
assert result == {"error": "This tool call is rejected."}
|
||||
|
||||
# Third call, user approves
|
||||
tool_context_mock.tool_confirmation = ToolConfirmation(confirmed=True)
|
||||
result = await tool.run_async(
|
||||
args={"arg1": "hello"},
|
||||
tool_context=tool_context_mock,
|
||||
)
|
||||
assert result == {"received_arg": "hello"}
|
||||
|
||||
Reference in New Issue
Block a user