Fix: Handle unexpected 'parameters' argument in FunctionTool.run_async

The LLM occasionally includes an unexpected 'parameters' argument when calling tools, specifically observed with 'transfer_to_agent'. This change makes FunctionTool.run_async more robust by filtering arguments against the function signature before invocation.

This resolves issue #1637.

Update test_function_tool.py

fix typing

fix: add `from __future__ import annotations`
This commit is contained in:
google-labs-jules[bot]
2025-06-26 03:13:23 +00:00
committed by Hangfei Lin
parent 9af2394e0a
commit 0959b06dbd
2 changed files with 58 additions and 1 deletions
+7 -1
View File
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
from typing import Any
from typing import Callable
@@ -79,9 +81,13 @@ class FunctionTool(BaseTool):
) -> Any:
args_to_call = args.copy()
signature = inspect.signature(self.func)
if 'tool_context' in signature.parameters:
valid_params = {param for param in signature.parameters}
if 'tool_context' in valid_params:
args_to_call['tool_context'] = tool_context
# Filter args_to_call to only include valid parameters for the function
args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params}
# Before invoking the function, we check for if the list of args passed in
# has all the mandatory arguments or not.
# If the check fails, then we don't invoke the tool and let the Agent know
@@ -14,7 +14,10 @@
from unittest.mock import MagicMock
from google.adk.agents.invocation_context import InvocationContext
from google.adk.sessions.session import Session
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
import pytest
@@ -294,3 +297,51 @@ async def test_run_async_with_optional_args_not_set_async_func():
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1,test_value_3"
@pytest.mark.asyncio
async def test_run_async_with_unexpected_argument():
"""Test that run_async filters out unexpected arguments."""
def sample_func(expected_arg: str):
return {"received_arg": expected_arg}
tool = FunctionTool(sample_func)
mock_invocation_context = MagicMock(spec=InvocationContext)
mock_invocation_context.session = MagicMock(spec=Session)
# Add the missing state attribute to the session mock
mock_invocation_context.session.state = MagicMock()
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)
result = await tool.run_async(
args={"expected_arg": "hello", "parameters": "should_be_filtered"},
tool_context=tool_context_mock,
)
assert result == {"received_arg": "hello"}
@pytest.mark.asyncio
async def test_run_async_with_tool_context_and_unexpected_argument():
"""Test that run_async handles tool_context and filters out unexpected arguments."""
def sample_func_with_context(expected_arg: str, tool_context: ToolContext):
return {"received_arg": expected_arg, "context_present": bool(tool_context)}
tool = FunctionTool(sample_func_with_context)
mock_invocation_context = MagicMock(spec=InvocationContext)
mock_invocation_context.session = MagicMock(spec=Session)
# Add the missing state attribute to the session mock
mock_invocation_context.session.state = MagicMock()
mock_tool_context = ToolContext(invocation_context=mock_invocation_context)
result = await tool.run_async(
args={
"expected_arg": "world",
"parameters": "should_also_be_filtered",
},
tool_context=mock_tool_context,
)
assert result == {
"received_arg": "world",
"context_present": True,
}