diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 6f34e8fe..4d045fac 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -660,14 +660,65 @@ async def _execute_single_function_call_live( streaming_lock: asyncio.Lock, ) -> Optional[Event]: """Execute a single function call for live mode with thread safety.""" - tool, tool_context = _get_tool_and_context( - invocation_context, function_call, tools_dict - ) + async def _run_on_tool_error_callbacks( + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict[str, Any]]: + """Runs the on_tool_error_callbacks for the given tool.""" + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=tool_args, + tool_context=tool_context, + error=error, + ) + ) + if error_response is not None: + return error_response + + for callback in agent.canonical_on_tool_error_callbacks: + error_response = callback( + tool=tool, + args=tool_args, + tool_context=tool_context, + error=error, + ) + if inspect.isawaitable(error_response): + error_response = await error_response + if error_response is not None: + return error_response + + return None + + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. function_args = ( copy.deepcopy(function_call.args) if function_call.args else {} ) + tool_context = _create_tool_context(invocation_context, function_call) + + try: + tool = _get_tool(function_call, tools_dict) + except ValueError as tool_error: + tool = BaseTool(name=function_call.name, description='Tool not found') + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + return __build_response_event( + tool, error_response, tool_context, invocation_context + ) + raise tool_error + async def _run_with_trace(): nonlocal function_args diff --git a/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py index caab8f3f..016e9b49 100644 --- a/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py @@ -386,3 +386,80 @@ async def test_live_callback_compatibility_with_async(): async_response = async_result.content.parts[0].function_response.response live_response = live_result.content.parts[0].function_response.response assert async_response == live_response == {"bypassed": "by_before_callback"} + + +@pytest.mark.asyncio +async def test_live_on_tool_error_callback_tool_not_found_noop(): + """Test that on_tool_error_callback is a no-op when the tool is not found.""" + + def noop_on_tool_error_callback(tool, args, tool_context, error): + return None + + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + on_tool_error_callback=noop_on_tool_error_callback, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="" + ) + function_call = types.FunctionCall(name="nonexistent_function", args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + with pytest.raises(ValueError): + await handle_function_calls_live(invocation_context, event, tools_dict) + + +@pytest.mark.asyncio +async def test_live_on_tool_error_callback_tool_not_found_modify_tool_response(): + """Test that on_tool_error_callback modifies tool response when tool is not found.""" + + def mock_on_tool_error_callback(tool, args, tool_context, error): + return {"result": "on_tool_error_callback_response"} + + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + on_tool_error_callback=mock_on_tool_error_callback, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="" + ) + function_call = types.FunctionCall(name="nonexistent_function", args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + result_event = await handle_function_calls_live( + invocation_context, + event, + tools_dict, + ) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == { + "result": "on_tool_error_callback_response" + }