feat: Extend ReflectAndRetryToolPlugin to support hallucinating function calls

PiperOrigin-RevId: 820051762
This commit is contained in:
Xuan Yang
2025-10-15 21:42:56 -07:00
committed by Copybara-Service
parent 3734ceaa6c
commit f51380f9ea
6 changed files with 287 additions and 32 deletions
@@ -46,8 +46,30 @@ You can run the agent with:
$ adk web contributing/samples/plugin_reflect_tool_retry
```
You can provide the following prompt to see the agent retrying tool calls:
Select "basic" and provide the following prompt to see the agent retrying tool
calls:
```
Please guess a number! Tell me what number you guess and how is it.
```
### Hallucinating tool calls
The "hallucinating_func_name" agent is an example to show the plugin can retry
hallucinating tool calls.
For example, we used the `after_model_callback` to hack a tool call with the
wrong name then the agent can retry calling with the right tool name.
You can run the agent with:
```bash
$ adk web contributing/samples/plugin_reflect_tool_retry
```
Select "hallucinating_func_name" and provide the following prompt to see the
agent retrying tool calls:
```
Roll a 6 sided die
```
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
@@ -0,0 +1,81 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from google.adk.agents import LlmAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.apps.app import App
from google.adk.models.llm_response import LlmResponse
from google.adk.plugins import ReflectAndRetryToolPlugin
from google.adk.tools.tool_context import ToolContext
APP_NAME = "hallucinating_func_name"
USER_ID = "test_user"
hallucinated = False # Whether the tool name is hallucinated
def roll_die(sides: int, tool_context: ToolContext) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
result = random.randint(1, sides)
if not "rolls" in tool_context.state:
tool_context.state["rolls"] = []
tool_context.state["rolls"] = tool_context.state["rolls"] + [result]
return result
def after_model_callback(
callback_context: CallbackContext, llm_response: LlmResponse
):
"""After model callback to produce one hallucinating tool call."""
global hallucinated
if hallucinated:
return None
if (
llm_response.content
and llm_response.content.parts[0].function_call.name == "roll_die"
):
llm_response.content.parts[0].function_call.name = "roll_die_wrong_name"
hallucinated = True
return None
root_agent = LlmAgent(
name="hello_world",
description="Helpful agent",
instruction="""Use guess_number_tool to guess a number.""",
model="gemini-2.5-flash",
tools=[roll_die],
after_model_callback=after_model_callback,
)
app = App(
name=APP_NAME,
root_agent=root_agent,
plugins=[
ReflectAndRetryToolPlugin(max_retries=3),
],
)
+60 -23
View File
@@ -275,21 +275,37 @@ async def _execute_single_function_call_async(
tool_confirmation: Optional[ToolConfirmation] = None,
) -> Optional[Event]:
"""Execute a single function call with thread safety for state modifications."""
tool, tool_context = _get_tool_and_context(
invocation_context,
function_call,
tools_dict,
tool_confirmation,
# Do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
# Make a deep copy to avoid being modified.
function_args = (
copy.deepcopy(function_call.args) if function_call.args else {}
)
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
# Do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
# Make a deep copy to avoid being modified.
function_args = (
copy.deepcopy(function_call.args) if function_call.args else {}
)
tool_context = _create_tool_context(
invocation_context, function_call, tool_confirmation
)
try:
tool = _get_tool(function_call, tools_dict)
except ValueError as tool_error:
tool = BaseTool(name=function_call.name, description='Tool not found')
error_response = (
await invocation_context.plugin_manager.run_on_tool_error_callback(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
)
if error_response is not None:
return __build_response_event(
tool, error_response, tool_context, invocation_context
)
else:
raise tool_error
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
# Step 1: Check if plugin before_tool_callback overrides the function
# response.
function_response = (
@@ -639,25 +655,46 @@ async def _process_function_live_helper(
return function_response
def _get_tool(
function_call: types.FunctionCall, tools_dict: dict[str, BaseTool]
):
"""Returns the tool corresponding to the function call."""
if function_call.name not in tools_dict:
raise ValueError(
f'Function {function_call.name} is not found in the tools_dict:'
f' {tools_dict.keys()}.'
)
return tools_dict[function_call.name]
def _create_tool_context(
invocation_context: InvocationContext,
function_call: types.FunctionCall,
tool_confirmation: Optional[ToolConfirmation] = None,
):
"""Creates a ToolContext object."""
return ToolContext(
invocation_context=invocation_context,
function_call_id=function_call.id,
tool_confirmation=tool_confirmation,
)
def _get_tool_and_context(
invocation_context: InvocationContext,
function_call: types.FunctionCall,
tools_dict: dict[str, BaseTool],
tool_confirmation: Optional[ToolConfirmation] = None,
):
if function_call.name not in tools_dict:
raise ValueError(
f'Function {function_call.name} is not found in the tools_dict.'
)
tool_context = ToolContext(
invocation_context=invocation_context,
function_call_id=function_call.id,
tool_confirmation=tool_confirmation,
"""Returns the tool and tool context corresponding to the function call."""
tool = _get_tool(function_call, tools_dict)
tool_context = _create_tool_context(
invocation_context,
function_call,
tool_confirmation,
)
tool = tools_dict[function_call.name]
return (tool, tool_context)
@@ -142,8 +142,20 @@ class ReflectAndRetryToolPlugin(BasePlugin):
tool_args: dict[str, Any],
tool_context: ToolContext,
result: Any,
) -> Optional[dict]:
"""Handles successful tool calls or extracts and processes errors."""
) -> Optional[dict[str, Any]]:
"""Handles successful tool calls or extracts and processes errors.
Args:
tool: The tool that was called.
tool_args: The arguments passed to the tool.
tool_context: The context of the tool call.
result: The result of the tool call.
Returns:
An optional dictionary containing reflection guidance if an error is
detected, or None if the tool call was successful or the
response is already a reflection message.
"""
if (
isinstance(result, dict)
and result.get("response_type") == REFLECT_AND_RETRY_RESPONSE_TYPE
@@ -157,7 +169,8 @@ class ReflectAndRetryToolPlugin(BasePlugin):
if error:
return await self._handle_tool_error(tool, tool_args, tool_context, error)
# On success, reset the failure count for this specific tool within its scope.
# On success, reset the failure count for this specific tool within
# its scope.
await self._reset_failures_for_tool(tool_context, tool.name)
return None
@@ -168,7 +181,7 @@ class ReflectAndRetryToolPlugin(BasePlugin):
tool_args: dict[str, Any],
tool_context: ToolContext,
result: Any,
) -> Optional[Any]:
) -> Optional[dict[str, Any]]:
"""Extracts an error from a successful tool result and triggers retry logic.
This is useful when tool call finishes successfully but the result contains
@@ -176,6 +189,15 @@ class ReflectAndRetryToolPlugin(BasePlugin):
By overriding this method, you can trigger retry logic on these successful
results that contain errors.
Args:
tool: The tool that was called.
tool_args: The arguments passed to the tool.
tool_context: The context of the tool call.
result: The result of the tool call.
Returns:
The extracted error if any, or None if no error was detected.
"""
return None
@@ -186,8 +208,18 @@ class ReflectAndRetryToolPlugin(BasePlugin):
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
"""Handles tool exceptions by providing reflection guidance."""
) -> Optional[dict[str, Any]]:
"""Handles tool exceptions by providing reflection guidance.
Args:
tool: The tool that was called.
tool_args: The arguments passed to the tool.
tool_context: The context of the tool call.
error: The exception raised by the tool.
Returns:
An optional dictionary containing reflection guidance for the error.
"""
return await self._handle_tool_error(tool, tool_args, tool_context, error)
async def _handle_tool_error(
@@ -196,8 +228,18 @@ class ReflectAndRetryToolPlugin(BasePlugin):
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Any,
) -> Optional[dict]:
"""Central, thread-safe logic for processing tool errors."""
) -> Optional[dict[str, Any]]:
"""Central, thread-safe logic for processing tool errors.
Args:
tool: The tool that was called.
tool_args: The arguments passed to the tool.
tool_context: The context of the tool call.
error: The error to be handled.
Returns:
An optional dictionary containing reflection guidance for the error.
"""
if self.max_retries == 0:
if self.throw_exception_if_retry_exceeded:
raise error
@@ -285,6 +327,7 @@ This is retry attempt **{retry_count} of {self.max_retries}**. Analyze the error
2. **State or Preconditions**: Did a previous step fail or not produce the necessary state/resource for this tool to succeed?
3. **Alternative Approach**: Is this the right tool for the job? Could another tool or a different sequence of steps achieve the goal?
4. **Simplify the Task**: Can you break the problem down into smaller, simpler steps?
5. **Wrong Function Name**: Does the error indicates the tool is not found? Please check again and only use available tools.
Formulate a new plan based on your analysis and try a corrected or different approach.
"""
@@ -16,10 +16,14 @@ from typing import Any
from unittest import IsolatedAsyncioTestCase
from unittest.mock import Mock
from google.adk.agents.llm_agent import LlmAgent
from google.adk.plugins.reflect_retry_tool_plugin import REFLECT_AND_RETRY_RESPONSE_TYPE
from google.adk.plugins.reflect_retry_tool_plugin import ReflectAndRetryToolPlugin
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
from google.genai import types
from .. import testing_utils
class MockTool(BaseTool):
@@ -524,3 +528,56 @@ class TestReflectAndRetryToolPlugin(IsolatedAsyncioTestCase):
result=custom_error,
)
self.assertEqual(result4["retry_count"], 1)
async def test_hallucinating_tool_name(self):
"""Test that hallucinating tool name is handled correctly."""
wrong_function_call = types.Part.from_function_call(
name="increase_by_one", args={"x": 1}
)
correct_function_call = types.Part.from_function_call(
name="increase", args={"x": 1}
)
responses: list[types.Content] = [
wrong_function_call,
correct_function_call,
"response1",
]
mock_model = testing_utils.MockModel.create(responses=responses)
function_called = 0
def increase(x: int) -> int:
nonlocal function_called
function_called += 1
return x + 1
agent = LlmAgent(name="root_agent", model=mock_model, tools=[increase])
runner = testing_utils.TestInMemoryRunner(
agent=agent, plugins=[self.get_plugin()]
)
events = await runner.run_async_with_new_session("test")
# Assert that the first event is a function call with the wrong name
assert events[0].content.parts[0].function_call.name == "increase_by_one"
# Assert that the second event is a function response with the
# reflection_guidance
assert (
events[1].content.parts[0].function_response.response["error_type"]
== "ValueError"
)
assert (
events[1].content.parts[0].function_response.response["retry_count"]
== 1
)
assert (
"Wrong Function Name"
in events[1]
.content.parts[0]
.function_response.response["reflection_guidance"]
)
# Assert that the third event is a function call with the correct name
assert events[2].content.parts[0].function_call.name == "increase"
self.assertEqual(function_called, 1)