You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Invoke on_tool_error_callback for missing tools in live mode
In live mode, when the model calls an unregistered tool, ADK now runs on_tool_error_callback before failing. If the callback returns a response, ADK emits that function response and continues; otherwise it keeps the old ValueError Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 872996178
This commit is contained in:
committed by
Copybara-Service
parent
7478bdaa98
commit
e6b601a2ab
@@ -660,14 +660,65 @@ async def _execute_single_function_call_live(
|
||||
streaming_lock: asyncio.Lock,
|
||||
) -> Optional[Event]:
|
||||
"""Execute a single function call for live mode with thread safety."""
|
||||
tool, tool_context = _get_tool_and_context(
|
||||
invocation_context, function_call, tools_dict
|
||||
)
|
||||
|
||||
async def _run_on_tool_error_callbacks(
|
||||
*,
|
||||
tool: BaseTool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
error: Exception,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Runs the on_tool_error_callbacks for the given tool."""
|
||||
error_response = (
|
||||
await invocation_context.plugin_manager.run_on_tool_error_callback(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
tool_context=tool_context,
|
||||
error=error,
|
||||
)
|
||||
)
|
||||
if error_response is not None:
|
||||
return error_response
|
||||
|
||||
for callback in agent.canonical_on_tool_error_callbacks:
|
||||
error_response = callback(
|
||||
tool=tool,
|
||||
args=tool_args,
|
||||
tool_context=tool_context,
|
||||
error=error,
|
||||
)
|
||||
if inspect.isawaitable(error_response):
|
||||
error_response = await error_response
|
||||
if error_response is not None:
|
||||
return error_response
|
||||
|
||||
return None
|
||||
|
||||
# Do not use "args" as the variable name, because it is a reserved keyword
|
||||
# in python debugger.
|
||||
# Make a deep copy to avoid being modified.
|
||||
function_args = (
|
||||
copy.deepcopy(function_call.args) if function_call.args else {}
|
||||
)
|
||||
|
||||
tool_context = _create_tool_context(invocation_context, function_call)
|
||||
|
||||
try:
|
||||
tool = _get_tool(function_call, tools_dict)
|
||||
except ValueError as tool_error:
|
||||
tool = BaseTool(name=function_call.name, description='Tool not found')
|
||||
error_response = await _run_on_tool_error_callbacks(
|
||||
tool=tool,
|
||||
tool_args=function_args,
|
||||
tool_context=tool_context,
|
||||
error=tool_error,
|
||||
)
|
||||
if error_response is not None:
|
||||
return __build_response_event(
|
||||
tool, error_response, tool_context, invocation_context
|
||||
)
|
||||
raise tool_error
|
||||
|
||||
async def _run_with_trace():
|
||||
nonlocal function_args
|
||||
|
||||
|
||||
@@ -386,3 +386,80 @@ async def test_live_callback_compatibility_with_async():
|
||||
async_response = async_result.content.parts[0].function_response.response
|
||||
live_response = live_result.content.parts[0].function_response.response
|
||||
assert async_response == live_response == {"bypassed": "by_before_callback"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_on_tool_error_callback_tool_not_found_noop():
|
||||
"""Test that on_tool_error_callback is a no-op when the tool is not found."""
|
||||
|
||||
def noop_on_tool_error_callback(tool, args, tool_context, error):
|
||||
return None
|
||||
|
||||
def simple_fn(**kwargs) -> Dict[str, Any]:
|
||||
return {"initial": "response"}
|
||||
|
||||
tool = FunctionTool(simple_fn)
|
||||
model = testing_utils.MockModel.create(responses=[])
|
||||
agent = Agent(
|
||||
name="agent",
|
||||
model=model,
|
||||
tools=[tool],
|
||||
on_tool_error_callback=noop_on_tool_error_callback,
|
||||
)
|
||||
invocation_context = await testing_utils.create_invocation_context(
|
||||
agent=agent, user_content=""
|
||||
)
|
||||
function_call = types.FunctionCall(name="nonexistent_function", args={})
|
||||
content = types.Content(parts=[types.Part(function_call=function_call)])
|
||||
event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=agent.name,
|
||||
content=content,
|
||||
)
|
||||
tools_dict = {tool.name: tool}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await handle_function_calls_live(invocation_context, event, tools_dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_on_tool_error_callback_tool_not_found_modify_tool_response():
|
||||
"""Test that on_tool_error_callback modifies tool response when tool is not found."""
|
||||
|
||||
def mock_on_tool_error_callback(tool, args, tool_context, error):
|
||||
return {"result": "on_tool_error_callback_response"}
|
||||
|
||||
def simple_fn(**kwargs) -> Dict[str, Any]:
|
||||
return {"initial": "response"}
|
||||
|
||||
tool = FunctionTool(simple_fn)
|
||||
model = testing_utils.MockModel.create(responses=[])
|
||||
agent = Agent(
|
||||
name="agent",
|
||||
model=model,
|
||||
tools=[tool],
|
||||
on_tool_error_callback=mock_on_tool_error_callback,
|
||||
)
|
||||
invocation_context = await testing_utils.create_invocation_context(
|
||||
agent=agent, user_content=""
|
||||
)
|
||||
function_call = types.FunctionCall(name="nonexistent_function", args={})
|
||||
content = types.Content(parts=[types.Part(function_call=function_call)])
|
||||
event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=agent.name,
|
||||
content=content,
|
||||
)
|
||||
tools_dict = {tool.name: tool}
|
||||
|
||||
result_event = await handle_function_calls_live(
|
||||
invocation_context,
|
||||
event,
|
||||
tools_dict,
|
||||
)
|
||||
|
||||
assert result_event is not None
|
||||
part = result_event.content.parts[0]
|
||||
assert part.function_response.response == {
|
||||
"result": "on_tool_error_callback_response"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user