diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 088b2fe1..1ab32d42 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -18,16 +18,18 @@ import inspect import logging from typing import Any from typing import Callable +from typing import get_args +from typing import get_origin from typing import Optional from typing import Union from google.genai import types +import pydantic from typing_extensions import override from ..utils.context_utils import Aclosing from ._automatic_function_calling_util import build_function_declaration from .base_tool import BaseTool -from .tool_confirmation import ToolConfirmation from .tool_context import ToolContext logger = logging.getLogger('google_adk.' + __name__) @@ -95,11 +97,69 @@ class FunctionTool(BaseTool): return function_decl + def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]: + """Preprocess and convert function arguments before invocation. + + Currently handles: + - Converting JSON dictionaries to Pydantic model instances where expected + + Future extensions could include: + - Type coercion for other complex types + - Validation and sanitization + - Custom conversion logic + + Args: + args: Raw arguments from the LLM tool call + + Returns: + Processed arguments ready for function invocation + """ + signature = inspect.signature(self.func) + converted_args = args.copy() + + for param_name, param in signature.parameters.items(): + if param_name in args and param.annotation != inspect.Parameter.empty: + target_type = param.annotation + + # Handle Optional[PydanticModel] types + if get_origin(param.annotation) is Union: + union_args = get_args(param.annotation) + # Find the non-None type in Optional[T] (which is Union[T, None]) + non_none_types = [arg for arg in union_args if arg is not type(None)] + if len(non_none_types) == 1: + target_type = non_none_types[0] + + # Check if the target type is a Pydantic model + if inspect.isclass(target_type) and issubclass( + target_type, pydantic.BaseModel + ): + # Skip conversion if the value is None and the parameter is Optional + if args[param_name] is None: + continue + + # Convert to Pydantic model if it's not already the correct type + if not isinstance(args[param_name], target_type): + try: + converted_args[param_name] = target_type.model_validate( + args[param_name] + ) + except Exception as e: + logger.warning( + f"Failed to convert argument '{param_name}' to Pydantic model" + f' {target_type.__name__}: {e}' + ) + # Keep the original value if conversion fails + pass + + return converted_args + @override async def run_async( self, *, args: dict[str, Any], tool_context: ToolContext ) -> Any: - args_to_call = args.copy() + # Preprocess arguments (includes Pydantic model conversion) + args_to_call = self._preprocess_args(args) + signature = inspect.signature(self.func) valid_params = {param for param in signature.parameters} if 'tool_context' in valid_params: diff --git a/tests/unittests/tools/test_function_tool_pydantic.py b/tests/unittests/tools/test_function_tool_pydantic.py new file mode 100644 index 00000000..1af5d683 --- /dev/null +++ b/tests/unittests/tools/test_function_tool_pydantic.py @@ -0,0 +1,284 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Pydantic model conversion tests + +from typing import Optional +from unittest.mock import MagicMock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.sessions.session import Session +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_context import ToolContext +import pydantic +import pytest + + +class UserModel(pydantic.BaseModel): + """Test Pydantic model for user data.""" + + name: str + age: int + email: Optional[str] = None + + +class PreferencesModel(pydantic.BaseModel): + """Test Pydantic model for preferences.""" + + theme: str = "light" + notifications: bool = True + + +def sync_function_with_pydantic_model(user: UserModel) -> dict: + """Sync function that takes a Pydantic model.""" + return { + "name": user.name, + "age": user.age, + "email": user.email, + "type": str(type(user).__name__), + } + + +async def async_function_with_pydantic_model(user: UserModel) -> dict: + """Async function that takes a Pydantic model.""" + return { + "name": user.name, + "age": user.age, + "email": user.email, + "type": str(type(user).__name__), + } + + +def function_with_optional_pydantic_model( + user: UserModel, preferences: Optional[PreferencesModel] = None +) -> dict: + """Function with required and optional Pydantic models.""" + result = { + "user_name": user.name, + "user_type": str(type(user).__name__), + } + if preferences: + result.update({ + "theme": preferences.theme, + "notifications": preferences.notifications, + "preferences_type": str(type(preferences).__name__), + }) + return result + + +def function_with_mixed_args( + name: str, user: UserModel, count: int = 5 +) -> dict: + """Function with mixed argument types including Pydantic model.""" + return { + "name": name, + "user_name": user.name, + "user_type": str(type(user).__name__), + "count": count, + } + + +def test_preprocess_args_with_dict_to_pydantic_conversion(): + """Test _preprocess_args converts dict to Pydantic model.""" + tool = FunctionTool(sync_function_with_pydantic_model) + + input_args = { + "user": {"name": "Alice", "age": 30, "email": "alice@example.com"} + } + + processed_args = tool._preprocess_args(input_args) + + # Check that the dict was converted to a Pydantic model + assert "user" in processed_args + user = processed_args["user"] + assert isinstance(user, UserModel) + assert user.name == "Alice" + assert user.age == 30 + assert user.email == "alice@example.com" + + +def test_preprocess_args_with_existing_pydantic_model(): + """Test _preprocess_args leaves existing Pydantic model unchanged.""" + tool = FunctionTool(sync_function_with_pydantic_model) + + # Create an existing Pydantic model + existing_user = UserModel(name="Bob", age=25) + input_args = {"user": existing_user} + + processed_args = tool._preprocess_args(input_args) + + # Check that the existing model was not changed (same object) + assert "user" in processed_args + user = processed_args["user"] + assert user is existing_user + assert isinstance(user, UserModel) + assert user.name == "Bob" + + +def test_preprocess_args_with_optional_pydantic_model_none(): + """Test _preprocess_args handles None for optional Pydantic models.""" + tool = FunctionTool(function_with_optional_pydantic_model) + + input_args = {"user": {"name": "Charlie", "age": 35}, "preferences": None} + + processed_args = tool._preprocess_args(input_args) + + # Check user conversion + assert isinstance(processed_args["user"], UserModel) + assert processed_args["user"].name == "Charlie" + + # Check preferences remains None + assert processed_args["preferences"] is None + + +def test_preprocess_args_with_optional_pydantic_model_dict(): + """Test _preprocess_args converts dict for optional Pydantic models.""" + tool = FunctionTool(function_with_optional_pydantic_model) + + input_args = { + "user": {"name": "Diana", "age": 28}, + "preferences": {"theme": "dark", "notifications": False}, + } + + processed_args = tool._preprocess_args(input_args) + + # Check both conversions + assert isinstance(processed_args["user"], UserModel) + assert processed_args["user"].name == "Diana" + + assert isinstance(processed_args["preferences"], PreferencesModel) + assert processed_args["preferences"].theme == "dark" + assert processed_args["preferences"].notifications is False + + +def test_preprocess_args_with_mixed_types(): + """Test _preprocess_args handles mixed argument types correctly.""" + tool = FunctionTool(function_with_mixed_args) + + input_args = { + "name": "test_name", + "user": {"name": "Eve", "age": 40}, + "count": 10, + } + + processed_args = tool._preprocess_args(input_args) + + # Check that only Pydantic model was converted + assert processed_args["name"] == "test_name" # string unchanged + assert processed_args["count"] == 10 # int unchanged + + # Check Pydantic model conversion + assert isinstance(processed_args["user"], UserModel) + assert processed_args["user"].name == "Eve" + assert processed_args["user"].age == 40 + + +def test_preprocess_args_with_invalid_data_graceful_failure(): + """Test _preprocess_args handles invalid data gracefully.""" + tool = FunctionTool(sync_function_with_pydantic_model) + + # Invalid data that can't be converted to UserModel + input_args = {"user": "invalid_string"} # string instead of dict/model + + processed_args = tool._preprocess_args(input_args) + + # Should keep original value when conversion fails + assert processed_args["user"] == "invalid_string" + + +def test_preprocess_args_with_non_pydantic_parameters(): + """Test _preprocess_args ignores non-Pydantic parameters.""" + + def simple_function(name: str, age: int) -> dict: + return {"name": name, "age": age} + + tool = FunctionTool(simple_function) + + input_args = {"name": "test", "age": 25} + processed_args = tool._preprocess_args(input_args) + + # Should remain unchanged (no Pydantic models to convert) + assert processed_args == input_args + + +@pytest.mark.asyncio +async def test_run_async_with_pydantic_model_conversion_sync_function(): + """Test run_async with Pydantic model conversion for sync function.""" + tool = FunctionTool(sync_function_with_pydantic_model) + + tool_context_mock = MagicMock(spec=ToolContext) + invocation_context_mock = MagicMock(spec=InvocationContext) + session_mock = MagicMock(spec=Session) + invocation_context_mock.session = session_mock + tool_context_mock.invocation_context = invocation_context_mock + + args = {"user": {"name": "Frank", "age": 45, "email": "frank@example.com"}} + + result = await tool.run_async(args=args, tool_context=tool_context_mock) + + # Verify the function received a proper Pydantic model + assert result["name"] == "Frank" + assert result["age"] == 45 + assert result["email"] == "frank@example.com" + assert result["type"] == "UserModel" + + +@pytest.mark.asyncio +async def test_run_async_with_pydantic_model_conversion_async_function(): + """Test run_async with Pydantic model conversion for async function.""" + tool = FunctionTool(async_function_with_pydantic_model) + + tool_context_mock = MagicMock(spec=ToolContext) + invocation_context_mock = MagicMock(spec=InvocationContext) + session_mock = MagicMock(spec=Session) + invocation_context_mock.session = session_mock + tool_context_mock.invocation_context = invocation_context_mock + + args = {"user": {"name": "Grace", "age": 32}} + + result = await tool.run_async(args=args, tool_context=tool_context_mock) + + # Verify the function received a proper Pydantic model + assert result["name"] == "Grace" + assert result["age"] == 32 + assert result["email"] is None # default value + assert result["type"] == "UserModel" + + +@pytest.mark.asyncio +async def test_run_async_with_optional_pydantic_models(): + """Test run_async with optional Pydantic models.""" + tool = FunctionTool(function_with_optional_pydantic_model) + + tool_context_mock = MagicMock(spec=ToolContext) + invocation_context_mock = MagicMock(spec=InvocationContext) + session_mock = MagicMock(spec=Session) + invocation_context_mock.session = session_mock + tool_context_mock.invocation_context = invocation_context_mock + + # Test with both required and optional models + args = { + "user": {"name": "Henry", "age": 50}, + "preferences": {"theme": "dark", "notifications": True}, + } + + result = await tool.run_async(args=args, tool_context=tool_context_mock) + + assert result["user_name"] == "Henry" + assert result["user_type"] == "UserModel" + assert result["theme"] == "dark" + assert result["notifications"] is True + assert result["preferences_type"] == "PreferencesModel" + assert result["preferences_type"] == "PreferencesModel" + assert result["preferences_type"] == "PreferencesModel"