You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Use json schema for agent tool declaration when feature enabled
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 854329254
This commit is contained in:
committed by
Copybara-Service
parent
86e7664006
commit
6b241f5ef2
@@ -14,8 +14,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Generator
|
||||
import warnings
|
||||
|
||||
from ..utils.env_utils import is_env_enabled
|
||||
@@ -264,3 +266,52 @@ def _emit_non_stable_warning_once(
|
||||
f"[{feature_stage.name.upper()}] feature {feature_name} is enabled."
|
||||
)
|
||||
warnings.warn(full_message, category=UserWarning, stacklevel=4)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_feature_override(
|
||||
feature_name: FeatureName,
|
||||
enabled: bool,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Temporarily override a feature's enabled state within a context.
|
||||
|
||||
This context manager is useful for testing or temporarily enabling/disabling
|
||||
a feature within a specific scope. The original state is restored when the
|
||||
context exits.
|
||||
|
||||
Args:
|
||||
feature_name: The feature name to override.
|
||||
enabled: Whether the feature should be enabled.
|
||||
|
||||
Yields:
|
||||
None
|
||||
|
||||
Example:
|
||||
```python
|
||||
from google.adk.features import FeatureName, temporary_feature_override
|
||||
|
||||
# Temporarily enable a feature for testing
|
||||
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True):
|
||||
# Feature is enabled here
|
||||
result = some_function_that_checks_feature()
|
||||
# Feature is restored to original state here
|
||||
```
|
||||
"""
|
||||
config = _get_feature_config(feature_name)
|
||||
if config is None:
|
||||
raise ValueError(f"Feature {feature_name} is not registered.")
|
||||
|
||||
# Save the original override state
|
||||
had_override = feature_name in _FEATURE_OVERRIDES
|
||||
original_value = _FEATURE_OVERRIDES.get(feature_name)
|
||||
|
||||
# Apply the temporary override
|
||||
_FEATURE_OVERRIDES[feature_name] = enabled
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore the original state
|
||||
if had_override:
|
||||
_FEATURE_OVERRIDES[feature_name] = original_value
|
||||
else:
|
||||
_FEATURE_OVERRIDES.pop(feature_name, None)
|
||||
|
||||
@@ -23,6 +23,8 @@ from typing_extensions import override
|
||||
|
||||
from . import _automatic_function_calling_util
|
||||
from ..agents.common_configs import AgentRefConfig
|
||||
from ..features import FeatureName
|
||||
from ..features import is_feature_enabled
|
||||
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ._forwarding_artifact_service import ForwardingArtifactService
|
||||
@@ -82,29 +84,48 @@ class AgentTool(BaseTool):
|
||||
# Override the description with the agent's description
|
||||
result.description = self.agent.description
|
||||
else:
|
||||
result = types.FunctionDeclaration(
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
'request': types.Schema(
|
||||
type=types.Type.STRING,
|
||||
),
|
||||
},
|
||||
required=['request'],
|
||||
),
|
||||
description=self.agent.description,
|
||||
name=self.name,
|
||||
)
|
||||
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
|
||||
result = types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.agent.description,
|
||||
parameters_json_schema={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'request': {'type': 'string'},
|
||||
},
|
||||
'required': ['request'],
|
||||
},
|
||||
)
|
||||
else:
|
||||
result = types.FunctionDeclaration(
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
'request': types.Schema(
|
||||
type=types.Type.STRING,
|
||||
),
|
||||
},
|
||||
required=['request'],
|
||||
),
|
||||
description=self.agent.description,
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
# Set response schema for non-GEMINI_API variants
|
||||
if self._api_variant != GoogleLLMVariant.GEMINI_API:
|
||||
# Determine response type based on agent's output schema
|
||||
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
|
||||
# Agent has structured output schema - response is an object
|
||||
result.response = types.Schema(type=types.Type.OBJECT)
|
||||
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
|
||||
result.response_json_schema = {'type': 'object'}
|
||||
else:
|
||||
result.response = types.Schema(type=types.Type.OBJECT)
|
||||
else:
|
||||
# Agent returns text - response is a string
|
||||
result.response = types.Schema(type=types.Type.STRING)
|
||||
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
|
||||
result.response_json_schema = {'type': 'string'}
|
||||
else:
|
||||
result.response = types.Schema(type=types.Type.STRING)
|
||||
|
||||
result.name = self.name
|
||||
return result
|
||||
|
||||
@@ -22,6 +22,8 @@ from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.features import FeatureName
|
||||
from google.adk.features._feature_registry import temporary_feature_override
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
|
||||
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
@@ -129,17 +131,16 @@ class TestMCPTool:
|
||||
assert declaration.description == "Test tool description"
|
||||
assert declaration.parameters is not None
|
||||
|
||||
def test_get_declaration_with_json_schema_for_func_decl_enabled(
|
||||
self, monkeypatch
|
||||
):
|
||||
def test_get_declaration_with_json_schema_for_func_decl_enabled(self):
|
||||
"""Test function declaration generation with json schema for func decl enabled."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true")
|
||||
with temporary_feature_override(
|
||||
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True
|
||||
):
|
||||
declaration = tool._get_declaration()
|
||||
|
||||
assert isinstance(declaration, FunctionDeclaration)
|
||||
@@ -151,7 +152,7 @@ class TestMCPTool:
|
||||
assert declaration.response_json_schema is None
|
||||
|
||||
def test_get_declaration_with_output_schema_and_json_schema_for_func_decl_enabled(
|
||||
self, monkeypatch
|
||||
self,
|
||||
):
|
||||
"""Test function declaration generation with an output schema and json schema for func decl enabled."""
|
||||
output_schema = {
|
||||
@@ -169,8 +170,9 @@ class TestMCPTool:
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true")
|
||||
with temporary_feature_override(
|
||||
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True
|
||||
):
|
||||
declaration = tool._get_declaration()
|
||||
|
||||
assert isinstance(declaration, FunctionDeclaration)
|
||||
@@ -178,7 +180,7 @@ class TestMCPTool:
|
||||
assert declaration.response_json_schema == output_schema
|
||||
|
||||
def test_get_declaration_with_empty_output_schema_and_json_schema_for_func_decl_enabled(
|
||||
self, monkeypatch
|
||||
self,
|
||||
):
|
||||
"""Test function declaration with an empty output schema and json schema for func decl enabled."""
|
||||
tool = MCPTool(
|
||||
@@ -186,8 +188,9 @@ class TestMCPTool:
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true")
|
||||
with temporary_feature_override(
|
||||
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True
|
||||
):
|
||||
declaration = tool._get_declaration()
|
||||
|
||||
assert declaration.response is None
|
||||
|
||||
@@ -21,6 +21,8 @@ from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.features import FeatureName
|
||||
from google.adk.features._feature_registry import temporary_feature_override
|
||||
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
@@ -33,6 +35,7 @@ from google.adk.utils.variant_utils import GoogleLLMVariant
|
||||
from google.genai import types
|
||||
from google.genai.types import Part
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
from .. import testing_utils
|
||||
@@ -702,3 +705,198 @@ def test_agent_tool_description_with_input_schema():
|
||||
|
||||
# The description should come from the agent, not the Pydantic model
|
||||
assert declaration.description == agent_description
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_json_schema_feature():
|
||||
"""Fixture to enable JSON_SCHEMA_FOR_FUNC_DECL feature for a test."""
|
||||
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True):
|
||||
yield
|
||||
|
||||
|
||||
def test_agent_tool_no_schema_with_json_schema_feature(
|
||||
enable_json_schema_feature,
|
||||
):
|
||||
"""Test AgentTool without input_schema uses parameters_json_schema when feature enabled."""
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
description='A tool agent for testing.',
|
||||
model=testing_utils.MockModel.create(responses=['test response']),
|
||||
)
|
||||
|
||||
agent_tool = AgentTool(agent=tool_agent)
|
||||
declaration = agent_tool._get_declaration()
|
||||
|
||||
assert declaration.model_dump(exclude_none=True) == {
|
||||
'name': 'tool_agent',
|
||||
'description': 'A tool agent for testing.',
|
||||
'parameters_json_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'request': {'type': 'string'},
|
||||
},
|
||||
'required': ['request'],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
'env_variables',
|
||||
[
|
||||
'VERTEX', # Test VERTEX_AI variant
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_agent_tool_response_json_schema_no_output_schema_vertex_ai(
|
||||
env_variables,
|
||||
enable_json_schema_feature,
|
||||
):
|
||||
"""Test AgentTool with no output schema uses response_json_schema for VERTEX_AI when feature enabled."""
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
description='A tool agent for testing.',
|
||||
model=testing_utils.MockModel.create(responses=['test response']),
|
||||
)
|
||||
|
||||
agent_tool = AgentTool(agent=tool_agent)
|
||||
declaration = agent_tool._get_declaration()
|
||||
|
||||
assert declaration.model_dump(exclude_none=True) == {
|
||||
'name': 'tool_agent',
|
||||
'description': 'A tool agent for testing.',
|
||||
'parameters_json_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'request': {'type': 'string'},
|
||||
},
|
||||
'required': ['request'],
|
||||
},
|
||||
'response_json_schema': {'type': 'string'},
|
||||
}
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
'env_variables',
|
||||
[
|
||||
'VERTEX', # Test VERTEX_AI variant
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_agent_tool_response_json_schema_with_output_schema_vertex_ai(
|
||||
env_variables,
|
||||
enable_json_schema_feature,
|
||||
):
|
||||
"""Test AgentTool with output schema uses response_json_schema for VERTEX_AI when feature enabled."""
|
||||
|
||||
class CustomOutput(BaseModel):
|
||||
custom_output: str
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
description='A tool agent for testing.',
|
||||
model=testing_utils.MockModel.create(responses=['test response']),
|
||||
output_schema=CustomOutput,
|
||||
)
|
||||
|
||||
agent_tool = AgentTool(agent=tool_agent)
|
||||
declaration = agent_tool._get_declaration()
|
||||
|
||||
assert declaration.model_dump(exclude_none=True) == {
|
||||
'name': 'tool_agent',
|
||||
'description': 'A tool agent for testing.',
|
||||
'parameters_json_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'request': {'type': 'string'},
|
||||
},
|
||||
'required': ['request'],
|
||||
},
|
||||
'response_json_schema': {'type': 'object'},
|
||||
}
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
'env_variables',
|
||||
[
|
||||
'GOOGLE_AI', # Test GEMINI_API variant
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_agent_tool_no_response_json_schema_gemini_api(
|
||||
env_variables,
|
||||
enable_json_schema_feature,
|
||||
):
|
||||
"""Test AgentTool with GEMINI_API variant has no response_json_schema when feature enabled."""
|
||||
|
||||
class CustomOutput(BaseModel):
|
||||
custom_output: str
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
description='A tool agent for testing.',
|
||||
model=testing_utils.MockModel.create(responses=['test response']),
|
||||
output_schema=CustomOutput,
|
||||
)
|
||||
|
||||
agent_tool = AgentTool(agent=tool_agent)
|
||||
declaration = agent_tool._get_declaration()
|
||||
|
||||
# GEMINI_API should not have response_json_schema
|
||||
assert declaration.model_dump(exclude_none=True) == {
|
||||
'name': 'tool_agent',
|
||||
'description': 'A tool agent for testing.',
|
||||
'parameters_json_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'request': {'type': 'string'},
|
||||
},
|
||||
'required': ['request'],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
'env_variables',
|
||||
[
|
||||
'VERTEX', # Test VERTEX_AI variant
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_agent_tool_with_input_schema_uses_json_schema_feature(
|
||||
env_variables,
|
||||
enable_json_schema_feature,
|
||||
):
|
||||
"""Test AgentTool with input_schema uses parameters_json_schema when feature enabled."""
|
||||
|
||||
class CustomInput(BaseModel):
|
||||
custom_input: str
|
||||
|
||||
class CustomOutput(BaseModel):
|
||||
custom_output: str
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
description='A tool agent for testing.',
|
||||
model=testing_utils.MockModel.create(responses=['test response']),
|
||||
input_schema=CustomInput,
|
||||
output_schema=CustomOutput,
|
||||
)
|
||||
|
||||
agent_tool = AgentTool(agent=tool_agent)
|
||||
declaration = agent_tool._get_declaration()
|
||||
|
||||
# When input_schema is provided, build_function_declaration uses Pydantic's
|
||||
# model_json_schema() which includes additional fields like 'title'
|
||||
assert declaration.model_dump(exclude_none=True) == {
|
||||
'name': 'tool_agent',
|
||||
'description': 'A tool agent for testing.',
|
||||
'parameters_json_schema': {
|
||||
'properties': {
|
||||
'custom_input': {'title': 'Custom Input', 'type': 'string'},
|
||||
},
|
||||
'required': ['custom_input'],
|
||||
'title': 'CustomInput',
|
||||
'type': 'object',
|
||||
},
|
||||
'response_json_schema': {'type': 'object'},
|
||||
}
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.features import FeatureName
|
||||
from google.adk.features._feature_registry import temporary_feature_override
|
||||
from google.adk.tools import _automatic_function_calling_util
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.adk.utils.variant_utils import GoogleLLMVariant
|
||||
@@ -435,11 +435,8 @@ class TestJsonSchemaFeatureFlagEnabled:
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_feature_flag(self):
|
||||
"""Enable the JSON_SCHEMA_FOR_FUNC_DECL feature flag for all tests."""
|
||||
with mock.patch.object(
|
||||
_automatic_function_calling_util,
|
||||
'is_feature_enabled',
|
||||
autospec=True,
|
||||
side_effect=lambda f: f == FeatureName.JSON_SCHEMA_FOR_FUNC_DECL,
|
||||
with temporary_feature_override(
|
||||
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
Reference in New Issue
Block a user