diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 90c09d65..2ab01306 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -14,8 +14,10 @@ from __future__ import annotations +from contextlib import contextmanager from dataclasses import dataclass from enum import Enum +from typing import Generator import warnings from ..utils.env_utils import is_env_enabled @@ -264,3 +266,52 @@ def _emit_non_stable_warning_once( f"[{feature_stage.name.upper()}] feature {feature_name} is enabled." ) warnings.warn(full_message, category=UserWarning, stacklevel=4) + + +@contextmanager +def temporary_feature_override( + feature_name: FeatureName, + enabled: bool, +) -> Generator[None, None, None]: + """Temporarily override a feature's enabled state within a context. + + This context manager is useful for testing or temporarily enabling/disabling + a feature within a specific scope. The original state is restored when the + context exits. + + Args: + feature_name: The feature name to override. + enabled: Whether the feature should be enabled. + + Yields: + None + + Example: + ```python + from google.adk.features import FeatureName, temporary_feature_override + + # Temporarily enable a feature for testing + with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True): + # Feature is enabled here + result = some_function_that_checks_feature() + # Feature is restored to original state here + ``` + """ + config = _get_feature_config(feature_name) + if config is None: + raise ValueError(f"Feature {feature_name} is not registered.") + + # Save the original override state + had_override = feature_name in _FEATURE_OVERRIDES + original_value = _FEATURE_OVERRIDES.get(feature_name) + + # Apply the temporary override + _FEATURE_OVERRIDES[feature_name] = enabled + try: + yield + finally: + # Restore the original state + if had_override: + _FEATURE_OVERRIDES[feature_name] = original_value + else: + _FEATURE_OVERRIDES.pop(feature_name, None) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 799f0ea4..ea40bee0 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -23,6 +23,8 @@ from typing_extensions import override from . import _automatic_function_calling_util from ..agents.common_configs import AgentRefConfig +from ..features import FeatureName +from ..features import is_feature_enabled from ..memory.in_memory_memory_service import InMemoryMemoryService from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService @@ -82,29 +84,48 @@ class AgentTool(BaseTool): # Override the description with the agent's description result.description = self.agent.description else: - result = types.FunctionDeclaration( - parameters=types.Schema( - type=types.Type.OBJECT, - properties={ - 'request': types.Schema( - type=types.Type.STRING, - ), - }, - required=['request'], - ), - description=self.agent.description, - name=self.name, - ) + if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL): + result = types.FunctionDeclaration( + name=self.name, + description=self.agent.description, + parameters_json_schema={ + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + ) + else: + result = types.FunctionDeclaration( + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + 'request': types.Schema( + type=types.Type.STRING, + ), + }, + required=['request'], + ), + description=self.agent.description, + name=self.name, + ) # Set response schema for non-GEMINI_API variants if self._api_variant != GoogleLLMVariant.GEMINI_API: # Determine response type based on agent's output schema if isinstance(self.agent, LlmAgent) and self.agent.output_schema: # Agent has structured output schema - response is an object - result.response = types.Schema(type=types.Type.OBJECT) + if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL): + result.response_json_schema = {'type': 'object'} + else: + result.response = types.Schema(type=types.Type.OBJECT) else: # Agent returns text - response is a string - result.response = types.Schema(type=types.Type.STRING) + if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL): + result.response_json_schema = {'type': 'string'} + else: + result.response = types.Schema(type=types.Type.STRING) result.name = self.name return result diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 1284e73b..23583019 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -22,6 +22,8 @@ from google.adk.auth.auth_credential import HttpAuth from google.adk.auth.auth_credential import HttpCredentials from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_credential import ServiceAccount +from google.adk.features import FeatureName +from google.adk.features._feature_registry import temporary_feature_override from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.tool_context import ToolContext @@ -129,17 +131,16 @@ class TestMCPTool: assert declaration.description == "Test tool description" assert declaration.parameters is not None - def test_get_declaration_with_json_schema_for_func_decl_enabled( - self, monkeypatch - ): + def test_get_declaration_with_json_schema_for_func_decl_enabled(self): """Test function declaration generation with json schema for func decl enabled.""" tool = MCPTool( mcp_tool=self.mock_mcp_tool, mcp_session_manager=self.mock_session_manager, ) - with monkeypatch.context() as m: - m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true") + with temporary_feature_override( + FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True + ): declaration = tool._get_declaration() assert isinstance(declaration, FunctionDeclaration) @@ -151,7 +152,7 @@ class TestMCPTool: assert declaration.response_json_schema is None def test_get_declaration_with_output_schema_and_json_schema_for_func_decl_enabled( - self, monkeypatch + self, ): """Test function declaration generation with an output schema and json schema for func decl enabled.""" output_schema = { @@ -169,8 +170,9 @@ class TestMCPTool: mcp_session_manager=self.mock_session_manager, ) - with monkeypatch.context() as m: - m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true") + with temporary_feature_override( + FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True + ): declaration = tool._get_declaration() assert isinstance(declaration, FunctionDeclaration) @@ -178,7 +180,7 @@ class TestMCPTool: assert declaration.response_json_schema == output_schema def test_get_declaration_with_empty_output_schema_and_json_schema_for_func_decl_enabled( - self, monkeypatch + self, ): """Test function declaration with an empty output schema and json schema for func decl enabled.""" tool = MCPTool( @@ -186,8 +188,9 @@ class TestMCPTool: mcp_session_manager=self.mock_session_manager, ) - with monkeypatch.context() as m: - m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true") + with temporary_feature_override( + FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True + ): declaration = tool._get_declaration() assert declaration.response is None diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index a9723b43..48a7a995 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -21,6 +21,8 @@ from google.adk.agents.llm_agent import Agent from google.adk.agents.run_config import RunConfig from google.adk.agents.sequential_agent import SequentialAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.features import FeatureName +from google.adk.features._feature_registry import temporary_feature_override from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse @@ -33,6 +35,7 @@ from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai.types import Part from pydantic import BaseModel +import pytest from pytest import mark from .. import testing_utils @@ -702,3 +705,198 @@ def test_agent_tool_description_with_input_schema(): # The description should come from the agent, not the Pydantic model assert declaration.description == agent_description + + +@pytest.fixture +def enable_json_schema_feature(): + """Fixture to enable JSON_SCHEMA_FOR_FUNC_DECL feature for a test.""" + with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True): + yield + + +def test_agent_tool_no_schema_with_json_schema_feature( + enable_json_schema_feature, +): + """Test AgentTool without input_schema uses parameters_json_schema when feature enabled.""" + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + } + + +@mark.parametrize( + 'env_variables', + [ + 'VERTEX', # Test VERTEX_AI variant + ], + indirect=True, +) +def test_agent_tool_response_json_schema_no_output_schema_vertex_ai( + env_variables, + enable_json_schema_feature, +): + """Test AgentTool with no output schema uses response_json_schema for VERTEX_AI when feature enabled.""" + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + 'response_json_schema': {'type': 'string'}, + } + + +@mark.parametrize( + 'env_variables', + [ + 'VERTEX', # Test VERTEX_AI variant + ], + indirect=True, +) +def test_agent_tool_response_json_schema_with_output_schema_vertex_ai( + env_variables, + enable_json_schema_feature, +): + """Test AgentTool with output schema uses response_json_schema for VERTEX_AI when feature enabled.""" + + class CustomOutput(BaseModel): + custom_output: str + + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + output_schema=CustomOutput, + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + 'response_json_schema': {'type': 'object'}, + } + + +@mark.parametrize( + 'env_variables', + [ + 'GOOGLE_AI', # Test GEMINI_API variant + ], + indirect=True, +) +def test_agent_tool_no_response_json_schema_gemini_api( + env_variables, + enable_json_schema_feature, +): + """Test AgentTool with GEMINI_API variant has no response_json_schema when feature enabled.""" + + class CustomOutput(BaseModel): + custom_output: str + + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + output_schema=CustomOutput, + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + # GEMINI_API should not have response_json_schema + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + } + + +@mark.parametrize( + 'env_variables', + [ + 'VERTEX', # Test VERTEX_AI variant + ], + indirect=True, +) +def test_agent_tool_with_input_schema_uses_json_schema_feature( + env_variables, + enable_json_schema_feature, +): + """Test AgentTool with input_schema uses parameters_json_schema when feature enabled.""" + + class CustomInput(BaseModel): + custom_input: str + + class CustomOutput(BaseModel): + custom_output: str + + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + input_schema=CustomInput, + output_schema=CustomOutput, + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + # When input_schema is provided, build_function_declaration uses Pydantic's + # model_json_schema() which includes additional fields like 'title' + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'properties': { + 'custom_input': {'title': 'Custom Input', 'type': 'string'}, + }, + 'required': ['custom_input'], + 'title': 'CustomInput', + 'type': 'object', + }, + 'response_json_schema': {'type': 'object'}, + } diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index dd85b20c..3797c4ed 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -13,9 +13,9 @@ # limitations under the License. from enum import Enum -from unittest import mock from google.adk.features import FeatureName +from google.adk.features._feature_registry import temporary_feature_override from google.adk.tools import _automatic_function_calling_util from google.adk.tools.tool_context import ToolContext from google.adk.utils.variant_utils import GoogleLLMVariant @@ -435,11 +435,8 @@ class TestJsonSchemaFeatureFlagEnabled: @pytest.fixture(autouse=True) def enable_feature_flag(self): """Enable the JSON_SCHEMA_FOR_FUNC_DECL feature flag for all tests.""" - with mock.patch.object( - _automatic_function_calling_util, - 'is_feature_enabled', - autospec=True, - side_effect=lambda f: f == FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, + with temporary_feature_override( + FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True ): yield