From f51380f9ea4534591eda76bef27407c0aa7c3fae Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Wed, 15 Oct 2025 21:42:56 -0700 Subject: [PATCH] feat: Extend `ReflectAndRetryToolPlugin` to support hallucinating function calls PiperOrigin-RevId: 820051762 --- .../plugin_reflect_tool_retry/README.md | 24 +++++- .../hallucinating_func_name/__init__.py | 15 ++++ .../hallucinating_func_name/agent.py | 81 ++++++++++++++++++ src/google/adk/flows/llm_flows/functions.py | 83 ++++++++++++++----- .../adk/plugins/reflect_retry_tool_plugin.py | 59 +++++++++++-- .../plugins/test_reflect_retry_tool_plugin.py | 57 +++++++++++++ 6 files changed, 287 insertions(+), 32 deletions(-) create mode 100644 contributing/samples/plugin_reflect_tool_retry/hallucinating_func_name/__init__.py create mode 100644 contributing/samples/plugin_reflect_tool_retry/hallucinating_func_name/agent.py diff --git a/contributing/samples/plugin_reflect_tool_retry/README.md b/contributing/samples/plugin_reflect_tool_retry/README.md index 084862e4..77362316 100644 --- a/contributing/samples/plugin_reflect_tool_retry/README.md +++ b/contributing/samples/plugin_reflect_tool_retry/README.md @@ -46,8 +46,30 @@ You can run the agent with: $ adk web contributing/samples/plugin_reflect_tool_retry ``` -You can provide the following prompt to see the agent retrying tool calls: +Select "basic" and provide the following prompt to see the agent retrying tool +calls: ``` Please guess a number! Tell me what number you guess and how is it. ``` + +### Hallucinating tool calls + +The "hallucinating_func_name" agent is an example to show the plugin can retry +hallucinating tool calls. + +For example, we used the `after_model_callback` to hack a tool call with the +wrong name then the agent can retry calling with the right tool name. + +You can run the agent with: + +```bash +$ adk web contributing/samples/plugin_reflect_tool_retry +``` + +Select "hallucinating_func_name" and provide the following prompt to see the +agent retrying tool calls: + +``` +Roll a 6 sided die +``` diff --git a/contributing/samples/plugin_reflect_tool_retry/hallucinating_func_name/__init__.py b/contributing/samples/plugin_reflect_tool_retry/hallucinating_func_name/__init__.py new file mode 100644 index 00000000..c48963cd --- /dev/null +++ b/contributing/samples/plugin_reflect_tool_retry/hallucinating_func_name/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/plugin_reflect_tool_retry/hallucinating_func_name/agent.py b/contributing/samples/plugin_reflect_tool_retry/hallucinating_func_name/agent.py new file mode 100644 index 00000000..5b8d9262 --- /dev/null +++ b/contributing/samples/plugin_reflect_tool_retry/hallucinating_func_name/agent.py @@ -0,0 +1,81 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk.agents import LlmAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.apps.app import App +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins import ReflectAndRetryToolPlugin +from google.adk.tools.tool_context import ToolContext + +APP_NAME = "hallucinating_func_name" +USER_ID = "test_user" + +hallucinated = False # Whether the tool name is hallucinated + + +def roll_die(sides: int, tool_context: ToolContext) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if not "rolls" in tool_context.state: + tool_context.state["rolls"] = [] + + tool_context.state["rolls"] = tool_context.state["rolls"] + [result] + return result + + +def after_model_callback( + callback_context: CallbackContext, llm_response: LlmResponse +): + """After model callback to produce one hallucinating tool call.""" + global hallucinated + + if hallucinated: + return None + + if ( + llm_response.content + and llm_response.content.parts[0].function_call.name == "roll_die" + ): + llm_response.content.parts[0].function_call.name = "roll_die_wrong_name" + hallucinated = True + return None + + +root_agent = LlmAgent( + name="hello_world", + description="Helpful agent", + instruction="""Use guess_number_tool to guess a number.""", + model="gemini-2.5-flash", + tools=[roll_die], + after_model_callback=after_model_callback, +) + + +app = App( + name=APP_NAME, + root_agent=root_agent, + plugins=[ + ReflectAndRetryToolPlugin(max_retries=3), + ], +) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index b7508aee..4380322b 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -275,21 +275,37 @@ async def _execute_single_function_call_async( tool_confirmation: Optional[ToolConfirmation] = None, ) -> Optional[Event]: """Execute a single function call with thread safety for state modifications.""" - tool, tool_context = _get_tool_and_context( - invocation_context, - function_call, - tools_dict, - tool_confirmation, + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} ) - with tracer.start_as_current_span(f'execute_tool {tool.name}'): - # Do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - # Make a deep copy to avoid being modified. - function_args = ( - copy.deepcopy(function_call.args) if function_call.args else {} - ) + tool_context = _create_tool_context( + invocation_context, function_call, tool_confirmation + ) + try: + tool = _get_tool(function_call, tools_dict) + except ValueError as tool_error: + tool = BaseTool(name=function_call.name, description='Tool not found') + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + ) + if error_response is not None: + return __build_response_event( + tool, error_response, tool_context, invocation_context + ) + else: + raise tool_error + + with tracer.start_as_current_span(f'execute_tool {tool.name}'): # Step 1: Check if plugin before_tool_callback overrides the function # response. function_response = ( @@ -639,25 +655,46 @@ async def _process_function_live_helper( return function_response +def _get_tool( + function_call: types.FunctionCall, tools_dict: dict[str, BaseTool] +): + """Returns the tool corresponding to the function call.""" + if function_call.name not in tools_dict: + raise ValueError( + f'Function {function_call.name} is not found in the tools_dict:' + f' {tools_dict.keys()}.' + ) + + return tools_dict[function_call.name] + + +def _create_tool_context( + invocation_context: InvocationContext, + function_call: types.FunctionCall, + tool_confirmation: Optional[ToolConfirmation] = None, +): + """Creates a ToolContext object.""" + return ToolContext( + invocation_context=invocation_context, + function_call_id=function_call.id, + tool_confirmation=tool_confirmation, + ) + + def _get_tool_and_context( invocation_context: InvocationContext, function_call: types.FunctionCall, tools_dict: dict[str, BaseTool], tool_confirmation: Optional[ToolConfirmation] = None, ): - if function_call.name not in tools_dict: - raise ValueError( - f'Function {function_call.name} is not found in the tools_dict.' - ) - - tool_context = ToolContext( - invocation_context=invocation_context, - function_call_id=function_call.id, - tool_confirmation=tool_confirmation, + """Returns the tool and tool context corresponding to the function call.""" + tool = _get_tool(function_call, tools_dict) + tool_context = _create_tool_context( + invocation_context, + function_call, + tool_confirmation, ) - tool = tools_dict[function_call.name] - return (tool, tool_context) diff --git a/src/google/adk/plugins/reflect_retry_tool_plugin.py b/src/google/adk/plugins/reflect_retry_tool_plugin.py index cc501a93..a3a0cc25 100644 --- a/src/google/adk/plugins/reflect_retry_tool_plugin.py +++ b/src/google/adk/plugins/reflect_retry_tool_plugin.py @@ -142,8 +142,20 @@ class ReflectAndRetryToolPlugin(BasePlugin): tool_args: dict[str, Any], tool_context: ToolContext, result: Any, - ) -> Optional[dict]: - """Handles successful tool calls or extracts and processes errors.""" + ) -> Optional[dict[str, Any]]: + """Handles successful tool calls or extracts and processes errors. + + Args: + tool: The tool that was called. + tool_args: The arguments passed to the tool. + tool_context: The context of the tool call. + result: The result of the tool call. + + Returns: + An optional dictionary containing reflection guidance if an error is + detected, or None if the tool call was successful or the + response is already a reflection message. + """ if ( isinstance(result, dict) and result.get("response_type") == REFLECT_AND_RETRY_RESPONSE_TYPE @@ -157,7 +169,8 @@ class ReflectAndRetryToolPlugin(BasePlugin): if error: return await self._handle_tool_error(tool, tool_args, tool_context, error) - # On success, reset the failure count for this specific tool within its scope. + # On success, reset the failure count for this specific tool within + # its scope. await self._reset_failures_for_tool(tool_context, tool.name) return None @@ -168,7 +181,7 @@ class ReflectAndRetryToolPlugin(BasePlugin): tool_args: dict[str, Any], tool_context: ToolContext, result: Any, - ) -> Optional[Any]: + ) -> Optional[dict[str, Any]]: """Extracts an error from a successful tool result and triggers retry logic. This is useful when tool call finishes successfully but the result contains @@ -176,6 +189,15 @@ class ReflectAndRetryToolPlugin(BasePlugin): By overriding this method, you can trigger retry logic on these successful results that contain errors. + + Args: + tool: The tool that was called. + tool_args: The arguments passed to the tool. + tool_context: The context of the tool call. + result: The result of the tool call. + + Returns: + The extracted error if any, or None if no error was detected. """ return None @@ -186,8 +208,18 @@ class ReflectAndRetryToolPlugin(BasePlugin): tool_args: dict[str, Any], tool_context: ToolContext, error: Exception, - ) -> Optional[dict]: - """Handles tool exceptions by providing reflection guidance.""" + ) -> Optional[dict[str, Any]]: + """Handles tool exceptions by providing reflection guidance. + + Args: + tool: The tool that was called. + tool_args: The arguments passed to the tool. + tool_context: The context of the tool call. + error: The exception raised by the tool. + + Returns: + An optional dictionary containing reflection guidance for the error. + """ return await self._handle_tool_error(tool, tool_args, tool_context, error) async def _handle_tool_error( @@ -196,8 +228,18 @@ class ReflectAndRetryToolPlugin(BasePlugin): tool_args: dict[str, Any], tool_context: ToolContext, error: Any, - ) -> Optional[dict]: - """Central, thread-safe logic for processing tool errors.""" + ) -> Optional[dict[str, Any]]: + """Central, thread-safe logic for processing tool errors. + + Args: + tool: The tool that was called. + tool_args: The arguments passed to the tool. + tool_context: The context of the tool call. + error: The error to be handled. + + Returns: + An optional dictionary containing reflection guidance for the error. + """ if self.max_retries == 0: if self.throw_exception_if_retry_exceeded: raise error @@ -285,6 +327,7 @@ This is retry attempt **{retry_count} of {self.max_retries}**. Analyze the error 2. **State or Preconditions**: Did a previous step fail or not produce the necessary state/resource for this tool to succeed? 3. **Alternative Approach**: Is this the right tool for the job? Could another tool or a different sequence of steps achieve the goal? 4. **Simplify the Task**: Can you break the problem down into smaller, simpler steps? +5. **Wrong Function Name**: Does the error indicates the tool is not found? Please check again and only use available tools. Formulate a new plan based on your analysis and try a corrected or different approach. """ diff --git a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py index ac085020..1e15f338 100644 --- a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py +++ b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py @@ -16,10 +16,14 @@ from typing import Any from unittest import IsolatedAsyncioTestCase from unittest.mock import Mock +from google.adk.agents.llm_agent import LlmAgent from google.adk.plugins.reflect_retry_tool_plugin import REFLECT_AND_RETRY_RESPONSE_TYPE from google.adk.plugins.reflect_retry_tool_plugin import ReflectAndRetryToolPlugin from google.adk.tools.base_tool import BaseTool from google.adk.tools.tool_context import ToolContext +from google.genai import types + +from .. import testing_utils class MockTool(BaseTool): @@ -524,3 +528,56 @@ class TestReflectAndRetryToolPlugin(IsolatedAsyncioTestCase): result=custom_error, ) self.assertEqual(result4["retry_count"], 1) + + async def test_hallucinating_tool_name(self): + """Test that hallucinating tool name is handled correctly.""" + wrong_function_call = types.Part.from_function_call( + name="increase_by_one", args={"x": 1} + ) + correct_function_call = types.Part.from_function_call( + name="increase", args={"x": 1} + ) + responses: list[types.Content] = [ + wrong_function_call, + correct_function_call, + "response1", + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + function_called = 0 + + def increase(x: int) -> int: + nonlocal function_called + function_called += 1 + return x + 1 + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[increase]) + runner = testing_utils.TestInMemoryRunner( + agent=agent, plugins=[self.get_plugin()] + ) + + events = await runner.run_async_with_new_session("test") + + # Assert that the first event is a function call with the wrong name + assert events[0].content.parts[0].function_call.name == "increase_by_one" + + # Assert that the second event is a function response with the + # reflection_guidance + assert ( + events[1].content.parts[0].function_response.response["error_type"] + == "ValueError" + ) + assert ( + events[1].content.parts[0].function_response.response["retry_count"] + == 1 + ) + assert ( + "Wrong Function Name" + in events[1] + .content.parts[0] + .function_response.response["reflection_guidance"] + ) + + # Assert that the third event is a function call with the correct name + assert events[2].content.parts[0].function_call.name == "increase" + self.assertEqual(function_called, 1)