feat: Use json schema for agent tool declaration when feature enabled

Co-authored-by: Xuan Yang <xygoogle@google.com>
PiperOrigin-RevId: 854329254
This commit is contained in:
Xuan Yang
2026-01-09 13:44:30 -08:00
committed by Copybara-Service
parent 86e7664006
commit 6b241f5ef2
5 changed files with 302 additions and 32 deletions
@@ -14,8 +14,10 @@
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Generator
import warnings
from ..utils.env_utils import is_env_enabled
@@ -264,3 +266,52 @@ def _emit_non_stable_warning_once(
f"[{feature_stage.name.upper()}] feature {feature_name} is enabled."
)
warnings.warn(full_message, category=UserWarning, stacklevel=4)
@contextmanager
def temporary_feature_override(
feature_name: FeatureName,
enabled: bool,
) -> Generator[None, None, None]:
"""Temporarily override a feature's enabled state within a context.
This context manager is useful for testing or temporarily enabling/disabling
a feature within a specific scope. The original state is restored when the
context exits.
Args:
feature_name: The feature name to override.
enabled: Whether the feature should be enabled.
Yields:
None
Example:
```python
from google.adk.features import FeatureName, temporary_feature_override
# Temporarily enable a feature for testing
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True):
# Feature is enabled here
result = some_function_that_checks_feature()
# Feature is restored to original state here
```
"""
config = _get_feature_config(feature_name)
if config is None:
raise ValueError(f"Feature {feature_name} is not registered.")
# Save the original override state
had_override = feature_name in _FEATURE_OVERRIDES
original_value = _FEATURE_OVERRIDES.get(feature_name)
# Apply the temporary override
_FEATURE_OVERRIDES[feature_name] = enabled
try:
yield
finally:
# Restore the original state
if had_override:
_FEATURE_OVERRIDES[feature_name] = original_value
else:
_FEATURE_OVERRIDES.pop(feature_name, None)
+36 -15
View File
@@ -23,6 +23,8 @@ from typing_extensions import override
from . import _automatic_function_calling_util
from ..agents.common_configs import AgentRefConfig
from ..features import FeatureName
from ..features import is_feature_enabled
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..utils.context_utils import Aclosing
from ._forwarding_artifact_service import ForwardingArtifactService
@@ -82,29 +84,48 @@ class AgentTool(BaseTool):
# Override the description with the agent's description
result.description = self.agent.description
else:
result = types.FunctionDeclaration(
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'request': types.Schema(
type=types.Type.STRING,
),
},
required=['request'],
),
description=self.agent.description,
name=self.name,
)
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
result = types.FunctionDeclaration(
name=self.name,
description=self.agent.description,
parameters_json_schema={
'type': 'object',
'properties': {
'request': {'type': 'string'},
},
'required': ['request'],
},
)
else:
result = types.FunctionDeclaration(
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'request': types.Schema(
type=types.Type.STRING,
),
},
required=['request'],
),
description=self.agent.description,
name=self.name,
)
# Set response schema for non-GEMINI_API variants
if self._api_variant != GoogleLLMVariant.GEMINI_API:
# Determine response type based on agent's output schema
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
# Agent has structured output schema - response is an object
result.response = types.Schema(type=types.Type.OBJECT)
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
result.response_json_schema = {'type': 'object'}
else:
result.response = types.Schema(type=types.Type.OBJECT)
else:
# Agent returns text - response is a string
result.response = types.Schema(type=types.Type.STRING)
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
result.response_json_schema = {'type': 'string'}
else:
result.response = types.Schema(type=types.Type.STRING)
result.name = self.name
return result
+14 -11
View File
@@ -22,6 +22,8 @@ from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
from google.adk.tools.tool_context import ToolContext
@@ -129,17 +131,16 @@ class TestMCPTool:
assert declaration.description == "Test tool description"
assert declaration.parameters is not None
def test_get_declaration_with_json_schema_for_func_decl_enabled(
self, monkeypatch
):
def test_get_declaration_with_json_schema_for_func_decl_enabled(self):
"""Test function declaration generation with json schema for func decl enabled."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
with monkeypatch.context() as m:
m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true")
with temporary_feature_override(
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True
):
declaration = tool._get_declaration()
assert isinstance(declaration, FunctionDeclaration)
@@ -151,7 +152,7 @@ class TestMCPTool:
assert declaration.response_json_schema is None
def test_get_declaration_with_output_schema_and_json_schema_for_func_decl_enabled(
self, monkeypatch
self,
):
"""Test function declaration generation with an output schema and json schema for func decl enabled."""
output_schema = {
@@ -169,8 +170,9 @@ class TestMCPTool:
mcp_session_manager=self.mock_session_manager,
)
with monkeypatch.context() as m:
m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true")
with temporary_feature_override(
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True
):
declaration = tool._get_declaration()
assert isinstance(declaration, FunctionDeclaration)
@@ -178,7 +180,7 @@ class TestMCPTool:
assert declaration.response_json_schema == output_schema
def test_get_declaration_with_empty_output_schema_and_json_schema_for_func_decl_enabled(
self, monkeypatch
self,
):
"""Test function declaration with an empty output schema and json schema for func decl enabled."""
tool = MCPTool(
@@ -186,8 +188,9 @@ class TestMCPTool:
mcp_session_manager=self.mock_session_manager,
)
with monkeypatch.context() as m:
m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true")
with temporary_feature_override(
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True
):
declaration = tool._get_declaration()
assert declaration.response is None
+198
View File
@@ -21,6 +21,8 @@ from google.adk.agents.llm_agent import Agent
from google.adk.agents.run_config import RunConfig
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
@@ -33,6 +35,7 @@ from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
from google.genai.types import Part
from pydantic import BaseModel
import pytest
from pytest import mark
from .. import testing_utils
@@ -702,3 +705,198 @@ def test_agent_tool_description_with_input_schema():
# The description should come from the agent, not the Pydantic model
assert declaration.description == agent_description
@pytest.fixture
def enable_json_schema_feature():
"""Fixture to enable JSON_SCHEMA_FOR_FUNC_DECL feature for a test."""
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True):
yield
def test_agent_tool_no_schema_with_json_schema_feature(
enable_json_schema_feature,
):
"""Test AgentTool without input_schema uses parameters_json_schema when feature enabled."""
tool_agent = Agent(
name='tool_agent',
description='A tool agent for testing.',
model=testing_utils.MockModel.create(responses=['test response']),
)
agent_tool = AgentTool(agent=tool_agent)
declaration = agent_tool._get_declaration()
assert declaration.model_dump(exclude_none=True) == {
'name': 'tool_agent',
'description': 'A tool agent for testing.',
'parameters_json_schema': {
'type': 'object',
'properties': {
'request': {'type': 'string'},
},
'required': ['request'],
},
}
@mark.parametrize(
'env_variables',
[
'VERTEX', # Test VERTEX_AI variant
],
indirect=True,
)
def test_agent_tool_response_json_schema_no_output_schema_vertex_ai(
env_variables,
enable_json_schema_feature,
):
"""Test AgentTool with no output schema uses response_json_schema for VERTEX_AI when feature enabled."""
tool_agent = Agent(
name='tool_agent',
description='A tool agent for testing.',
model=testing_utils.MockModel.create(responses=['test response']),
)
agent_tool = AgentTool(agent=tool_agent)
declaration = agent_tool._get_declaration()
assert declaration.model_dump(exclude_none=True) == {
'name': 'tool_agent',
'description': 'A tool agent for testing.',
'parameters_json_schema': {
'type': 'object',
'properties': {
'request': {'type': 'string'},
},
'required': ['request'],
},
'response_json_schema': {'type': 'string'},
}
@mark.parametrize(
'env_variables',
[
'VERTEX', # Test VERTEX_AI variant
],
indirect=True,
)
def test_agent_tool_response_json_schema_with_output_schema_vertex_ai(
env_variables,
enable_json_schema_feature,
):
"""Test AgentTool with output schema uses response_json_schema for VERTEX_AI when feature enabled."""
class CustomOutput(BaseModel):
custom_output: str
tool_agent = Agent(
name='tool_agent',
description='A tool agent for testing.',
model=testing_utils.MockModel.create(responses=['test response']),
output_schema=CustomOutput,
)
agent_tool = AgentTool(agent=tool_agent)
declaration = agent_tool._get_declaration()
assert declaration.model_dump(exclude_none=True) == {
'name': 'tool_agent',
'description': 'A tool agent for testing.',
'parameters_json_schema': {
'type': 'object',
'properties': {
'request': {'type': 'string'},
},
'required': ['request'],
},
'response_json_schema': {'type': 'object'},
}
@mark.parametrize(
'env_variables',
[
'GOOGLE_AI', # Test GEMINI_API variant
],
indirect=True,
)
def test_agent_tool_no_response_json_schema_gemini_api(
env_variables,
enable_json_schema_feature,
):
"""Test AgentTool with GEMINI_API variant has no response_json_schema when feature enabled."""
class CustomOutput(BaseModel):
custom_output: str
tool_agent = Agent(
name='tool_agent',
description='A tool agent for testing.',
model=testing_utils.MockModel.create(responses=['test response']),
output_schema=CustomOutput,
)
agent_tool = AgentTool(agent=tool_agent)
declaration = agent_tool._get_declaration()
# GEMINI_API should not have response_json_schema
assert declaration.model_dump(exclude_none=True) == {
'name': 'tool_agent',
'description': 'A tool agent for testing.',
'parameters_json_schema': {
'type': 'object',
'properties': {
'request': {'type': 'string'},
},
'required': ['request'],
},
}
@mark.parametrize(
'env_variables',
[
'VERTEX', # Test VERTEX_AI variant
],
indirect=True,
)
def test_agent_tool_with_input_schema_uses_json_schema_feature(
env_variables,
enable_json_schema_feature,
):
"""Test AgentTool with input_schema uses parameters_json_schema when feature enabled."""
class CustomInput(BaseModel):
custom_input: str
class CustomOutput(BaseModel):
custom_output: str
tool_agent = Agent(
name='tool_agent',
description='A tool agent for testing.',
model=testing_utils.MockModel.create(responses=['test response']),
input_schema=CustomInput,
output_schema=CustomOutput,
)
agent_tool = AgentTool(agent=tool_agent)
declaration = agent_tool._get_declaration()
# When input_schema is provided, build_function_declaration uses Pydantic's
# model_json_schema() which includes additional fields like 'title'
assert declaration.model_dump(exclude_none=True) == {
'name': 'tool_agent',
'description': 'A tool agent for testing.',
'parameters_json_schema': {
'properties': {
'custom_input': {'title': 'Custom Input', 'type': 'string'},
},
'required': ['custom_input'],
'title': 'CustomInput',
'type': 'object',
},
'response_json_schema': {'type': 'object'},
}
@@ -13,9 +13,9 @@
# limitations under the License.
from enum import Enum
from unittest import mock
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.tools import _automatic_function_calling_util
from google.adk.tools.tool_context import ToolContext
from google.adk.utils.variant_utils import GoogleLLMVariant
@@ -435,11 +435,8 @@ class TestJsonSchemaFeatureFlagEnabled:
@pytest.fixture(autouse=True)
def enable_feature_flag(self):
"""Enable the JSON_SCHEMA_FOR_FUNC_DECL feature flag for all tests."""
with mock.patch.object(
_automatic_function_calling_util,
'is_feature_enabled',
autospec=True,
side_effect=lambda f: f == FeatureName.JSON_SCHEMA_FOR_FUNC_DECL,
with temporary_feature_override(
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True
):
yield