You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Extend ReflectAndRetryToolPlugin to support hallucinating function calls
PiperOrigin-RevId: 820051762
This commit is contained in:
committed by
Copybara-Service
parent
3734ceaa6c
commit
f51380f9ea
@@ -46,8 +46,30 @@ You can run the agent with:
|
||||
$ adk web contributing/samples/plugin_reflect_tool_retry
|
||||
```
|
||||
|
||||
You can provide the following prompt to see the agent retrying tool calls:
|
||||
Select "basic" and provide the following prompt to see the agent retrying tool
|
||||
calls:
|
||||
|
||||
```
|
||||
Please guess a number! Tell me what number you guess and how is it.
|
||||
```
|
||||
|
||||
### Hallucinating tool calls
|
||||
|
||||
The "hallucinating_func_name" agent is an example to show the plugin can retry
|
||||
hallucinating tool calls.
|
||||
|
||||
For example, we used the `after_model_callback` to hack a tool call with the
|
||||
wrong name then the agent can retry calling with the right tool name.
|
||||
|
||||
You can run the agent with:
|
||||
|
||||
```bash
|
||||
$ adk web contributing/samples/plugin_reflect_tool_retry
|
||||
```
|
||||
|
||||
Select "hallucinating_func_name" and provide the following prompt to see the
|
||||
agent retrying tool calls:
|
||||
|
||||
```
|
||||
Roll a 6 sided die
|
||||
```
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
from google.adk.agents import LlmAgent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.apps.app import App
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.plugins import ReflectAndRetryToolPlugin
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
|
||||
APP_NAME = "hallucinating_func_name"
|
||||
USER_ID = "test_user"
|
||||
|
||||
hallucinated = False # Whether the tool name is hallucinated
|
||||
|
||||
|
||||
def roll_die(sides: int, tool_context: ToolContext) -> int:
|
||||
"""Roll a die and return the rolled result.
|
||||
|
||||
Args:
|
||||
sides: The integer number of sides the die has.
|
||||
|
||||
Returns:
|
||||
An integer of the result of rolling the die.
|
||||
"""
|
||||
result = random.randint(1, sides)
|
||||
if not "rolls" in tool_context.state:
|
||||
tool_context.state["rolls"] = []
|
||||
|
||||
tool_context.state["rolls"] = tool_context.state["rolls"] + [result]
|
||||
return result
|
||||
|
||||
|
||||
def after_model_callback(
|
||||
callback_context: CallbackContext, llm_response: LlmResponse
|
||||
):
|
||||
"""After model callback to produce one hallucinating tool call."""
|
||||
global hallucinated
|
||||
|
||||
if hallucinated:
|
||||
return None
|
||||
|
||||
if (
|
||||
llm_response.content
|
||||
and llm_response.content.parts[0].function_call.name == "roll_die"
|
||||
):
|
||||
llm_response.content.parts[0].function_call.name = "roll_die_wrong_name"
|
||||
hallucinated = True
|
||||
return None
|
||||
|
||||
|
||||
root_agent = LlmAgent(
|
||||
name="hello_world",
|
||||
description="Helpful agent",
|
||||
instruction="""Use guess_number_tool to guess a number.""",
|
||||
model="gemini-2.5-flash",
|
||||
tools=[roll_die],
|
||||
after_model_callback=after_model_callback,
|
||||
)
|
||||
|
||||
|
||||
app = App(
|
||||
name=APP_NAME,
|
||||
root_agent=root_agent,
|
||||
plugins=[
|
||||
ReflectAndRetryToolPlugin(max_retries=3),
|
||||
],
|
||||
)
|
||||
@@ -275,21 +275,37 @@ async def _execute_single_function_call_async(
|
||||
tool_confirmation: Optional[ToolConfirmation] = None,
|
||||
) -> Optional[Event]:
|
||||
"""Execute a single function call with thread safety for state modifications."""
|
||||
tool, tool_context = _get_tool_and_context(
|
||||
invocation_context,
|
||||
function_call,
|
||||
tools_dict,
|
||||
tool_confirmation,
|
||||
# Do not use "args" as the variable name, because it is a reserved keyword
|
||||
# in python debugger.
|
||||
# Make a deep copy to avoid being modified.
|
||||
function_args = (
|
||||
copy.deepcopy(function_call.args) if function_call.args else {}
|
||||
)
|
||||
|
||||
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
|
||||
# Do not use "args" as the variable name, because it is a reserved keyword
|
||||
# in python debugger.
|
||||
# Make a deep copy to avoid being modified.
|
||||
function_args = (
|
||||
copy.deepcopy(function_call.args) if function_call.args else {}
|
||||
)
|
||||
tool_context = _create_tool_context(
|
||||
invocation_context, function_call, tool_confirmation
|
||||
)
|
||||
|
||||
try:
|
||||
tool = _get_tool(function_call, tools_dict)
|
||||
except ValueError as tool_error:
|
||||
tool = BaseTool(name=function_call.name, description='Tool not found')
|
||||
error_response = (
|
||||
await invocation_context.plugin_manager.run_on_tool_error_callback(
|
||||
tool=tool,
|
||||
tool_args=function_args,
|
||||
tool_context=tool_context,
|
||||
error=tool_error,
|
||||
)
|
||||
)
|
||||
if error_response is not None:
|
||||
return __build_response_event(
|
||||
tool, error_response, tool_context, invocation_context
|
||||
)
|
||||
else:
|
||||
raise tool_error
|
||||
|
||||
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
|
||||
# Step 1: Check if plugin before_tool_callback overrides the function
|
||||
# response.
|
||||
function_response = (
|
||||
@@ -639,25 +655,46 @@ async def _process_function_live_helper(
|
||||
return function_response
|
||||
|
||||
|
||||
def _get_tool(
|
||||
function_call: types.FunctionCall, tools_dict: dict[str, BaseTool]
|
||||
):
|
||||
"""Returns the tool corresponding to the function call."""
|
||||
if function_call.name not in tools_dict:
|
||||
raise ValueError(
|
||||
f'Function {function_call.name} is not found in the tools_dict:'
|
||||
f' {tools_dict.keys()}.'
|
||||
)
|
||||
|
||||
return tools_dict[function_call.name]
|
||||
|
||||
|
||||
def _create_tool_context(
|
||||
invocation_context: InvocationContext,
|
||||
function_call: types.FunctionCall,
|
||||
tool_confirmation: Optional[ToolConfirmation] = None,
|
||||
):
|
||||
"""Creates a ToolContext object."""
|
||||
return ToolContext(
|
||||
invocation_context=invocation_context,
|
||||
function_call_id=function_call.id,
|
||||
tool_confirmation=tool_confirmation,
|
||||
)
|
||||
|
||||
|
||||
def _get_tool_and_context(
|
||||
invocation_context: InvocationContext,
|
||||
function_call: types.FunctionCall,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
tool_confirmation: Optional[ToolConfirmation] = None,
|
||||
):
|
||||
if function_call.name not in tools_dict:
|
||||
raise ValueError(
|
||||
f'Function {function_call.name} is not found in the tools_dict.'
|
||||
)
|
||||
|
||||
tool_context = ToolContext(
|
||||
invocation_context=invocation_context,
|
||||
function_call_id=function_call.id,
|
||||
tool_confirmation=tool_confirmation,
|
||||
"""Returns the tool and tool context corresponding to the function call."""
|
||||
tool = _get_tool(function_call, tools_dict)
|
||||
tool_context = _create_tool_context(
|
||||
invocation_context,
|
||||
function_call,
|
||||
tool_confirmation,
|
||||
)
|
||||
|
||||
tool = tools_dict[function_call.name]
|
||||
|
||||
return (tool, tool_context)
|
||||
|
||||
|
||||
|
||||
@@ -142,8 +142,20 @@ class ReflectAndRetryToolPlugin(BasePlugin):
|
||||
tool_args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
result: Any,
|
||||
) -> Optional[dict]:
|
||||
"""Handles successful tool calls or extracts and processes errors."""
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Handles successful tool calls or extracts and processes errors.
|
||||
|
||||
Args:
|
||||
tool: The tool that was called.
|
||||
tool_args: The arguments passed to the tool.
|
||||
tool_context: The context of the tool call.
|
||||
result: The result of the tool call.
|
||||
|
||||
Returns:
|
||||
An optional dictionary containing reflection guidance if an error is
|
||||
detected, or None if the tool call was successful or the
|
||||
response is already a reflection message.
|
||||
"""
|
||||
if (
|
||||
isinstance(result, dict)
|
||||
and result.get("response_type") == REFLECT_AND_RETRY_RESPONSE_TYPE
|
||||
@@ -157,7 +169,8 @@ class ReflectAndRetryToolPlugin(BasePlugin):
|
||||
if error:
|
||||
return await self._handle_tool_error(tool, tool_args, tool_context, error)
|
||||
|
||||
# On success, reset the failure count for this specific tool within its scope.
|
||||
# On success, reset the failure count for this specific tool within
|
||||
# its scope.
|
||||
await self._reset_failures_for_tool(tool_context, tool.name)
|
||||
return None
|
||||
|
||||
@@ -168,7 +181,7 @@ class ReflectAndRetryToolPlugin(BasePlugin):
|
||||
tool_args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
result: Any,
|
||||
) -> Optional[Any]:
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Extracts an error from a successful tool result and triggers retry logic.
|
||||
|
||||
This is useful when tool call finishes successfully but the result contains
|
||||
@@ -176,6 +189,15 @@ class ReflectAndRetryToolPlugin(BasePlugin):
|
||||
|
||||
By overriding this method, you can trigger retry logic on these successful
|
||||
results that contain errors.
|
||||
|
||||
Args:
|
||||
tool: The tool that was called.
|
||||
tool_args: The arguments passed to the tool.
|
||||
tool_context: The context of the tool call.
|
||||
result: The result of the tool call.
|
||||
|
||||
Returns:
|
||||
The extracted error if any, or None if no error was detected.
|
||||
"""
|
||||
return None
|
||||
|
||||
@@ -186,8 +208,18 @@ class ReflectAndRetryToolPlugin(BasePlugin):
|
||||
tool_args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
error: Exception,
|
||||
) -> Optional[dict]:
|
||||
"""Handles tool exceptions by providing reflection guidance."""
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Handles tool exceptions by providing reflection guidance.
|
||||
|
||||
Args:
|
||||
tool: The tool that was called.
|
||||
tool_args: The arguments passed to the tool.
|
||||
tool_context: The context of the tool call.
|
||||
error: The exception raised by the tool.
|
||||
|
||||
Returns:
|
||||
An optional dictionary containing reflection guidance for the error.
|
||||
"""
|
||||
return await self._handle_tool_error(tool, tool_args, tool_context, error)
|
||||
|
||||
async def _handle_tool_error(
|
||||
@@ -196,8 +228,18 @@ class ReflectAndRetryToolPlugin(BasePlugin):
|
||||
tool_args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
error: Any,
|
||||
) -> Optional[dict]:
|
||||
"""Central, thread-safe logic for processing tool errors."""
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Central, thread-safe logic for processing tool errors.
|
||||
|
||||
Args:
|
||||
tool: The tool that was called.
|
||||
tool_args: The arguments passed to the tool.
|
||||
tool_context: The context of the tool call.
|
||||
error: The error to be handled.
|
||||
|
||||
Returns:
|
||||
An optional dictionary containing reflection guidance for the error.
|
||||
"""
|
||||
if self.max_retries == 0:
|
||||
if self.throw_exception_if_retry_exceeded:
|
||||
raise error
|
||||
@@ -285,6 +327,7 @@ This is retry attempt **{retry_count} of {self.max_retries}**. Analyze the error
|
||||
2. **State or Preconditions**: Did a previous step fail or not produce the necessary state/resource for this tool to succeed?
|
||||
3. **Alternative Approach**: Is this the right tool for the job? Could another tool or a different sequence of steps achieve the goal?
|
||||
4. **Simplify the Task**: Can you break the problem down into smaller, simpler steps?
|
||||
5. **Wrong Function Name**: Does the error indicates the tool is not found? Please check again and only use available tools.
|
||||
|
||||
Formulate a new plan based on your analysis and try a corrected or different approach.
|
||||
"""
|
||||
|
||||
@@ -16,10 +16,14 @@ from typing import Any
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.plugins.reflect_retry_tool_plugin import REFLECT_AND_RETRY_RESPONSE_TYPE
|
||||
from google.adk.plugins.reflect_retry_tool_plugin import ReflectAndRetryToolPlugin
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
|
||||
from .. import testing_utils
|
||||
|
||||
|
||||
class MockTool(BaseTool):
|
||||
@@ -524,3 +528,56 @@ class TestReflectAndRetryToolPlugin(IsolatedAsyncioTestCase):
|
||||
result=custom_error,
|
||||
)
|
||||
self.assertEqual(result4["retry_count"], 1)
|
||||
|
||||
async def test_hallucinating_tool_name(self):
|
||||
"""Test that hallucinating tool name is handled correctly."""
|
||||
wrong_function_call = types.Part.from_function_call(
|
||||
name="increase_by_one", args={"x": 1}
|
||||
)
|
||||
correct_function_call = types.Part.from_function_call(
|
||||
name="increase", args={"x": 1}
|
||||
)
|
||||
responses: list[types.Content] = [
|
||||
wrong_function_call,
|
||||
correct_function_call,
|
||||
"response1",
|
||||
]
|
||||
mock_model = testing_utils.MockModel.create(responses=responses)
|
||||
|
||||
function_called = 0
|
||||
|
||||
def increase(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x + 1
|
||||
|
||||
agent = LlmAgent(name="root_agent", model=mock_model, tools=[increase])
|
||||
runner = testing_utils.TestInMemoryRunner(
|
||||
agent=agent, plugins=[self.get_plugin()]
|
||||
)
|
||||
|
||||
events = await runner.run_async_with_new_session("test")
|
||||
|
||||
# Assert that the first event is a function call with the wrong name
|
||||
assert events[0].content.parts[0].function_call.name == "increase_by_one"
|
||||
|
||||
# Assert that the second event is a function response with the
|
||||
# reflection_guidance
|
||||
assert (
|
||||
events[1].content.parts[0].function_response.response["error_type"]
|
||||
== "ValueError"
|
||||
)
|
||||
assert (
|
||||
events[1].content.parts[0].function_response.response["retry_count"]
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
"Wrong Function Name"
|
||||
in events[1]
|
||||
.content.parts[0]
|
||||
.function_response.response["reflection_guidance"]
|
||||
)
|
||||
|
||||
# Assert that the third event is a function call with the correct name
|
||||
assert events[2].content.parts[0].function_call.name == "increase"
|
||||
self.assertEqual(function_called, 1)
|
||||
|
||||
Reference in New Issue
Block a user