You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Convert argument to pydantic model when tool declare to accept pydantic model as argument
PiperOrigin-RevId: 814273005
This commit is contained in:
committed by
Copybara-Service
parent
c46308b7cf
commit
571c802fba
@@ -18,16 +18,18 @@ import inspect
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import get_args
|
||||
from typing import get_origin
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
import pydantic
|
||||
from typing_extensions import override
|
||||
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ._automatic_function_calling_util import build_function_declaration
|
||||
from .base_tool import BaseTool
|
||||
from .tool_confirmation import ToolConfirmation
|
||||
from .tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
@@ -95,11 +97,69 @@ class FunctionTool(BaseTool):
|
||||
|
||||
return function_decl
|
||||
|
||||
def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess and convert function arguments before invocation.
|
||||
|
||||
Currently handles:
|
||||
- Converting JSON dictionaries to Pydantic model instances where expected
|
||||
|
||||
Future extensions could include:
|
||||
- Type coercion for other complex types
|
||||
- Validation and sanitization
|
||||
- Custom conversion logic
|
||||
|
||||
Args:
|
||||
args: Raw arguments from the LLM tool call
|
||||
|
||||
Returns:
|
||||
Processed arguments ready for function invocation
|
||||
"""
|
||||
signature = inspect.signature(self.func)
|
||||
converted_args = args.copy()
|
||||
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param_name in args and param.annotation != inspect.Parameter.empty:
|
||||
target_type = param.annotation
|
||||
|
||||
# Handle Optional[PydanticModel] types
|
||||
if get_origin(param.annotation) is Union:
|
||||
union_args = get_args(param.annotation)
|
||||
# Find the non-None type in Optional[T] (which is Union[T, None])
|
||||
non_none_types = [arg for arg in union_args if arg is not type(None)]
|
||||
if len(non_none_types) == 1:
|
||||
target_type = non_none_types[0]
|
||||
|
||||
# Check if the target type is a Pydantic model
|
||||
if inspect.isclass(target_type) and issubclass(
|
||||
target_type, pydantic.BaseModel
|
||||
):
|
||||
# Skip conversion if the value is None and the parameter is Optional
|
||||
if args[param_name] is None:
|
||||
continue
|
||||
|
||||
# Convert to Pydantic model if it's not already the correct type
|
||||
if not isinstance(args[param_name], target_type):
|
||||
try:
|
||||
converted_args[param_name] = target_type.model_validate(
|
||||
args[param_name]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to convert argument '{param_name}' to Pydantic model"
|
||||
f' {target_type.__name__}: {e}'
|
||||
)
|
||||
# Keep the original value if conversion fails
|
||||
pass
|
||||
|
||||
return converted_args
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
args_to_call = args.copy()
|
||||
# Preprocess arguments (includes Pydantic model conversion)
|
||||
args_to_call = self._preprocess_args(args)
|
||||
|
||||
signature = inspect.signature(self.func)
|
||||
valid_params = {param for param in signature.parameters}
|
||||
if 'tool_context' in valid_params:
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Pydantic model conversion tests
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
|
||||
class UserModel(pydantic.BaseModel):
|
||||
"""Test Pydantic model for user data."""
|
||||
|
||||
name: str
|
||||
age: int
|
||||
email: Optional[str] = None
|
||||
|
||||
|
||||
class PreferencesModel(pydantic.BaseModel):
|
||||
"""Test Pydantic model for preferences."""
|
||||
|
||||
theme: str = "light"
|
||||
notifications: bool = True
|
||||
|
||||
|
||||
def sync_function_with_pydantic_model(user: UserModel) -> dict:
|
||||
"""Sync function that takes a Pydantic model."""
|
||||
return {
|
||||
"name": user.name,
|
||||
"age": user.age,
|
||||
"email": user.email,
|
||||
"type": str(type(user).__name__),
|
||||
}
|
||||
|
||||
|
||||
async def async_function_with_pydantic_model(user: UserModel) -> dict:
|
||||
"""Async function that takes a Pydantic model."""
|
||||
return {
|
||||
"name": user.name,
|
||||
"age": user.age,
|
||||
"email": user.email,
|
||||
"type": str(type(user).__name__),
|
||||
}
|
||||
|
||||
|
||||
def function_with_optional_pydantic_model(
|
||||
user: UserModel, preferences: Optional[PreferencesModel] = None
|
||||
) -> dict:
|
||||
"""Function with required and optional Pydantic models."""
|
||||
result = {
|
||||
"user_name": user.name,
|
||||
"user_type": str(type(user).__name__),
|
||||
}
|
||||
if preferences:
|
||||
result.update({
|
||||
"theme": preferences.theme,
|
||||
"notifications": preferences.notifications,
|
||||
"preferences_type": str(type(preferences).__name__),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def function_with_mixed_args(
|
||||
name: str, user: UserModel, count: int = 5
|
||||
) -> dict:
|
||||
"""Function with mixed argument types including Pydantic model."""
|
||||
return {
|
||||
"name": name,
|
||||
"user_name": user.name,
|
||||
"user_type": str(type(user).__name__),
|
||||
"count": count,
|
||||
}
|
||||
|
||||
|
||||
def test_preprocess_args_with_dict_to_pydantic_conversion():
|
||||
"""Test _preprocess_args converts dict to Pydantic model."""
|
||||
tool = FunctionTool(sync_function_with_pydantic_model)
|
||||
|
||||
input_args = {
|
||||
"user": {"name": "Alice", "age": 30, "email": "alice@example.com"}
|
||||
}
|
||||
|
||||
processed_args = tool._preprocess_args(input_args)
|
||||
|
||||
# Check that the dict was converted to a Pydantic model
|
||||
assert "user" in processed_args
|
||||
user = processed_args["user"]
|
||||
assert isinstance(user, UserModel)
|
||||
assert user.name == "Alice"
|
||||
assert user.age == 30
|
||||
assert user.email == "alice@example.com"
|
||||
|
||||
|
||||
def test_preprocess_args_with_existing_pydantic_model():
|
||||
"""Test _preprocess_args leaves existing Pydantic model unchanged."""
|
||||
tool = FunctionTool(sync_function_with_pydantic_model)
|
||||
|
||||
# Create an existing Pydantic model
|
||||
existing_user = UserModel(name="Bob", age=25)
|
||||
input_args = {"user": existing_user}
|
||||
|
||||
processed_args = tool._preprocess_args(input_args)
|
||||
|
||||
# Check that the existing model was not changed (same object)
|
||||
assert "user" in processed_args
|
||||
user = processed_args["user"]
|
||||
assert user is existing_user
|
||||
assert isinstance(user, UserModel)
|
||||
assert user.name == "Bob"
|
||||
|
||||
|
||||
def test_preprocess_args_with_optional_pydantic_model_none():
|
||||
"""Test _preprocess_args handles None for optional Pydantic models."""
|
||||
tool = FunctionTool(function_with_optional_pydantic_model)
|
||||
|
||||
input_args = {"user": {"name": "Charlie", "age": 35}, "preferences": None}
|
||||
|
||||
processed_args = tool._preprocess_args(input_args)
|
||||
|
||||
# Check user conversion
|
||||
assert isinstance(processed_args["user"], UserModel)
|
||||
assert processed_args["user"].name == "Charlie"
|
||||
|
||||
# Check preferences remains None
|
||||
assert processed_args["preferences"] is None
|
||||
|
||||
|
||||
def test_preprocess_args_with_optional_pydantic_model_dict():
|
||||
"""Test _preprocess_args converts dict for optional Pydantic models."""
|
||||
tool = FunctionTool(function_with_optional_pydantic_model)
|
||||
|
||||
input_args = {
|
||||
"user": {"name": "Diana", "age": 28},
|
||||
"preferences": {"theme": "dark", "notifications": False},
|
||||
}
|
||||
|
||||
processed_args = tool._preprocess_args(input_args)
|
||||
|
||||
# Check both conversions
|
||||
assert isinstance(processed_args["user"], UserModel)
|
||||
assert processed_args["user"].name == "Diana"
|
||||
|
||||
assert isinstance(processed_args["preferences"], PreferencesModel)
|
||||
assert processed_args["preferences"].theme == "dark"
|
||||
assert processed_args["preferences"].notifications is False
|
||||
|
||||
|
||||
def test_preprocess_args_with_mixed_types():
|
||||
"""Test _preprocess_args handles mixed argument types correctly."""
|
||||
tool = FunctionTool(function_with_mixed_args)
|
||||
|
||||
input_args = {
|
||||
"name": "test_name",
|
||||
"user": {"name": "Eve", "age": 40},
|
||||
"count": 10,
|
||||
}
|
||||
|
||||
processed_args = tool._preprocess_args(input_args)
|
||||
|
||||
# Check that only Pydantic model was converted
|
||||
assert processed_args["name"] == "test_name" # string unchanged
|
||||
assert processed_args["count"] == 10 # int unchanged
|
||||
|
||||
# Check Pydantic model conversion
|
||||
assert isinstance(processed_args["user"], UserModel)
|
||||
assert processed_args["user"].name == "Eve"
|
||||
assert processed_args["user"].age == 40
|
||||
|
||||
|
||||
def test_preprocess_args_with_invalid_data_graceful_failure():
|
||||
"""Test _preprocess_args handles invalid data gracefully."""
|
||||
tool = FunctionTool(sync_function_with_pydantic_model)
|
||||
|
||||
# Invalid data that can't be converted to UserModel
|
||||
input_args = {"user": "invalid_string"} # string instead of dict/model
|
||||
|
||||
processed_args = tool._preprocess_args(input_args)
|
||||
|
||||
# Should keep original value when conversion fails
|
||||
assert processed_args["user"] == "invalid_string"
|
||||
|
||||
|
||||
def test_preprocess_args_with_non_pydantic_parameters():
|
||||
"""Test _preprocess_args ignores non-Pydantic parameters."""
|
||||
|
||||
def simple_function(name: str, age: int) -> dict:
|
||||
return {"name": name, "age": age}
|
||||
|
||||
tool = FunctionTool(simple_function)
|
||||
|
||||
input_args = {"name": "test", "age": 25}
|
||||
processed_args = tool._preprocess_args(input_args)
|
||||
|
||||
# Should remain unchanged (no Pydantic models to convert)
|
||||
assert processed_args == input_args
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_pydantic_model_conversion_sync_function():
|
||||
"""Test run_async with Pydantic model conversion for sync function."""
|
||||
tool = FunctionTool(sync_function_with_pydantic_model)
|
||||
|
||||
tool_context_mock = MagicMock(spec=ToolContext)
|
||||
invocation_context_mock = MagicMock(spec=InvocationContext)
|
||||
session_mock = MagicMock(spec=Session)
|
||||
invocation_context_mock.session = session_mock
|
||||
tool_context_mock.invocation_context = invocation_context_mock
|
||||
|
||||
args = {"user": {"name": "Frank", "age": 45, "email": "frank@example.com"}}
|
||||
|
||||
result = await tool.run_async(args=args, tool_context=tool_context_mock)
|
||||
|
||||
# Verify the function received a proper Pydantic model
|
||||
assert result["name"] == "Frank"
|
||||
assert result["age"] == 45
|
||||
assert result["email"] == "frank@example.com"
|
||||
assert result["type"] == "UserModel"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_pydantic_model_conversion_async_function():
|
||||
"""Test run_async with Pydantic model conversion for async function."""
|
||||
tool = FunctionTool(async_function_with_pydantic_model)
|
||||
|
||||
tool_context_mock = MagicMock(spec=ToolContext)
|
||||
invocation_context_mock = MagicMock(spec=InvocationContext)
|
||||
session_mock = MagicMock(spec=Session)
|
||||
invocation_context_mock.session = session_mock
|
||||
tool_context_mock.invocation_context = invocation_context_mock
|
||||
|
||||
args = {"user": {"name": "Grace", "age": 32}}
|
||||
|
||||
result = await tool.run_async(args=args, tool_context=tool_context_mock)
|
||||
|
||||
# Verify the function received a proper Pydantic model
|
||||
assert result["name"] == "Grace"
|
||||
assert result["age"] == 32
|
||||
assert result["email"] is None # default value
|
||||
assert result["type"] == "UserModel"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_optional_pydantic_models():
|
||||
"""Test run_async with optional Pydantic models."""
|
||||
tool = FunctionTool(function_with_optional_pydantic_model)
|
||||
|
||||
tool_context_mock = MagicMock(spec=ToolContext)
|
||||
invocation_context_mock = MagicMock(spec=InvocationContext)
|
||||
session_mock = MagicMock(spec=Session)
|
||||
invocation_context_mock.session = session_mock
|
||||
tool_context_mock.invocation_context = invocation_context_mock
|
||||
|
||||
# Test with both required and optional models
|
||||
args = {
|
||||
"user": {"name": "Henry", "age": 50},
|
||||
"preferences": {"theme": "dark", "notifications": True},
|
||||
}
|
||||
|
||||
result = await tool.run_async(args=args, tool_context=tool_context_mock)
|
||||
|
||||
assert result["user_name"] == "Henry"
|
||||
assert result["user_type"] == "UserModel"
|
||||
assert result["theme"] == "dark"
|
||||
assert result["notifications"] is True
|
||||
assert result["preferences_type"] == "PreferencesModel"
|
||||
assert result["preferences_type"] == "PreferencesModel"
|
||||
assert result["preferences_type"] == "PreferencesModel"
|
||||
Reference in New Issue
Block a user