feat: Use json schema for base_retrieval_tool, load_artifacts_tool, and load_memory_tool declaration when feature enabled

Co-authored-by: Xuan Yang <xygoogle@google.com>
PiperOrigin-RevId: 858435881
This commit is contained in:
Xuan Yang
2026-01-19 23:36:23 -08:00
committed by Copybara-Service
parent 4b29d15b3e
commit 69ad605bc4
6 changed files with 181 additions and 0 deletions
@@ -24,6 +24,8 @@ from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from ..features import FeatureName
from ..features import is_feature_enabled
from .base_tool import BaseTool
# MIME types Gemini accepts for inline data in requests.
@@ -132,6 +134,20 @@ web UI)."""),
)
def _get_declaration(self) -> types.FunctionDeclaration | None:
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters_json_schema={
'type': 'object',
'properties': {
'artifact_names': {
'type': 'array',
'items': {'type': 'string'},
},
},
},
)
return types.FunctionDeclaration(
name=self.name,
description=self.description,
+14
View File
@@ -21,6 +21,8 @@ from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override
from ..features import FeatureName
from ..features import is_feature_enabled
from ..memory.memory_entry import MemoryEntry
from .function_tool import FunctionTool
from .tool_context import ToolContext
@@ -59,6 +61,18 @@ class LoadMemoryTool(FunctionTool):
@override
def _get_declaration(self) -> types.FunctionDeclaration | None:
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters_json_schema={
'type': 'object',
'properties': {
'query': {'type': 'string'},
},
'required': ['query'],
},
)
return types.FunctionDeclaration(
name=self.name,
description=self.description,
@@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.genai import types
from typing_extensions import override
from ...features import FeatureName
from ...features import is_feature_enabled
from ..base_tool import BaseTool
@@ -22,6 +26,20 @@ class BaseRetrievalTool(BaseTool):
@override
def _get_declaration(self) -> types.FunctionDeclaration:
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters_json_schema={
'type': 'object',
'properties': {
'query': {
'type': 'string',
'description': 'The query to retrieve.',
},
},
},
)
return types.FunctionDeclaration(
name=self.name,
description=self.description,
@@ -0,0 +1,67 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
from google.genai import types
class _TestRetrievalTool(BaseRetrievalTool):
"""Concrete implementation of BaseRetrievalTool for testing."""
def __init__(self):
super().__init__(
name='test_retrieval',
description='A test retrieval tool.',
)
async def run_async(self, *, args, tool_context):
return {'result': 'test'}
def test_get_declaration_with_json_schema_feature_disabled():
"""Test that _get_declaration uses parameters when feature is disabled."""
tool = _TestRetrievalTool()
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, False):
declaration = tool._get_declaration()
assert declaration.name == 'test_retrieval'
assert declaration.description == 'A test retrieval tool.'
assert declaration.parameters_json_schema is None
assert isinstance(declaration.parameters, types.Schema)
assert declaration.parameters.type == types.Type.OBJECT
assert 'query' in declaration.parameters.properties
def test_get_declaration_with_json_schema_feature_enabled():
"""Test that _get_declaration uses parameters_json_schema when feature is enabled."""
tool = _TestRetrievalTool()
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True):
declaration = tool._get_declaration()
assert declaration.name == 'test_retrieval'
assert declaration.description == 'A test retrieval tool.'
assert declaration.parameters is None
assert declaration.parameters_json_schema == {
'type': 'object',
'properties': {
'query': {
'type': 'string',
'description': 'The query to retrieve.',
},
},
}
@@ -14,6 +14,8 @@
import base64
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.models.llm_request import LlmRequest
from google.adk.tools.load_artifacts_tool import _maybe_base64_to_bytes
from google.adk.tools.load_artifacts_tool import load_artifacts_tool
@@ -160,3 +162,21 @@ def test_maybe_base64_to_bytes_returns_none_for_invalid():
"""Invalid base64 strings return None."""
# Single character is invalid (base64 requires length % 4 == 0 after padding)
assert _maybe_base64_to_bytes('x') is None
def test_get_declaration_with_json_schema_feature_enabled():
"""Test that _get_declaration uses parameters_json_schema when feature is enabled."""
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True):
declaration = load_artifacts_tool._get_declaration()
assert declaration.name == 'load_artifacts'
assert declaration.parameters is None
assert declaration.parameters_json_schema == {
'type': 'object',
'properties': {
'artifact_names': {
'type': 'array',
'items': {'type': 'string'},
},
},
}
@@ -0,0 +1,46 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.tools.load_memory_tool import load_memory_tool
from google.genai import types
def test_get_declaration_with_json_schema_feature_disabled():
"""Test that _get_declaration uses parameters when feature is disabled."""
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, False):
declaration = load_memory_tool._get_declaration()
assert declaration.name == 'load_memory'
assert declaration.parameters_json_schema is None
assert isinstance(declaration.parameters, types.Schema)
assert declaration.parameters.type == types.Type.OBJECT
assert 'query' in declaration.parameters.properties
def test_get_declaration_with_json_schema_feature_enabled():
"""Test that _get_declaration uses parameters_json_schema when feature is enabled."""
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True):
declaration = load_memory_tool._get_declaration()
assert declaration.name == 'load_memory'
assert declaration.parameters is None
assert declaration.parameters_json_schema == {
'type': 'object',
'properties': {
'query': {'type': 'string'},
},
'required': ['query'],
}