fix: Convert argument to pydantic model when tool declare to accept pydantic model as argument

PiperOrigin-RevId: 814273005
This commit is contained in:
Xiang (Sean) Zhou
2025-10-02 09:46:58 -07:00
committed by Copybara-Service
parent c46308b7cf
commit 571c802fba
2 changed files with 346 additions and 2 deletions
+62 -2
View File
@@ -18,16 +18,18 @@ import inspect
import logging
from typing import Any
from typing import Callable
from typing import get_args
from typing import get_origin
from typing import Optional
from typing import Union
from google.genai import types
import pydantic
from typing_extensions import override
from ..utils.context_utils import Aclosing
from ._automatic_function_calling_util import build_function_declaration
from .base_tool import BaseTool
from .tool_confirmation import ToolConfirmation
from .tool_context import ToolContext
logger = logging.getLogger('google_adk.' + __name__)
@@ -95,11 +97,69 @@ class FunctionTool(BaseTool):
return function_decl
def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
"""Preprocess and convert function arguments before invocation.
Currently handles:
- Converting JSON dictionaries to Pydantic model instances where expected
Future extensions could include:
- Type coercion for other complex types
- Validation and sanitization
- Custom conversion logic
Args:
args: Raw arguments from the LLM tool call
Returns:
Processed arguments ready for function invocation
"""
signature = inspect.signature(self.func)
converted_args = args.copy()
for param_name, param in signature.parameters.items():
if param_name in args and param.annotation != inspect.Parameter.empty:
target_type = param.annotation
# Handle Optional[PydanticModel] types
if get_origin(param.annotation) is Union:
union_args = get_args(param.annotation)
# Find the non-None type in Optional[T] (which is Union[T, None])
non_none_types = [arg for arg in union_args if arg is not type(None)]
if len(non_none_types) == 1:
target_type = non_none_types[0]
# Check if the target type is a Pydantic model
if inspect.isclass(target_type) and issubclass(
target_type, pydantic.BaseModel
):
# Skip conversion if the value is None and the parameter is Optional
if args[param_name] is None:
continue
# Convert to Pydantic model if it's not already the correct type
if not isinstance(args[param_name], target_type):
try:
converted_args[param_name] = target_type.model_validate(
args[param_name]
)
except Exception as e:
logger.warning(
f"Failed to convert argument '{param_name}' to Pydantic model"
f' {target_type.__name__}: {e}'
)
# Keep the original value if conversion fails
pass
return converted_args
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
args_to_call = args.copy()
# Preprocess arguments (includes Pydantic model conversion)
args_to_call = self._preprocess_args(args)
signature = inspect.signature(self.func)
valid_params = {param for param in signature.parameters}
if 'tool_context' in valid_params:
@@ -0,0 +1,284 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Pydantic model conversion tests
from typing import Optional
from unittest.mock import MagicMock
from google.adk.agents.invocation_context import InvocationContext
from google.adk.sessions.session import Session
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
import pydantic
import pytest
class UserModel(pydantic.BaseModel):
"""Test Pydantic model for user data."""
name: str
age: int
email: Optional[str] = None
class PreferencesModel(pydantic.BaseModel):
"""Test Pydantic model for preferences."""
theme: str = "light"
notifications: bool = True
def sync_function_with_pydantic_model(user: UserModel) -> dict:
"""Sync function that takes a Pydantic model."""
return {
"name": user.name,
"age": user.age,
"email": user.email,
"type": str(type(user).__name__),
}
async def async_function_with_pydantic_model(user: UserModel) -> dict:
"""Async function that takes a Pydantic model."""
return {
"name": user.name,
"age": user.age,
"email": user.email,
"type": str(type(user).__name__),
}
def function_with_optional_pydantic_model(
user: UserModel, preferences: Optional[PreferencesModel] = None
) -> dict:
"""Function with required and optional Pydantic models."""
result = {
"user_name": user.name,
"user_type": str(type(user).__name__),
}
if preferences:
result.update({
"theme": preferences.theme,
"notifications": preferences.notifications,
"preferences_type": str(type(preferences).__name__),
})
return result
def function_with_mixed_args(
name: str, user: UserModel, count: int = 5
) -> dict:
"""Function with mixed argument types including Pydantic model."""
return {
"name": name,
"user_name": user.name,
"user_type": str(type(user).__name__),
"count": count,
}
def test_preprocess_args_with_dict_to_pydantic_conversion():
"""Test _preprocess_args converts dict to Pydantic model."""
tool = FunctionTool(sync_function_with_pydantic_model)
input_args = {
"user": {"name": "Alice", "age": 30, "email": "alice@example.com"}
}
processed_args = tool._preprocess_args(input_args)
# Check that the dict was converted to a Pydantic model
assert "user" in processed_args
user = processed_args["user"]
assert isinstance(user, UserModel)
assert user.name == "Alice"
assert user.age == 30
assert user.email == "alice@example.com"
def test_preprocess_args_with_existing_pydantic_model():
"""Test _preprocess_args leaves existing Pydantic model unchanged."""
tool = FunctionTool(sync_function_with_pydantic_model)
# Create an existing Pydantic model
existing_user = UserModel(name="Bob", age=25)
input_args = {"user": existing_user}
processed_args = tool._preprocess_args(input_args)
# Check that the existing model was not changed (same object)
assert "user" in processed_args
user = processed_args["user"]
assert user is existing_user
assert isinstance(user, UserModel)
assert user.name == "Bob"
def test_preprocess_args_with_optional_pydantic_model_none():
"""Test _preprocess_args handles None for optional Pydantic models."""
tool = FunctionTool(function_with_optional_pydantic_model)
input_args = {"user": {"name": "Charlie", "age": 35}, "preferences": None}
processed_args = tool._preprocess_args(input_args)
# Check user conversion
assert isinstance(processed_args["user"], UserModel)
assert processed_args["user"].name == "Charlie"
# Check preferences remains None
assert processed_args["preferences"] is None
def test_preprocess_args_with_optional_pydantic_model_dict():
"""Test _preprocess_args converts dict for optional Pydantic models."""
tool = FunctionTool(function_with_optional_pydantic_model)
input_args = {
"user": {"name": "Diana", "age": 28},
"preferences": {"theme": "dark", "notifications": False},
}
processed_args = tool._preprocess_args(input_args)
# Check both conversions
assert isinstance(processed_args["user"], UserModel)
assert processed_args["user"].name == "Diana"
assert isinstance(processed_args["preferences"], PreferencesModel)
assert processed_args["preferences"].theme == "dark"
assert processed_args["preferences"].notifications is False
def test_preprocess_args_with_mixed_types():
"""Test _preprocess_args handles mixed argument types correctly."""
tool = FunctionTool(function_with_mixed_args)
input_args = {
"name": "test_name",
"user": {"name": "Eve", "age": 40},
"count": 10,
}
processed_args = tool._preprocess_args(input_args)
# Check that only Pydantic model was converted
assert processed_args["name"] == "test_name" # string unchanged
assert processed_args["count"] == 10 # int unchanged
# Check Pydantic model conversion
assert isinstance(processed_args["user"], UserModel)
assert processed_args["user"].name == "Eve"
assert processed_args["user"].age == 40
def test_preprocess_args_with_invalid_data_graceful_failure():
"""Test _preprocess_args handles invalid data gracefully."""
tool = FunctionTool(sync_function_with_pydantic_model)
# Invalid data that can't be converted to UserModel
input_args = {"user": "invalid_string"} # string instead of dict/model
processed_args = tool._preprocess_args(input_args)
# Should keep original value when conversion fails
assert processed_args["user"] == "invalid_string"
def test_preprocess_args_with_non_pydantic_parameters():
"""Test _preprocess_args ignores non-Pydantic parameters."""
def simple_function(name: str, age: int) -> dict:
return {"name": name, "age": age}
tool = FunctionTool(simple_function)
input_args = {"name": "test", "age": 25}
processed_args = tool._preprocess_args(input_args)
# Should remain unchanged (no Pydantic models to convert)
assert processed_args == input_args
@pytest.mark.asyncio
async def test_run_async_with_pydantic_model_conversion_sync_function():
"""Test run_async with Pydantic model conversion for sync function."""
tool = FunctionTool(sync_function_with_pydantic_model)
tool_context_mock = MagicMock(spec=ToolContext)
invocation_context_mock = MagicMock(spec=InvocationContext)
session_mock = MagicMock(spec=Session)
invocation_context_mock.session = session_mock
tool_context_mock.invocation_context = invocation_context_mock
args = {"user": {"name": "Frank", "age": 45, "email": "frank@example.com"}}
result = await tool.run_async(args=args, tool_context=tool_context_mock)
# Verify the function received a proper Pydantic model
assert result["name"] == "Frank"
assert result["age"] == 45
assert result["email"] == "frank@example.com"
assert result["type"] == "UserModel"
@pytest.mark.asyncio
async def test_run_async_with_pydantic_model_conversion_async_function():
"""Test run_async with Pydantic model conversion for async function."""
tool = FunctionTool(async_function_with_pydantic_model)
tool_context_mock = MagicMock(spec=ToolContext)
invocation_context_mock = MagicMock(spec=InvocationContext)
session_mock = MagicMock(spec=Session)
invocation_context_mock.session = session_mock
tool_context_mock.invocation_context = invocation_context_mock
args = {"user": {"name": "Grace", "age": 32}}
result = await tool.run_async(args=args, tool_context=tool_context_mock)
# Verify the function received a proper Pydantic model
assert result["name"] == "Grace"
assert result["age"] == 32
assert result["email"] is None # default value
assert result["type"] == "UserModel"
@pytest.mark.asyncio
async def test_run_async_with_optional_pydantic_models():
"""Test run_async with optional Pydantic models."""
tool = FunctionTool(function_with_optional_pydantic_model)
tool_context_mock = MagicMock(spec=ToolContext)
invocation_context_mock = MagicMock(spec=InvocationContext)
session_mock = MagicMock(spec=Session)
invocation_context_mock.session = session_mock
tool_context_mock.invocation_context = invocation_context_mock
# Test with both required and optional models
args = {
"user": {"name": "Henry", "age": 50},
"preferences": {"theme": "dark", "notifications": True},
}
result = await tool.run_async(args=args, tool_context=tool_context_mock)
assert result["user_name"] == "Henry"
assert result["user_type"] == "UserModel"
assert result["theme"] == "dark"
assert result["notifications"] is True
assert result["preferences_type"] == "PreferencesModel"
assert result["preferences_type"] == "PreferencesModel"
assert result["preferences_type"] == "PreferencesModel"