fix: Invoke on_tool_error_callback for missing tools in live mode

In live mode, when the model calls an unregistered tool, ADK now runs on_tool_error_callback before failing. If the callback returns a response, ADK emits
that function response and continues; otherwise it keeps the old ValueError

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 872996178
This commit is contained in:
George Weale
2026-02-20 11:20:31 -08:00
committed by Copybara-Service
parent 7478bdaa98
commit e6b601a2ab
2 changed files with 131 additions and 3 deletions
+54 -3
View File
@@ -660,14 +660,65 @@ async def _execute_single_function_call_live(
streaming_lock: asyncio.Lock,
) -> Optional[Event]:
"""Execute a single function call for live mode with thread safety."""
tool, tool_context = _get_tool_and_context(
invocation_context, function_call, tools_dict
)
async def _run_on_tool_error_callbacks(
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict[str, Any]]:
"""Runs the on_tool_error_callbacks for the given tool."""
error_response = (
await invocation_context.plugin_manager.run_on_tool_error_callback(
tool=tool,
tool_args=tool_args,
tool_context=tool_context,
error=error,
)
)
if error_response is not None:
return error_response
for callback in agent.canonical_on_tool_error_callbacks:
error_response = callback(
tool=tool,
args=tool_args,
tool_context=tool_context,
error=error,
)
if inspect.isawaitable(error_response):
error_response = await error_response
if error_response is not None:
return error_response
return None
# Do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
# Make a deep copy to avoid being modified.
function_args = (
copy.deepcopy(function_call.args) if function_call.args else {}
)
tool_context = _create_tool_context(invocation_context, function_call)
try:
tool = _get_tool(function_call, tools_dict)
except ValueError as tool_error:
tool = BaseTool(name=function_call.name, description='Tool not found')
error_response = await _run_on_tool_error_callbacks(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
if error_response is not None:
return __build_response_event(
tool, error_response, tool_context, invocation_context
)
raise tool_error
async def _run_with_trace():
nonlocal function_args
@@ -386,3 +386,80 @@ async def test_live_callback_compatibility_with_async():
async_response = async_result.content.parts[0].function_response.response
live_response = live_result.content.parts[0].function_response.response
assert async_response == live_response == {"bypassed": "by_before_callback"}
@pytest.mark.asyncio
async def test_live_on_tool_error_callback_tool_not_found_noop():
"""Test that on_tool_error_callback is a no-op when the tool is not found."""
def noop_on_tool_error_callback(tool, args, tool_context, error):
return None
def simple_fn(**kwargs) -> Dict[str, Any]:
return {"initial": "response"}
tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name="agent",
model=model,
tools=[tool],
on_tool_error_callback=noop_on_tool_error_callback,
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=""
)
function_call = types.FunctionCall(name="nonexistent_function", args={})
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}
with pytest.raises(ValueError):
await handle_function_calls_live(invocation_context, event, tools_dict)
@pytest.mark.asyncio
async def test_live_on_tool_error_callback_tool_not_found_modify_tool_response():
"""Test that on_tool_error_callback modifies tool response when tool is not found."""
def mock_on_tool_error_callback(tool, args, tool_context, error):
return {"result": "on_tool_error_callback_response"}
def simple_fn(**kwargs) -> Dict[str, Any]:
return {"initial": "response"}
tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name="agent",
model=model,
tools=[tool],
on_tool_error_callback=mock_on_tool_error_callback,
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=""
)
function_call = types.FunctionCall(name="nonexistent_function", args={})
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}
result_event = await handle_function_calls_live(
invocation_context,
event,
tools_dict,
)
assert result_event is not None
part = result_event.content.parts[0]
assert part.function_response.response == {
"result": "on_tool_error_callback_response"
}