You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
Fix: Handle unexpected 'parameters' argument in FunctionTool.run_async
The LLM occasionally includes an unexpected 'parameters' argument when calling tools, specifically observed with 'transfer_to_agent'. This change makes FunctionTool.run_async more robust by filtering arguments against the function signature before invocation. This resolves issue #1637. Update test_function_tool.py fix typing fix: add `from __future__ import annotations`
This commit is contained in:
committed by
Hangfei Lin
parent
9af2394e0a
commit
0959b06dbd
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
@@ -79,9 +81,13 @@ class FunctionTool(BaseTool):
|
||||
) -> Any:
|
||||
args_to_call = args.copy()
|
||||
signature = inspect.signature(self.func)
|
||||
if 'tool_context' in signature.parameters:
|
||||
valid_params = {param for param in signature.parameters}
|
||||
if 'tool_context' in valid_params:
|
||||
args_to_call['tool_context'] = tool_context
|
||||
|
||||
# Filter args_to_call to only include valid parameters for the function
|
||||
args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params}
|
||||
|
||||
# Before invoking the function, we check for if the list of args passed in
|
||||
# has all the mandatory arguments or not.
|
||||
# If the check fails, then we don't invoke the tool and let the Agent know
|
||||
|
||||
@@ -14,7 +14,10 @@
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -294,3 +297,51 @@ async def test_run_async_with_optional_args_not_set_async_func():
|
||||
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
|
||||
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||
assert result == "test_value_1,test_value_3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_unexpected_argument():
|
||||
"""Test that run_async filters out unexpected arguments."""
|
||||
|
||||
def sample_func(expected_arg: str):
|
||||
return {"received_arg": expected_arg}
|
||||
|
||||
tool = FunctionTool(sample_func)
|
||||
mock_invocation_context = MagicMock(spec=InvocationContext)
|
||||
mock_invocation_context.session = MagicMock(spec=Session)
|
||||
# Add the missing state attribute to the session mock
|
||||
mock_invocation_context.session.state = MagicMock()
|
||||
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)
|
||||
|
||||
result = await tool.run_async(
|
||||
args={"expected_arg": "hello", "parameters": "should_be_filtered"},
|
||||
tool_context=tool_context_mock,
|
||||
)
|
||||
assert result == {"received_arg": "hello"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_tool_context_and_unexpected_argument():
|
||||
"""Test that run_async handles tool_context and filters out unexpected arguments."""
|
||||
|
||||
def sample_func_with_context(expected_arg: str, tool_context: ToolContext):
|
||||
return {"received_arg": expected_arg, "context_present": bool(tool_context)}
|
||||
|
||||
tool = FunctionTool(sample_func_with_context)
|
||||
mock_invocation_context = MagicMock(spec=InvocationContext)
|
||||
mock_invocation_context.session = MagicMock(spec=Session)
|
||||
# Add the missing state attribute to the session mock
|
||||
mock_invocation_context.session.state = MagicMock()
|
||||
mock_tool_context = ToolContext(invocation_context=mock_invocation_context)
|
||||
|
||||
result = await tool.run_async(
|
||||
args={
|
||||
"expected_arg": "world",
|
||||
"parameters": "should_also_be_filtered",
|
||||
},
|
||||
tool_context=mock_tool_context,
|
||||
)
|
||||
assert result == {
|
||||
"received_arg": "world",
|
||||
"context_present": True,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user