From 0959b06dbdf3037fe4121f12b6d25edca8fb9afc Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Thu, 26 Jun 2025 03:13:23 +0000 Subject: [PATCH] Fix: Handle unexpected 'parameters' argument in FunctionTool.run_async The LLM occasionally includes an unexpected 'parameters' argument when calling tools, specifically observed with 'transfer_to_agent'. This change makes FunctionTool.run_async more robust by filtering arguments against the function signature before invocation. This resolves issue #1637. Update test_function_tool.py fix typing fix: add `from __future__ import annotations` --- src/google/adk/tools/function_tool.py | 8 +++- tests/unittests/tools/test_function_tool.py | 51 +++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index a3bebd91..2687f120 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect from typing import Any from typing import Callable @@ -79,9 +81,13 @@ class FunctionTool(BaseTool): ) -> Any: args_to_call = args.copy() signature = inspect.signature(self.func) - if 'tool_context' in signature.parameters: + valid_params = {param for param in signature.parameters} + if 'tool_context' in valid_params: args_to_call['tool_context'] = tool_context + # Filter args_to_call to only include valid parameters for the function + args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params} + # Before invoking the function, we check for if the list of args passed in # has all the mandatory arguments or not. # If the check fails, then we don't invoke the tool and let the Agent know diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index 9d325ed0..871f58dc 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -14,7 +14,10 @@ from unittest.mock import MagicMock +from google.adk.agents.invocation_context import InvocationContext +from google.adk.sessions.session import Session from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_context import ToolContext import pytest @@ -294,3 +297,51 @@ async def test_run_async_with_optional_args_not_set_async_func(): args = {"arg1": "test_value_1", "arg3": "test_value_3"} result = await tool.run_async(args=args, tool_context=MagicMock()) assert result == "test_value_1,test_value_3" + + +@pytest.mark.asyncio +async def test_run_async_with_unexpected_argument(): + """Test that run_async filters out unexpected arguments.""" + + def sample_func(expected_arg: str): + return {"received_arg": expected_arg} + + tool = FunctionTool(sample_func) + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = MagicMock(spec=Session) + # Add the missing state attribute to the session mock + mock_invocation_context.session.state = MagicMock() + tool_context_mock = ToolContext(invocation_context=mock_invocation_context) + + result = await tool.run_async( + args={"expected_arg": "hello", "parameters": "should_be_filtered"}, + tool_context=tool_context_mock, + ) + assert result == {"received_arg": "hello"} + + +@pytest.mark.asyncio +async def test_run_async_with_tool_context_and_unexpected_argument(): + """Test that run_async handles tool_context and filters out unexpected arguments.""" + + def sample_func_with_context(expected_arg: str, tool_context: ToolContext): + return {"received_arg": expected_arg, "context_present": bool(tool_context)} + + tool = FunctionTool(sample_func_with_context) + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = MagicMock(spec=Session) + # Add the missing state attribute to the session mock + mock_invocation_context.session.state = MagicMock() + mock_tool_context = ToolContext(invocation_context=mock_invocation_context) + + result = await tool.run_async( + args={ + "expected_arg": "world", + "parameters": "should_also_be_filtered", + }, + tool_context=mock_tool_context, + ) + assert result == { + "received_arg": "world", + "context_present": True, + }