From eaf50ce37e9cb1d4b5201ccf84fbe9f91f731644 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 19 Feb 2026 12:01:58 -0800 Subject: [PATCH 001/102] chore: provide a way to disable model check for builtin tools Co-authored-by: George Weale PiperOrigin-RevId: 872503435 --- .../code_executors/built_in_code_executor.py | 4 +- .../adk/tools/enterprise_search_tool.py | 9 +- .../adk/tools/google_maps_grounding_tool.py | 4 +- src/google/adk/tools/google_search_tool.py | 4 +- .../retrieval/vertex_ai_rag_retrieval.py | 4 +- src/google/adk/tools/url_context_tool.py | 4 +- src/google/adk/tools/vertex_ai_search_tool.py | 9 +- src/google/adk/utils/model_name_utils.py | 13 +++ .../test_built_in_code_executor.py | 16 ++++ .../retrieval/test_vertex_ai_rag_retrieval.py | 40 ++++++++ .../tools/test_enterprise_web_search_tool.py | 19 ++++ .../tools/test_google_maps_grounding_tool.py | 92 +++++++++++++++++++ .../tools/test_google_search_tool.py | 21 +++++ .../unittests/tools/test_url_context_tool.py | 21 +++++ .../tools/test_vertex_ai_search_tool.py | 23 +++++ .../unittests/utils/test_model_name_utils.py | 13 +++ 16 files changed, 285 insertions(+), 11 deletions(-) create mode 100644 tests/unittests/tools/test_google_maps_grounding_tool.py diff --git a/src/google/adk/code_executors/built_in_code_executor.py b/src/google/adk/code_executors/built_in_code_executor.py index 50a0b9f4..a4e32034 100644 --- a/src/google/adk/code_executors/built_in_code_executor.py +++ b/src/google/adk/code_executors/built_in_code_executor.py @@ -20,6 +20,7 @@ from typing_extensions import override from ..agents.invocation_context import InvocationContext from ..models import LlmRequest from ..utils.model_name_utils import is_gemini_2_or_above +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_code_executor import BaseCodeExecutor from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult @@ -42,7 +43,8 @@ class BuiltInCodeExecutor(BaseCodeExecutor): def process_llm_request(self, llm_request: LlmRequest) -> None: """Pre-process the LLM request for Gemini 2.0+ models to use the code execution tool.""" - if is_gemini_2_or_above(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + if is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] llm_request.config.tools.append( diff --git a/src/google/adk/tools/enterprise_search_tool.py b/src/google/adk/tools/enterprise_search_tool.py index 4f7a0d7f..c114fdb4 100644 --- a/src/google/adk/tools/enterprise_search_tool.py +++ b/src/google/adk/tools/enterprise_search_tool.py @@ -21,6 +21,7 @@ from typing_extensions import override from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -54,14 +55,16 @@ class EnterpriseWebSearchTool(BaseTool): tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if is_gemini_model(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + if is_gemini_model(llm_request.model) or model_check_disabled: if is_gemini_1_model(llm_request.model) and llm_request.config.tools: raise ValueError( 'Enterprise Web Search tool cannot be used with other tools in' ' Gemini 1.x.' ) - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.tools = llm_request.config.tools or [] llm_request.config.tools.append( types.Tool(enterprise_web_search=types.EnterpriseWebSearch()) ) diff --git a/src/google/adk/tools/google_maps_grounding_tool.py b/src/google/adk/tools/google_maps_grounding_tool.py index bade0a33..d4b105ec 100644 --- a/src/google/adk/tools/google_maps_grounding_tool.py +++ b/src/google/adk/tools/google_maps_grounding_tool.py @@ -21,6 +21,7 @@ from typing_extensions import override from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -49,13 +50,14 @@ class GoogleMapsGroundingTool(BaseTool): tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): raise ValueError( 'Google Maps grounding tool cannot be used with Gemini 1.x models.' ) - elif is_gemini_model(llm_request.model): + elif is_gemini_model(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(google_maps=types.GoogleMaps()) ) diff --git a/src/google/adk/tools/google_search_tool.py b/src/google/adk/tools/google_search_tool.py index 406ad218..1c11e091 100644 --- a/src/google/adk/tools/google_search_tool.py +++ b/src/google/adk/tools/google_search_tool.py @@ -21,6 +21,7 @@ from typing_extensions import override from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -67,6 +68,7 @@ class GoogleSearchTool(BaseTool): if self.model is not None: llm_request.model = self.model + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): @@ -77,7 +79,7 @@ class GoogleSearchTool(BaseTool): llm_request.config.tools.append( types.Tool(google_search_retrieval=types.GoogleSearchRetrieval()) ) - elif is_gemini_model(llm_request.model): + elif is_gemini_model(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(google_search=types.GoogleSearch()) ) diff --git a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py index 206819a9..4d564ca1 100644 --- a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +++ b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py @@ -24,6 +24,7 @@ from google.genai import types from typing_extensions import override from ...utils.model_name_utils import is_gemini_2_or_above +from ...utils.model_name_utils import is_gemini_model_id_check_disabled from ..tool_context import ToolContext from .base_retrieval_tool import BaseRetrievalTool @@ -63,7 +64,8 @@ class VertexAiRagRetrieval(BaseRetrievalTool): llm_request: LlmRequest, ) -> None: # Use Gemini built-in Vertex AI RAG tool for Gemini 2 models. - if is_gemini_2_or_above(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + if is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config = ( types.GenerateContentConfig() if not llm_request.config diff --git a/src/google/adk/tools/url_context_tool.py b/src/google/adk/tools/url_context_tool.py index fcdf76da..5e923e74 100644 --- a/src/google/adk/tools/url_context_tool.py +++ b/src/google/adk/tools/url_context_tool.py @@ -21,6 +21,7 @@ from typing_extensions import override from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_2_or_above +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -46,11 +47,12 @@ class UrlContextTool(BaseTool): tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): raise ValueError('Url context tool cannot be used in Gemini 1.x.') - elif is_gemini_2_or_above(llm_request.model): + elif is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(url_context=types.UrlContext()) ) diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index 91fe60e5..46104c5e 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -24,6 +24,7 @@ from typing_extensions import override from ..agents.readonly_context import ReadonlyContext from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -141,14 +142,16 @@ class VertexAiSearchTool(BaseTool): tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if is_gemini_model(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + if is_gemini_model(llm_request.model) or model_check_disabled: if is_gemini_1_model(llm_request.model) and llm_request.config.tools: raise ValueError( 'Vertex AI search tool cannot be used with other tools in Gemini' ' 1.x.' ) - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.tools = llm_request.config.tools or [] # Build the search config (can be overridden by subclasses) vertex_ai_search_config = self._build_vertex_ai_search_config( diff --git a/src/google/adk/utils/model_name_utils.py b/src/google/adk/utils/model_name_utils.py index 4960b0b7..57103fb2 100644 --- a/src/google/adk/utils/model_name_utils.py +++ b/src/google/adk/utils/model_name_utils.py @@ -22,6 +22,19 @@ from typing import Optional from packaging.version import InvalidVersion from packaging.version import Version +from .env_utils import is_env_enabled + +_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR = 'ADK_DISABLE_GEMINI_MODEL_ID_CHECK' + + +def is_gemini_model_id_check_disabled() -> bool: + """Returns True when Gemini model-id validation should be bypassed. + + This opt-in environment variable is intended for internal usage where model + ids may not follow the public ``gemini-*`` naming convention. + """ + return is_env_enabled(_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR) + def extract_model_name(model_string: str) -> str: """Extract the actual model name from either simple or path-based format. diff --git a/tests/unittests/code_executors/test_built_in_code_executor.py b/tests/unittests/code_executors/test_built_in_code_executor.py index 58f54c7c..cbf128fb 100644 --- a/tests/unittests/code_executors/test_built_in_code_executor.py +++ b/tests/unittests/code_executors/test_built_in_code_executor.py @@ -97,6 +97,22 @@ def test_process_llm_request_non_gemini_2_model( ) +def test_process_llm_request_non_gemini_2_model_with_disabled_check( + built_in_executor: BuiltInCodeExecutor, + monkeypatch, +): + """Tests non-Gemini models pass when model-id check is disabled.""" + monkeypatch.setenv("ADK_DISABLE_GEMINI_MODEL_ID_CHECK", "true") + llm_request = LlmRequest(model="internal-model-v1") + + built_in_executor.process_llm_request(llm_request) + + assert llm_request.config is not None + assert llm_request.config.tools == [ + types.Tool(code_execution=types.ToolCodeExecution()) + ] + + def test_process_llm_request_no_model_name( built_in_executor: BuiltInCodeExecutor, ): diff --git a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py index 3b5aa26f..0a86d07c 100644 --- a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py +++ b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py @@ -145,3 +145,43 @@ def test_vertex_rag_retrieval_for_gemini_2_x(): ) ] assert 'rag_retrieval' not in mockModel.requests[0].tools_dict + + +def test_vertex_rag_retrieval_for_non_gemini_with_disabled_check(monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + responses = [ + 'response1', + ] + mockModel = testing_utils.MockModel.create(responses=responses) + mockModel.model = 'internal-model-v1' + + agent = Agent( + name='root_agent', + model=mockModel, + tools=[ + VertexAiRagRetrieval( + name='rag_retrieval', + description='rag_retrieval', + rag_corpora=[ + 'projects/123456789/locations/us-central1/ragCorpora/1234567890' + ], + ) + ], + ) + runner = testing_utils.InMemoryRunner(agent) + runner.run('test1') + + assert len(mockModel.requests) == 1 + assert len(mockModel.requests[0].config.tools) == 1 + assert mockModel.requests[0].config.tools == [ + types.Tool( + retrieval=types.Retrieval( + vertex_rag_store=types.VertexRagStore( + rag_corpora=[ + 'projects/123456789/locations/us-central1/ragCorpora/1234567890' + ] + ) + ) + ) + ] + assert 'rag_retrieval' not in mockModel.requests[0].tools_dict diff --git a/tests/unittests/tools/test_enterprise_web_search_tool.py b/tests/unittests/tools/test_enterprise_web_search_tool.py index ed471596..7b28d858 100644 --- a/tests/unittests/tools/test_enterprise_web_search_tool.py +++ b/tests/unittests/tools/test_enterprise_web_search_tool.py @@ -76,6 +76,25 @@ async def test_process_llm_request_failure_with_non_gemini_models(): assert 'is not supported for model' in str(exc_info.value) +@pytest.mark.asyncio +async def test_process_llm_request_non_gemini_with_disabled_check(monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = EnterpriseWebSearchTool() + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + tool_context = await _create_tool_context() + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert ( + llm_request.config.tools[0].enterprise_web_search + == types.EnterpriseWebSearch() + ) + + @pytest.mark.asyncio async def test_process_llm_request_failure_with_multiple_tools_gemini_1_models(): tool = EnterpriseWebSearchTool() diff --git a/tests/unittests/tools/test_google_maps_grounding_tool.py b/tests/unittests/tools/test_google_maps_grounding_tool.py new file mode 100644 index 00000000..0cd2c4fa --- /dev/null +++ b/tests/unittests/tools/test_google_maps_grounding_tool.py @@ -0,0 +1,92 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.google_maps_grounding_tool import GoogleMapsGroundingTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + + +async def _create_tool_context() -> ToolContext: + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=agent, + session=session, + session_service=session_service, + ) + return ToolContext(invocation_context=invocation_context) + + +class TestGoogleMapsGroundingTool: + """Tests for GoogleMapsGroundingTool.""" + + @pytest.mark.asyncio + async def test_process_llm_request_with_gemini_2_model(self): + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='gemini-2.5-pro', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_maps is not None + + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_raises_error(self): + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='claude-3-sonnet', config=types.GenerateContentConfig() + ) + + with pytest.raises( + ValueError, + match='Google maps tool is not supported for model claude-3-sonnet', + ): + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_and_disabled_check( + self, monkeypatch + ): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_maps is not None diff --git a/tests/unittests/tools/test_google_search_tool.py b/tests/unittests/tools/test_google_search_tool.py index ad5d46b5..d71061b8 100644 --- a/tests/unittests/tools/test_google_search_tool.py +++ b/tests/unittests/tools/test_google_search_tool.py @@ -268,6 +268,27 @@ class TestGoogleSearchTool: tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = GoogleSearchTool() + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_search is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/tools/test_url_context_tool.py b/tests/unittests/tools/test_url_context_tool.py index 53ee7e62..8fd44b59 100644 --- a/tests/unittests/tools/test_url_context_tool.py +++ b/tests/unittests/tools/test_url_context_tool.py @@ -190,6 +190,27 @@ class TestUrlContextTool: tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = UrlContextTool() + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].url_context is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/tools/test_vertex_ai_search_tool.py b/tests/unittests/tools/test_vertex_ai_search_tool.py index 3ade634d..b15d3a1f 100644 --- a/tests/unittests/tools/test_vertex_ai_search_tool.py +++ b/tests/unittests/tools/test_vertex_ai_search_tool.py @@ -376,6 +376,29 @@ class TestVertexAiSearchTool: tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = VertexAiSearchTool(data_store_id='test_data_store') + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/utils/test_model_name_utils.py b/tests/unittests/utils/test_model_name_utils.py index cbac37e3..2af1584b 100644 --- a/tests/unittests/utils/test_model_name_utils.py +++ b/tests/unittests/utils/test_model_name_utils.py @@ -18,6 +18,7 @@ from google.adk.utils.model_name_utils import extract_model_name from google.adk.utils.model_name_utils import is_gemini_1_model from google.adk.utils.model_name_utils import is_gemini_2_or_above from google.adk.utils.model_name_utils import is_gemini_model +from google.adk.utils.model_name_utils import is_gemini_model_id_check_disabled class TestExtractModelName: @@ -318,3 +319,15 @@ class TestModelNameUtilsIntegration: f'Inconsistent Gemini 2.0+ classification for {simple_model} vs' f' {path_model}' ) + + +class TestGeminiModelIdCheckFlag: + """Tests for Gemini model-id check override flag.""" + + def test_default_is_disabled(self, monkeypatch): + monkeypatch.delenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', raising=False) + assert is_gemini_model_id_check_disabled() is False + + def test_true_enables_check_bypass(self, monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + assert is_gemini_model_id_check_disabled() is True From 4a88804ec7d17fb4031b238c362f27d240df0a13 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 19 Feb 2026 13:59:08 -0800 Subject: [PATCH 002/102] feat: Add support for memory consolidation via Vertex AI Memory Bank This change allows `add_memory` to use the `memories.generate` API with `direct_memories_source` when `custom_metadata["enable_consolidation"]` is set to True. This enables server-side consolidation of the provided memories Co-authored-by: George Weale PiperOrigin-RevId: 872554004 --- .../memory/vertex_ai_memory_bank_service.py | 99 ++++++++++-- .../test_vertex_ai_memory_bank_service.py | 142 ++++++++++++++++++ 2 files changed, 231 insertions(+), 10 deletions(-) diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 7bb18efa..2218c874 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -65,6 +65,11 @@ _CREATE_MEMORY_CONFIG_FALLBACK_KEYS = frozenset({ 'wait_for_completion', }) +_ENABLE_CONSOLIDATION_KEY = 'enable_consolidation' +# Vertex docs for GenerateMemoriesRequest.DirectMemoriesSource allow +# at most 5 direct_memories per request. +_MAX_DIRECT_MEMORIES_PER_GENERATE_CALL = 5 + def _supports_generate_memories_metadata() -> bool: """Returns whether installed Vertex SDK supports config.metadata.""" @@ -160,6 +165,11 @@ class VertexAiMemoryBankService(BaseMemoryService): not use Google AI Studio API key for this field. For more details, visit https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview """ + if not agent_engine_id: + raise ValueError( + 'agent_engine_id is required for VertexAiMemoryBankService.' + ) + self._project = project self._location = location self._agent_engine_id = agent_engine_id @@ -219,7 +229,22 @@ class VertexAiMemoryBankService(BaseMemoryService): memories: Sequence[MemoryEntry], custom_metadata: Mapping[str, object] | None = None, ) -> None: - """Adds explicit memory items via Vertex memories.create.""" + """Adds explicit memory items using Vertex Memory Bank. + + By default, this writes directly via `memories.create`. + If `custom_metadata["enable_consolidation"]` is set to True, this uses + `memories.generate` with `direct_memories_source` so provided memories are + consolidated server-side. + """ + if _is_consolidation_enabled(custom_metadata): + await self._add_memories_via_generate_direct_memories_source( + app_name=app_name, + user_id=user_id, + memories=memories, + custom_metadata=custom_metadata, + ) + return + await self._add_memories_via_create( app_name=app_name, user_id=user_id, @@ -235,9 +260,6 @@ class VertexAiMemoryBankService(BaseMemoryService): events_to_process: Sequence[Event], custom_metadata: Mapping[str, object] | None = None, ) -> None: - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - direct_events = [] for event in events_to_process: if _should_filter_out_event(event.content): @@ -272,9 +294,6 @@ class VertexAiMemoryBankService(BaseMemoryService): custom_metadata: Mapping[str, object] | None = None, ) -> None: """Adds direct memory items without server-side extraction.""" - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - normalized_memories = _normalize_memories_for_create(memories) api_client = self._get_api_client() for index, memory in enumerate(normalized_memories): @@ -300,11 +319,41 @@ class VertexAiMemoryBankService(BaseMemoryService): logger.info('Create memory response received.') logger.debug('Create memory response: %s', operation) + async def _add_memories_via_generate_direct_memories_source( + self, + *, + app_name: str, + user_id: str, + memories: Sequence[MemoryEntry], + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + """Adds memories via generate API with direct_memories_source.""" + normalized_memories = _normalize_memories_for_create(memories) + memory_texts = [ + _memory_entry_to_fact(m, index=i) + for i, m in enumerate(normalized_memories) + ] + api_client = self._get_api_client() + config = _build_generate_memories_config(custom_metadata) + for memory_batch in _iter_memory_batches(memory_texts): + operation = await api_client.agent_engines.memories.generate( + name='reasoningEngines/' + self._agent_engine_id, + direct_memories_source={ + 'direct_memories': [ + {'fact': memory_text} for memory_text in memory_batch + ] + }, + scope={ + 'app_name': app_name, + 'user_id': user_id, + }, + config=config, + ) + logger.info('Generate direct memory response received.') + logger.debug('Generate direct memory response: %s', operation) + @override async def search_memory(self, *, app_name: str, user_id: str, query: str): - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - api_client = self._get_api_client() retrieved_memories_iterator = ( await api_client.agent_engines.memories.retrieve( @@ -379,6 +428,8 @@ def _build_generate_memories_config( metadata_by_key: dict[str, object] = {} for key, value in custom_metadata.items(): + if key == _ENABLE_CONSOLIDATION_KEY: + continue if key == 'ttl': if value is None: continue @@ -456,6 +507,8 @@ def _build_create_memory_config( metadata_by_key: dict[str, object] = {} custom_revision_labels: dict[str, str] = {} for key, value in (custom_metadata or {}).items(): + if key == _ENABLE_CONSOLIDATION_KEY: + continue if key == 'metadata': if value is None: continue @@ -641,6 +694,32 @@ def _extract_revision_labels( return revision_labels +def _is_consolidation_enabled( + custom_metadata: Mapping[str, object] | None, +) -> bool: + """Returns whether direct memories should be consolidated via generate API.""" + if not custom_metadata: + return False + enable_consolidation = custom_metadata.get(_ENABLE_CONSOLIDATION_KEY) + if enable_consolidation is None: + return False + if not isinstance(enable_consolidation, bool): + raise TypeError( + f'custom_metadata["{_ENABLE_CONSOLIDATION_KEY}"] must be a bool.' + ) + return enable_consolidation + + +def _iter_memory_batches(memories: Sequence[str]) -> Sequence[Sequence[str]]: + """Returns memory slices that comply with direct_memories limits.""" + memory_batches: list[Sequence[str]] = [] + for index in range(0, len(memories), _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL): + memory_batches.append( + memories[index : index + _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL] + ) + return memory_batches + + def _build_vertex_metadata( metadata_by_key: Mapping[str, object], ) -> dict[str, object]: diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index 6f342a08..c498b833 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -230,6 +230,14 @@ async def test_initialize_with_project_location_and_api_key_error(): ) +def test_initialize_without_agent_engine_id_error(): + with pytest.raises( + ValueError, + match='agent_engine_id is required for VertexAiMemoryBankService', + ): + mock_vertex_ai_memory_bank_service(agent_engine_id=None) + + @pytest.mark.asyncio async def test_add_session_to_memory(mock_vertexai_client): memory_service = mock_vertex_ai_memory_bank_service() @@ -481,6 +489,7 @@ async def test_add_memory_calls_create( ), ], custom_metadata={ + 'enable_consolidation': False, 'ttl': '6000s', 'source': 'agent', }, @@ -518,6 +527,139 @@ async def test_add_memory_calls_create( vertex_common_types.AgentEngineMemoryConfig(**create_config) +@pytest.mark.asyncio +async def test_add_memory_enable_consolidation_calls_generate_direct_source( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact two')]) + ), + ], + custom_metadata={ + 'enable_consolidation': True, + 'source': 'agent', + }, + ) + + expected_config = {'wait_for_completion': False} + if _supports_generate_memories_metadata(): + expected_config['metadata'] = {'source': {'string_value': 'agent'}} + + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact one'}, + {'fact': 'fact two'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config=expected_config, + ) + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + generate_config = ( + mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) + + +@pytest.mark.asyncio +async def test_add_memory_enable_consolidation_batches_generate_calls( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact two')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact three')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact four')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact five')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact six')]) + ), + ], + custom_metadata={ + 'enable_consolidation': True, + }, + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_has_awaits([ + mock.call( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact one'}, + {'fact': 'fact two'}, + {'fact': 'fact three'}, + {'fact': 'fact four'}, + {'fact': 'fact five'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ), + mock.call( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact six'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ), + ]) + assert mock_vertexai_client.agent_engines.memories.generate.await_count == 2 + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_memory_invalid_enable_consolidation_type_raises( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + with pytest.raises( + TypeError, + match=r'custom_metadata\["enable_consolidation"\] must be a bool', + ): + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ) + ], + custom_metadata={'enable_consolidation': 'yes'}, + ) + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + @pytest.mark.asyncio async def test_add_memory_calls_create_with_memory_entry_metadata( mock_vertexai_client, From a39ca946d6ad2e937a89b7599d48560d5fd11ca0 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 20 Feb 2026 09:49:12 -0800 Subject: [PATCH 003/102] chore: Add sqlite_span_exporter for .adk folder traces Co-authored-by: George Weale PiperOrigin-RevId: 872948208 --- .../adk/telemetry/sqlite_span_exporter.py | 234 +++++++++ .../telemetry/test_sqlite_span_exporter.py | 462 ++++++++++++++++++ 2 files changed, 696 insertions(+) create mode 100644 src/google/adk/telemetry/sqlite_span_exporter.py create mode 100644 tests/unittests/telemetry/test_sqlite_span_exporter.py diff --git a/src/google/adk/telemetry/sqlite_span_exporter.py b/src/google/adk/telemetry/sqlite_span_exporter.py new file mode 100644 index 00000000..1d535908 --- /dev/null +++ b/src/google/adk/telemetry/sqlite_span_exporter.py @@ -0,0 +1,234 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQLite-backed OpenTelemetry span exporter for local development.""" + +from __future__ import annotations + +import json +import logging +import sqlite3 +import threading +from typing import Any +from typing import Iterable +from typing import Optional +from typing import Sequence + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExporter +from opentelemetry.sdk.trace.export import SpanExportResult +from opentelemetry.trace import SpanContext +from opentelemetry.trace import TraceFlags +from opentelemetry.trace import TraceState + +logger = logging.getLogger("google_adk." + __name__) + +_CREATE_SPANS_TABLE = """ +CREATE TABLE IF NOT EXISTS spans ( + span_id TEXT PRIMARY KEY, + trace_id TEXT NOT NULL, + parent_span_id TEXT, + name TEXT NOT NULL, + start_time_unix_nano INTEGER, + end_time_unix_nano INTEGER, + session_id TEXT, + invocation_id TEXT, + attributes_json TEXT +); +""" + +_CREATE_SESSION_INDEX = """ +CREATE INDEX IF NOT EXISTS spans_session_id_idx ON spans(session_id); +""" + +_CREATE_TRACE_INDEX = """ +CREATE INDEX IF NOT EXISTS spans_trace_id_idx ON spans(trace_id); +""" + +_INSERT_SPAN = """ +INSERT OR REPLACE INTO spans ( + span_id, + trace_id, + parent_span_id, + name, + start_time_unix_nano, + end_time_unix_nano, + session_id, + invocation_id, + attributes_json +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); +""" + +_DEFAULT_TIMEOUT_SECONDS = 30.0 + + +class SqliteSpanExporter(SpanExporter): + """Exports spans to a local SQLite database. + + This is intended for local development (e.g. `adk web`) to allow reloading + traces for older sessions after process restart. + """ + + def __init__(self, *, db_path: str): + self._db_path = db_path + self._lock = threading.Lock() + self._conn: Optional[sqlite3.Connection] = None + self._ensure_schema() + + def _get_connection(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = sqlite3.connect( + self._db_path, + timeout=_DEFAULT_TIMEOUT_SECONDS, + check_same_thread=False, + ) + self._conn.row_factory = sqlite3.Row + return self._conn + + def _ensure_schema(self) -> None: + with self._lock: + conn = self._get_connection() + conn.execute(_CREATE_SPANS_TABLE) + conn.execute(_CREATE_SESSION_INDEX) + conn.execute(_CREATE_TRACE_INDEX) + conn.commit() + + def _serialize_attributes(self, attributes: dict[str, Any]) -> str: + try: + return json.dumps( + attributes, + ensure_ascii=False, + default=lambda o: "", + ) + except (TypeError, ValueError) as e: + logger.debug("Failed to serialize span attributes: %r", e) + return "{}" + + def _deserialize_attributes(self, attributes_json: Any) -> dict[str, Any]: + if not attributes_json: + return {} + try: + attributes = json.loads(attributes_json) + except (json.JSONDecodeError, TypeError) as e: + logger.debug("Failed to deserialize span attributes: %r", e) + return {} + return attributes if isinstance(attributes, dict) else {} + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + try: + with self._lock: + conn = self._get_connection() + rows: list[tuple[Any, ...]] = [] + for span in spans: + attributes = dict(span.attributes) if span.attributes else {} + session_id = attributes.get( + "gcp.vertex.agent.session_id" + ) or attributes.get("gen_ai.conversation.id") + invocation_id = attributes.get("gcp.vertex.agent.invocation_id") + + parent_span_id = None + if span.parent is not None: + parent_span_id = format(span.parent.span_id, "016x") + + rows.append(( + format(span.context.span_id, "016x"), + format(span.context.trace_id, "032x"), + parent_span_id, + span.name, + span.start_time, + span.end_time, + session_id, + invocation_id, + self._serialize_attributes(attributes), + )) + conn.executemany(_INSERT_SPAN, rows) + conn.commit() + return SpanExportResult.SUCCESS + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to export spans to SQLite: %s", e) + return SpanExportResult.FAILURE + + def shutdown(self) -> None: + with self._lock: + if self._conn is not None: + self._conn.close() + self._conn = None + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + def _query(self, sql: str, params: Iterable[Any]) -> list[sqlite3.Row]: + with self._lock: + conn = self._get_connection() + cur = conn.execute(sql, tuple(params)) + return list(cur.fetchall()) + + def _row_to_readable_span(self, row: sqlite3.Row) -> ReadableSpan: + trace_id_hex = row["trace_id"] + span_id_hex = row["span_id"] + trace_id = int(str(trace_id_hex), 16) + span_id = int(str(span_id_hex), 16) + trace_state = TraceState() + trace_flags = TraceFlags(TraceFlags.SAMPLED) + context = SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=trace_flags, + trace_state=trace_state, + ) + + parent: SpanContext | None = None + parent_span_id_hex = row["parent_span_id"] + if parent_span_id_hex: + parent = SpanContext( + trace_id=trace_id, + span_id=int(str(parent_span_id_hex), 16), + is_remote=False, + trace_flags=trace_flags, + trace_state=trace_state, + ) + + attributes = self._deserialize_attributes(row["attributes_json"]) + return ReadableSpan( + name=row["name"] or "", + context=context, + parent=parent, + attributes=attributes, + start_time=row["start_time_unix_nano"], + end_time=row["end_time_unix_nano"], + ) + + def get_all_spans_for_session(self, session_id: str) -> list[ReadableSpan]: + """Returns all spans for a session (full trace trees). + + We first find trace_ids associated with the session, then return all spans + for those trace_ids. This works even if some spans are missing session_id + attributes (e.g. parent spans). + """ + trace_rows = self._query( + "SELECT DISTINCT trace_id FROM spans WHERE session_id = ?", + (session_id,), + ) + trace_ids = [r["trace_id"] for r in trace_rows if r["trace_id"]] + if not trace_ids: + return [] + + placeholders = ",".join("?" for _ in trace_ids) + rows = self._query( + f"SELECT * FROM spans WHERE trace_id IN ({placeholders}) " + "ORDER BY start_time_unix_nano", + trace_ids, + ) + return [self._row_to_readable_span(row) for row in rows] diff --git a/tests/unittests/telemetry/test_sqlite_span_exporter.py b/tests/unittests/telemetry/test_sqlite_span_exporter.py new file mode 100644 index 00000000..21437175 --- /dev/null +++ b/tests/unittests/telemetry/test_sqlite_span_exporter.py @@ -0,0 +1,462 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from pathlib import Path + +from google.adk.telemetry.sqlite_span_exporter import SqliteSpanExporter +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExportResult +from opentelemetry.trace import SpanContext +from opentelemetry.trace import TraceFlags +from opentelemetry.trace import TraceState + + +def _create_span( + *, + span_id: int = 0x00000000000ABC12, + trace_id: int = 0x000000000000000000000000000DEF45, + parent_span_id: int | None = None, + name: str = "test_span", + attributes: dict | None = None, + start_time: int = 1000, + end_time: int = 2000, +) -> ReadableSpan: + """Helper to create ReadableSpan instances for testing.""" + context = SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + trace_state=TraceState(), + ) + + parent = None + if parent_span_id is not None: + parent = SpanContext( + trace_id=trace_id, + span_id=parent_span_id, + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + trace_state=TraceState(), + ) + + return ReadableSpan( + name=name, + context=context, + parent=parent, + attributes=attributes or {}, + start_time=start_time, + end_time=end_time, + ) + + +def test_export_single_span_returns_success(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + span = _create_span( + name="test_operation", + attributes={"gcp.vertex.agent.session_id": "session-123"}, + ) + + result = exporter.export([span]) + + assert result == SpanExportResult.SUCCESS + assert db_path.exists() + + +def test_export_empty_list_returns_success(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + result = exporter.export([]) + + assert result == SpanExportResult.SUCCESS + + +def test_get_all_spans_for_session_returns_matching_spans(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + span1 = _create_span( + span_id=0x111, + trace_id=0xAAA111, # Different trace for session-123 + attributes={"gcp.vertex.agent.session_id": "session-123"}, + name="span1", + ) + span2 = _create_span( + span_id=0x222, + trace_id=0xAAA222, # Different trace for session-123 + attributes={"gcp.vertex.agent.session_id": "session-123"}, + name="span2", + ) + span3 = _create_span( + span_id=0x333, + trace_id=0xBBB333, # Different trace for session-456 + attributes={"gcp.vertex.agent.session_id": "session-456"}, + name="span3", + ) + + exporter.export([span1, span2, span3]) + + result = exporter.get_all_spans_for_session("session-123") + + assert len(result) == 2 + names = [span.name for span in result] + assert "span1" in names + assert "span2" in names + assert "span3" not in names + + +def test_get_all_spans_for_session_includes_sibling_spans_without_session_id( + tmp_path, +): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Parent span without session_id (e.g., invocation span) + parent_span = _create_span( + span_id=0x100, + trace_id=0xAAA, + name="invocation", + attributes={}, # No session_id + ) + + # Child span with session_id + child_span = _create_span( + span_id=0x200, + trace_id=0xAAA, # Same trace + parent_span_id=0x100, + name="call_llm", + attributes={"gcp.vertex.agent.session_id": "session-789"}, + ) + + # Sibling span without session_id (should be included) + sibling_span = _create_span( + span_id=0x300, + trace_id=0xAAA, # Same trace + parent_span_id=0x100, + name="tool_call", + attributes={}, # No session_id + ) + + # Unrelated span with different trace_id (should not be included) + unrelated_span = _create_span( + span_id=0x400, + trace_id=0xBBB, # Different trace + name="unrelated", + attributes={}, + ) + + exporter.export([parent_span, child_span, sibling_span, unrelated_span]) + + result = exporter.get_all_spans_for_session("session-789") + + assert len(result) == 3 + names = [span.name for span in result] + assert "invocation" in names + assert "call_llm" in names + assert "tool_call" in names + assert "unrelated" not in names + + +def test_get_all_spans_for_unknown_session_returns_empty_list(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + span = _create_span( + attributes={"gcp.vertex.agent.session_id": "session-123"}, + ) + exporter.export([span]) + + result = exporter.get_all_spans_for_session("unknown-session") + + assert result == [] + + +def test_round_trip_preserves_span_attributes(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + original_attributes = { + "gcp.vertex.agent.session_id": "session-123", + "gcp.vertex.agent.invocation_id": "invocation-456", + "gen_ai.conversation.id": "conv-789", + "custom.attribute": "test_value", + "numeric.value": 42, + "boolean.value": True, + "list.value": [1, 2, 3], + "dict.value": {"nested": "data"}, + } + + original_span = _create_span( + span_id=0x12345678, + trace_id=0xABCDEF123456789, + name="test_operation", + attributes=original_attributes, + start_time=1000000, + end_time=2000000, + ) + + exporter.export([original_span]) + + retrieved_spans = exporter.get_all_spans_for_session("session-123") + + assert len(retrieved_spans) == 1 + retrieved = retrieved_spans[0] + + assert retrieved.name == "test_operation" + assert retrieved.context.span_id == 0x12345678 + assert retrieved.context.trace_id == 0xABCDEF123456789 + assert retrieved.start_time == 1000000 + assert retrieved.end_time == 2000000 + assert retrieved.attributes == original_attributes + + +def test_spans_with_parent_context_exported_correctly(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + parent_span = _create_span( + span_id=0xAAA, + trace_id=0x123, + name="parent", + attributes={"gcp.vertex.agent.session_id": "session-001"}, + ) + + child_span = _create_span( + span_id=0xBBB, + trace_id=0x123, + parent_span_id=0xAAA, + name="child", + attributes={"gcp.vertex.agent.session_id": "session-001"}, + ) + + exporter.export([parent_span, child_span]) + + retrieved_spans = exporter.get_all_spans_for_session("session-001") + + assert len(retrieved_spans) == 2 + + # Find child span in results + child = next(s for s in retrieved_spans if s.name == "child") + assert child.parent is not None + assert child.parent.span_id == 0xAAA + assert child.parent.trace_id == 0x123 + + # Find parent span in results + parent = next(s for s in retrieved_spans if s.name == "parent") + assert parent.parent is None + + +def test_shutdown_closes_connection(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Create a span to ensure connection is open + span = _create_span() + exporter.export([span]) + + # Verify connection exists + assert exporter._conn is not None + + exporter.shutdown() + + # Verify connection is closed + assert exporter._conn is None + + +def test_force_flush_returns_true(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + result = exporter.force_flush() + + assert result is True + + # Also test with timeout parameter + result_with_timeout = exporter.force_flush(timeout_millis=5000) + assert result_with_timeout is True + + +def test_export_handles_spans_with_none_attributes(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + span = _create_span(attributes=None) + + result = exporter.export([span]) + + assert result == SpanExportResult.SUCCESS + + # Verify the span was stored correctly + rows = exporter._query("SELECT attributes_json FROM spans", []) + assert len(rows) == 1 + attributes_json = rows[0]["attributes_json"] + assert json.loads(attributes_json) == {} + + +def test_duplicate_span_id_replaces_previous_row(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Export first version of span + span1 = _create_span( + span_id=0x999, + name="first_version", + attributes={"version": 1, "gcp.vertex.agent.session_id": "session-dup"}, + ) + exporter.export([span1]) + + # Export second version with same span_id + span2 = _create_span( + span_id=0x999, + name="second_version", + attributes={"version": 2, "gcp.vertex.agent.session_id": "session-dup"}, + ) + exporter.export([span2]) + + # Verify only one row exists with updated data + retrieved_spans = exporter.get_all_spans_for_session("session-dup") + assert len(retrieved_spans) == 1 + assert retrieved_spans[0].name == "second_version" + assert retrieved_spans[0].attributes["version"] == 2 + + +def test_non_serializable_attributes_use_fallback(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Create a non-serializable object + class NonSerializable: + pass + + attributes = { + "gcp.vertex.agent.session_id": "session-nonser", + "normal_attr": "value", + "non_serializable": NonSerializable(), + } + + span = _create_span(attributes=attributes) + + result = exporter.export([span]) + + assert result == SpanExportResult.SUCCESS + + # Verify the span was stored and non-serializable attribute has fallback + retrieved_spans = exporter.get_all_spans_for_session("session-nonser") + assert len(retrieved_spans) == 1 + assert retrieved_spans[0].attributes["normal_attr"] == "value" + assert ( + retrieved_spans[0].attributes["non_serializable"] == "" + ) + + +def test_export_multiple_spans_in_batch(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + spans = [ + _create_span( + span_id=i, + name=f"span_{i}", + attributes={"gcp.vertex.agent.session_id": "batch-session"}, + ) + for i in range(10) + ] + + result = exporter.export(spans) + + assert result == SpanExportResult.SUCCESS + + retrieved_spans = exporter.get_all_spans_for_session("batch-session") + assert len(retrieved_spans) == 10 + names = {span.name for span in retrieved_spans} + assert names == {f"span_{i}" for i in range(10)} + + +def test_export_with_alternative_session_id_attribute(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Test using gen_ai.conversation.id as fallback for session_id + span = _create_span( + attributes={"gen_ai.conversation.id": "conv-session-123"}, + ) + + exporter.export([span]) + + # Should be queryable by the conversation id + result = exporter.get_all_spans_for_session("conv-session-123") + + assert len(result) == 1 + assert result[0].attributes["gen_ai.conversation.id"] == "conv-session-123" + + +def test_deserialize_handles_invalid_json(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Manually insert a row with invalid JSON + conn = exporter._get_connection() + conn.execute( + "INSERT INTO spans (span_id, trace_id, name, attributes_json) VALUES (?," + " ?, ?, ?)", + ("abc123", "def456", "test", "not valid json"), + ) + conn.commit() + + # Try to retrieve the span - should not raise, but attributes should be empty + rows = exporter._query("SELECT * FROM spans", []) + span = exporter._row_to_readable_span(rows[0]) + + assert span.name == "test" + assert span.attributes == {} + + +def test_get_spans_ordered_by_start_time(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Create spans with different start times + spans = [ + _create_span( + span_id=0x300, + start_time=3000, + attributes={"gcp.vertex.agent.session_id": "session-order"}, + ), + _create_span( + span_id=0x100, + start_time=1000, + attributes={"gcp.vertex.agent.session_id": "session-order"}, + ), + _create_span( + span_id=0x200, + start_time=2000, + attributes={"gcp.vertex.agent.session_id": "session-order"}, + ), + ] + + exporter.export(spans) + + result = exporter.get_all_spans_for_session("session-order") + + # Verify spans are ordered by start_time + assert len(result) == 3 + assert result[0].context.span_id == 0x100 + assert result[1].context.span_id == 0x200 + assert result[2].context.span_id == 0x300 From bef3f117b4842ce62760328304484cd26a1ec30a Mon Sep 17 00:00:00 2001 From: Sahaja Reddy Pabbathi Reddy Date: Fri, 20 Feb 2026 09:55:04 -0800 Subject: [PATCH 004/102] feat: Bigquery ADK support for search catalog tool Merge https://github.com/google/adk-python/pull/4171 **Problem:** The BigQuery ADK tools currently lack the ability to search for and discover BigQuery assets using the Dataplex Catalog. Users cannot leverage Dataplex's search capabilities within the ADK to find relevant data assets before querying them. **Solution:** This PR integrates a new search_catalog_tool into the BigQuery ADK. This tool utilizes the dataplex catalog client library to interact with the Dataplex API, allowing users to search the catalog. **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. Added the screenshots of the manual adk web UI tests - https://docs.google.com/document/d/1c_lMW7NYGKuLAvPFmSkLehbqySeNyXQIhzQlvo3ixmQ/edit?usp=sharing ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [x] Any dependent changes have been merged and published in downstream modules. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4171 from sahaajaaa:sahaajaaa-bq-adk 3dbbaa4f909cb25259e8e7d73a00a58fbe9c2f09 PiperOrigin-RevId: 872951141 --- contributing/samples/bigquery/README.md | 4 + pyproject.toml | 1 + .../tools/bigquery/bigquery_credentials.py | 8 +- .../adk/tools/bigquery/bigquery_toolset.py | 2 + src/google/adk/tools/bigquery/client.py | 45 +- src/google/adk/tools/bigquery/search_tool.py | 179 +++++++ .../tools/bigquery/test_bigquery_client.py | 75 +++ .../bigquery/test_bigquery_credentials.py | 16 +- .../bigquery/test_bigquery_search_tool.py | 448 ++++++++++++++++++ .../tools/bigquery/test_bigquery_toolset.py | 3 +- 10 files changed, 768 insertions(+), 13 deletions(-) create mode 100644 src/google/adk/tools/bigquery/search_tool.py create mode 100644 tests/unittests/tools/bigquery/test_bigquery_search_tool.py diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index 3ed97432..fc3f8610 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -55,6 +55,9 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: `ARIMA_PLUS` model and then querying it with `ML.DETECT_ANOMALIES` to detect time series data anomalies. +11. `search_catalog` + Searches for data entries across projects using the Dataplex Catalog. This allows discovery of datasets, tables, and other assets. + ## How to use Set up environment variables in your `.env` file for using @@ -159,3 +162,4 @@ the necessary access tokens to call BigQuery APIs on their behalf. * which tables exist in the ml_datasets dataset? * show more details about the penguins table * compute penguins population per island. +* are there any tables related to animals in project ? diff --git a/pyproject.toml b/pyproject.toml index 9bec96cb..a1f136d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "google-cloud-bigquery-storage>=2.0.0", "google-cloud-bigquery>=2.2.0", "google-cloud-bigtable>=2.32.0", # For Bigtable database + "google-cloud-dataplex>=1.7.0,<3.0.0", # For Dataplex Catalog Search tool "google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool "google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index fa23c74c..958ce9d7 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -19,6 +19,10 @@ from ...features import FeatureName from .._google_credentials import BaseGoogleCredentialsConfig BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" +BIGQUERY_SCOPES = [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", +] BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] @@ -34,8 +38,8 @@ class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig): super().__post_init__() if not self.scopes: - self.scopes = BIGQUERY_DEFAULT_SCOPE - + self.scopes = BIGQUERY_SCOPES + # Set the token cache key self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY return self diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 1a748b71..dba5f8ee 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -24,6 +24,7 @@ from typing_extensions import override from . import data_insights_tool from . import metadata_tool from . import query_tool +from . import search_tool from ...features import experimental from ...features import FeatureName from ...tools.base_tool import BaseTool @@ -87,6 +88,7 @@ class BigQueryToolset(BaseToolset): query_tool.analyze_contribution, query_tool.detect_anomalies, data_insights_tool.ask_data_insights, + search_tool.search_catalog, ] ] diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index d57c0c80..2cb4e67c 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -14,19 +14,22 @@ from __future__ import annotations +from typing import List from typing import Optional +from typing import Union import google.api_core.client_info +from google.api_core.gapic_v1 import client_info as gapic_client_info from google.auth.credentials import Credentials from google.cloud import bigquery +from google.cloud import dataplex_v1 from ... import version -USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" - - -from typing import List -from typing import Union +USER_AGENT_BASE = f"google-adk/{version.__version__}" +BQ_USER_AGENT = f"adk-bigquery-tool {USER_AGENT_BASE}" +DP_USER_AGENT = f"adk-dataplex-tool {USER_AGENT_BASE}" +USER_AGENT = BQ_USER_AGENT def get_bigquery_client( @@ -48,7 +51,7 @@ def get_bigquery_client( A BigQuery client. """ - user_agents = [USER_AGENT] + user_agents = [BQ_USER_AGENT] if user_agent: if isinstance(user_agent, str): user_agents.append(user_agent) @@ -67,3 +70,33 @@ def get_bigquery_client( ) return bigquery_client + + +def get_dataplex_catalog_client( + *, + credentials: Credentials, + user_agent: Optional[Union[str, List[str]]] = None, +) -> dataplex_v1.CatalogServiceClient: + """Get a Dataplex CatalogServiceClient with minimal necessary arguments. + + Args: + credentials: The credentials to use for the request. + user_agent: Additional user agent string(s) to append. + + Returns: + A Dataplex Client. + """ + + user_agents = [DP_USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) + + client_info = gapic_client_info.ClientInfo(user_agent=" ".join(user_agents)) + + return dataplex_v1.CatalogServiceClient( + credentials=credentials, + client_info=client_info, + ) diff --git a/src/google/adk/tools/bigquery/search_tool.py b/src/google/adk/tools/bigquery/search_tool.py new file mode 100644 index 00000000..0bf01d5a --- /dev/null +++ b/src/google/adk/tools/bigquery/search_tool.py @@ -0,0 +1,179 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any + +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + +from . import client +from .config import BigQueryToolConfig + + +def _construct_search_query_helper( + predicate: str, operator: str, items: list[str] +) -> str: + """Constructs a search query part for a specific predicate and items.""" + if not items: + return "" + + clauses = [f'{predicate}{operator}"{item}"' for item in items] + return "(" + " OR ".join(clauses) + ")" if len(items) > 1 else clauses[0] + + +def search_catalog( + prompt: str, + project_id: str, + *, + credentials: Credentials, + settings: BigQueryToolConfig, + location: str | None = None, + page_size: int = 10, + project_ids_filter: list[str] | None = None, + dataset_ids_filter: list[str] | None = None, + types_filter: list[str] | None = None, +) -> dict[str, Any]: + """Searches for BigQuery assets within Dataplex. + + Args: + prompt: The base search query (natural language or keywords). + project_id: The Google Cloud project ID to scope the search. + credentials: Credentials for the request. + settings: BigQuery tool settings. + location: The Dataplex location to use. + page_size: Maximum number of results. + project_ids_filter: Specific project IDs to include in the search results. + If None, defaults to the scoping project_id. + dataset_ids_filter: BigQuery dataset IDs to filter by. + types_filter: Entry types to filter by (e.g., BigQueryEntryType.TABLE, + BigQueryEntryType.DATASET). + + Returns: + Search results or error. The "results" list contains items with: + - name: The Dataplex Entry name (e.g., + "projects/p/locations/l/entryGroups/g/entries/e"). + - linked_resource: The underlying BigQuery resource name (e.g., + "//bigquery.googleapis.com/projects/p/datasets/d/tables/t"). + - display_name, entry_type, description, location, update_time. + + Examples: + Search for tables related to customer data: + + >>> search_catalog( + ... prompt="Search for tables related to customer data", + ... project_id="my-project", + ... credentials=creds, + ... settings=settings + ... ) + { + "status": "SUCCESS", + "results": [ + { + "name": + "projects/my-project/locations/us/entryGroups/@bigquery/entries/entry-id", + "display_name": "customer_table", + "entry_type": + "projects/p/locations/l/entryTypes/bigquery-table", + "linked_resource": + "//bigquery.googleapis.com/projects/my-project/datasets/d/tables/customer_table", + "description": "Table containing customer details.", + "location": "us", + "update_time": "2024-01-01 12:00:00+00:00" + } + ] + } + """ + + try: + if not project_id: + return { + "status": "ERROR", + "error_details": "project_id must be provided.", + } + + with client.get_dataplex_catalog_client( + credentials=credentials, + user_agent=[settings.application_name, "search_catalog"], + ) as dataplex_client: + query_parts = [] + if prompt: + query_parts.append(f"({prompt})") + + # Filter by project IDs + projects_to_filter = ( + project_ids_filter if project_ids_filter else [project_id] + ) + if projects_to_filter: + query_parts.append( + _construct_search_query_helper("projectid", "=", projects_to_filter) + ) + + # Filter by dataset IDs + if dataset_ids_filter: + dataset_resource_filters = [] + for pid in projects_to_filter: + for did in dataset_ids_filter: + dataset_resource_filters.append( + f'linked_resource:"//bigquery.googleapis.com/projects/{pid}/datasets/{did}/*"' + ) + if dataset_resource_filters: + query_parts.append(f"({' OR '.join(dataset_resource_filters)})") + # Filter by entry types + if types_filter: + query_parts.append( + _construct_search_query_helper("type", "=", types_filter) + ) + + # Always scope to BigQuery system + query_parts.append("system=BIGQUERY") + + full_query = " AND ".join(filter(None, query_parts)) + + search_location = location or settings.location or "global" + search_scope = f"projects/{project_id}/locations/{search_location}" + + request = dataplex_v1.SearchEntriesRequest( + name=search_scope, + query=full_query, + page_size=page_size, + semantic_search=True, + ) + + response = dataplex_client.search_entries(request=request) + + results = [] + for result in response.results: + entry = result.dataplex_entry + source = entry.entry_source + results.append({ + "name": entry.name, + "display_name": source.display_name or "", + "entry_type": entry.entry_type, + "update_time": str(entry.update_time), + "linked_resource": source.resource or "", + "description": source.description or "", + "location": source.location or "", + }) + return {"status": "SUCCESS", "results": results} + + except api_exceptions.GoogleAPICallError as e: + logging.exception("search_catalog tool: API call failed") + return {"status": "ERROR", "error_details": f"Dataplex API Error: {e}"} + except Exception as e: + logging.exception("search_catalog tool: Unexpected error") + return {"status": "ERROR", "error_details": repr(e)} diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py index 80a97f8f..d8d5e726 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_client.py +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -18,9 +18,13 @@ import os from unittest import mock import google.adk +from google.adk.tools.bigquery.client import DP_USER_AGENT from google.adk.tools.bigquery.client import get_bigquery_client +from google.adk.tools.bigquery.client import get_dataplex_catalog_client +from google.api_core.gapic_v1 import client_info as gapic_client_info import google.auth from google.auth.exceptions import DefaultCredentialsError +from google.cloud import dataplex_v1 from google.cloud.bigquery import client as bigquery_client from google.oauth2.credentials import Credentials @@ -201,3 +205,74 @@ def test_bigquery_client_location_custom(): # Verify that the client has the desired project set assert client.project == "test-gcp-project" assert client.location == "us-central1" + + +# Tests for Dataplex Catalog Client +# ------------------------------------------------------------------------------ + + +# Mock the CatalogServiceClient class directly +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_default(mock_catalog_service_client): + """Test get_dataplex_catalog_client with default user agent.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + + client = get_dataplex_catalog_client(credentials=mock_creds) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + + assert kwargs["credentials"] == mock_creds + client_info = kwargs["client_info"] + assert isinstance(client_info, gapic_client_info.ClientInfo) + assert client_info.user_agent == DP_USER_AGENT + + # Ensure the function returns the mock instance + assert client == mock_catalog_service_client.return_value + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_str(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent string.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua = "catalog_ua/1.0" + expected_ua = f"{DP_USER_AGENT} {custom_ua}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent list.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} {' '.join(custom_ua_list)}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list_with_none( + mock_catalog_service_client, +): + """Test get_dataplex_catalog_client with a list containing None.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", None, "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} catalog_ua catalog_ua_2.0" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index 9cf8c9e4..e2066292 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -44,9 +44,11 @@ class TestBigQueryCredentials: # Verify that the credentials are properly stored and attributes are extracted assert config.credentials == auth_creds - assert config.client_id is None assert config.client_secret is None - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_valid_credentials_object_oauth2_credentials(self): """Test that providing valid Credentials object works correctly with @@ -86,7 +88,10 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_valid_client_id_secret_pair_w_scope(self): """Test that providing client ID and secret with explicit scopes works. @@ -128,7 +133,10 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_missing_client_secret_raises_error(self): """Test that missing client secret raises appropriate validation error. diff --git a/tests/unittests/tools/bigquery/test_bigquery_search_tool.py b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py new file mode 100644 index 00000000..0ccdc9e1 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py @@ -0,0 +1,448 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from typing import Any +import unittest +from unittest import mock + +from absl.testing import parameterized + +# Mock google.genai and pydantic if not available, before importing google.adk modules +try: + import google.genai +except ImportError: + m = mock.MagicMock() + m.__path__ = [] + sys.modules["google.genai"] = m + sys.modules["google.genai.types"] = mock.MagicMock() + sys.modules["google.genai.errors"] = mock.MagicMock() + +try: + import pydantic +except ImportError: + m_pydantic = mock.MagicMock() + + class MockBaseModel: + pass + + m_pydantic.BaseModel = MockBaseModel + sys.modules["pydantic"] = m_pydantic + +try: + import fastapi + import fastapi.openapi.models +except ImportError: + m_fastapi = mock.MagicMock() + m_fastapi.openapi.models = mock.MagicMock() + sys.modules["fastapi"] = m_fastapi + sys.modules["fastapi.openapi"] = mock.MagicMock() + sys.modules["fastapi.openapi.models"] = mock.MagicMock() + + +from google.adk.tools.bigquery import search_tool +from google.adk.tools.bigquery.config import BigQueryToolConfig +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + + +def _mock_creds(): + return mock.create_autospec(Credentials, instance=True) + + +def _mock_settings(app_name: str | None = "test-app"): + return BigQueryToolConfig(application_name=app_name) + + +def _mock_search_entries_response(results: list[dict[str, Any]]): + mock_response = mock.MagicMock(spec=dataplex_v1.SearchEntriesResponse) + mock_results = [] + for r in results: + mock_result = mock.create_autospec( + dataplex_v1.SearchEntriesResult, instance=True + ) + # Manually attach dataplex_entry since it's not visible in dir() of the proto class + mock_entry = mock.create_autospec(dataplex_v1.Entry, instance=True) + mock_result.dataplex_entry = mock_entry + + mock_entry.name = r.get("name") + mock_entry.entry_type = r.get("entry_type") + mock_entry.update_time = r.get("update_time", "2026-01-14T05:00:00Z") + + # Manually attach entry_source since it's not visible in dir() of the proto class + mock_source = mock.create_autospec(dataplex_v1.EntrySource, instance=True) + mock_entry.entry_source = mock_source + + mock_source.display_name = r.get("display_name") + mock_source.resource = r.get("linked_resource") + mock_source.description = r.get("description") + mock_source.location = r.get("location") + mock_results.append(mock_result) + mock_response.results = mock_results + return mock_response + + +class TestSearchCatalog(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.mock_dataplex_client = mock.create_autospec( + dataplex_v1.CatalogServiceClient, instance=True + ) + + # Patch get_dataplex_catalog_client + self.mock_get_dataplex_client = self.enter_context( + mock.patch( + "google.adk.tools.bigquery.client.get_dataplex_catalog_client", + autospec=True, + ) + ) + self.mock_get_dataplex_client.return_value = self.mock_dataplex_client + self.mock_dataplex_client.__enter__.return_value = self.mock_dataplex_client + + # Patch SearchEntriesRequest + self.mock_search_request = self.enter_context( + mock.patch( + "google.cloud.dataplex_v1.SearchEntriesRequest", autospec=True + ) + ) + + def test_search_catalog_success(self): + """Test the successful path of search_catalog.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "customer data" + project_id = "test-project" + location = "us" + + mock_api_results = [{ + "name": "entry1", + "entry_type": "TABLE", + "display_name": "Cust Table", + "linked_resource": ( + "//bigquery.googleapis.com/projects/p/datasets/d/tables/t1" + ), + "description": "Table 1", + "location": "us", + }] + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response(mock_api_results) + ) + + result = search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + ) + + with self.subTest("Test result content"): + self.assertEqual(result["status"], "SUCCESS") + self.assertLen(result["results"], 1) + self.assertEqual(result["results"][0]["name"], "entry1") + self.assertEqual(result["results"][0]["display_name"], "Cust Table") + + with self.subTest("Test mock calls"): + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=["test-app", "search_catalog"] + ) + + expected_query = ( + '(customer data) AND projectid="test-project" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/us", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once_with( + request=self.mock_search_request.return_value + ) + + def test_search_catalog_no_project_id(self): + """Test search_catalog with missing project_id.""" + result = search_tool.search_catalog( + prompt="test", + project_id="", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn("project_id must be provided", result["error_details"]) + self.mock_get_dataplex_client.assert_not_called() + + def test_search_catalog_api_error(self): + """Test search_catalog handling API exceptions.""" + self.mock_dataplex_client.search_entries.side_effect = ( + api_exceptions.BadRequest("Invalid query") + ) + + result = search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn( + "Dataplex API Error: 400 Invalid query", result["error_details"] + ) + + def test_search_catalog_other_exception(self): + """Test search_catalog handling unexpected exceptions.""" + self.mock_get_dataplex_client.side_effect = Exception( + "Something went wrong" + ) + + result = search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn("Something went wrong", result["error_details"]) + + @parameterized.named_parameters( + ("project_filter", "p", ["proj1"], None, None, 'projectid="proj1"'), + ( + "multi_project_filter", + "p", + ["p1", "p2"], + None, + None, + '(projectid="p1" OR projectid="p2")', + ), + ("type_filter", "p", None, None, ["TABLE"], 'type="TABLE"'), + ( + "multi_type_filter", + "p", + None, + None, + ["TABLE", "DATASET"], + '(type="TABLE" OR type="DATASET")', + ), + ( + "project_and_dataset_filters", + "inventory", + ["proj1", "proj2"], + ["dsetA"], + None, + ( + '(projectid="proj1" OR projectid="proj2") AND' + ' (linked_resource:"//bigquery.googleapis.com/projects/proj1/datasets/dsetA/*"' + ' OR linked_resource:"//bigquery.googleapis.com/projects/proj2/datasets/dsetA/*")' + ), + ), + ) + def test_search_catalog_query_construction( + self, prompt, project_ids, dataset_ids, types, expected_query_part + ): + """Test different query constructions based on filters.""" + search_tool.search_catalog( + prompt=prompt, + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + project_ids_filter=project_ids, + dataset_ids_filter=dataset_ids, + types_filter=types, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + query = kwargs["query"] + + if prompt: + assert f"({prompt})" in query + assert "system=BIGQUERY" in query + assert expected_query_part in query + + def test_search_catalog_no_app_name(self): + """Test search_catalog when settings.application_name is None.""" + creds = _mock_creds() + settings = _mock_settings(app_name=None) + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + location="us", + ) + + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=[None, "search_catalog"] + ) + + def test_search_catalog_multi_project_filter_semantic(self): + """Test semantic search with a multi-project filter.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "What datasets store user profiles?" + project_id = "main-project" + project_filters = ["user-data-proj", "shared-infra-proj"] + location = "global" + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + project_ids_filter=project_filters, + types_filter=["DATASET"], + ) + + expected_query = ( + f"({prompt}) AND " + '(projectid="user-data-proj" OR projectid="shared-infra-proj") AND ' + 'type="DATASET" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + def test_search_catalog_natural_language_semantic(self): + """Test natural language prompts with semantic search enabled and check output.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "Find tables about football matches" + project_id = "sports-analytics" + location = "europe-west1" + + # Mock the results that the API would return for this semantic query + mock_api_results = [ + { + "name": ( + "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb1" + ), + "display_name": "uk_football_premiership", + "entry_type": ( + "projects/655216118709/locations/global/entryTypes/bigquery-table" + ), + "linked_resource": ( + "//bigquery.googleapis.com/projects/sports-analytics/datasets/uk/tables/premiership" + ), + "description": "Stats for UK Premier League matches.", + "location": "europe-west1", + }, + { + "name": ( + "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb2" + ), + "display_name": "serie_a_matches", + "entry_type": ( + "projects/655216118709/locations/global/entryTypes/bigquery-table" + ), + "linked_resource": ( + "//bigquery.googleapis.com/projects/sports-analytics/datasets/italy/tables/serie_a" + ), + "description": "Italian Serie A football results.", + "location": "europe-west1", + }, + ] + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response(mock_api_results) + ) + + result = search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + ) + + with self.subTest("Query Construction"): + # Assert the request was made as expected + expected_query = ( + f'({prompt}) AND projectid="{project_id}" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + with self.subTest("Response Processing"): + # Assert the output is processed correctly + self.assertEqual(result["status"], "SUCCESS") + self.assertLen(result["results"], 2) + self.assertEqual( + result["results"][0]["display_name"], "uk_football_premiership" + ) + self.assertEqual(result["results"][1]["display_name"], "serie_a_matches") + self.assertIn("UK Premier League", result["results"][0]["description"]) + + def test_search_catalog_default_location(self): + """Test search_catalog fallback to global location when None is provided.""" + creds = _mock_creds() + settings = _mock_settings() + # settings.location is None by default + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + name_arg = kwargs["name"] + self.assertIn("locations/global", name_arg) + + def test_search_catalog_settings_location(self): + """Test search_catalog uses settings.location when provided.""" + creds = _mock_creds() + settings = BigQueryToolConfig(location="eu") + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + name_arg = kwargs["name"] + self.assertIn("locations/eu", name_arg) diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index f1f73aa6..0eced4b1 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 10 + assert len(tools) == 11 assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ @@ -55,6 +55,7 @@ async def test_bigquery_toolset_tools_default(): "forecast", "analyze_contribution", "detect_anomalies", + "search_catalog", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names From 7478bdaa9817b0285b4119e8c739d7520373f719 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Fri, 20 Feb 2026 10:44:51 -0800 Subject: [PATCH 005/102] fix: Parallelize tool resolution in LlmAgent.canonical_tools() Previously we resolved tools sequentially by awaiting _convert_tool_union_to_tools() in a loop -- reduce the latency by resolving tools concurrently. Co-authored-by: Kathy Wu PiperOrigin-RevId: 872979105 --- src/google/adk/agents/llm_agent.py | 20 +++++++++++------- .../unittests/agents/test_llm_agent_fields.py | 21 +++++++++++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 5294e056..4e07651c 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import importlib import inspect import logging @@ -589,24 +590,27 @@ class LlmAgent(BaseAgent): return global_instruction, True async def canonical_tools( - self, ctx: ReadonlyContext = None + self, ctx: Optional[ReadonlyContext] = None ) -> list[BaseTool]: """The resolved self.tools field as a list of BaseTool based on the context. This method is only for use by Agent Development Kit. """ - resolved_tools = [] # We may need to wrap some built-in tools if there are other tools # because the built-in tools cannot be used together with other tools. # TODO(b/448114567): Remove once the workaround is no longer needed. multiple_tools = len(self.tools) > 1 model = self.canonical_model - for tool_union in self.tools: - resolved_tools.extend( - await _convert_tool_union_to_tools( - tool_union, ctx, model, multiple_tools - ) - ) + + results = await asyncio.gather(*( + _convert_tool_union_to_tools(tool_union, ctx, model, multiple_tools) + for tool_union in self.tools + )) + + resolved_tools = [] + for tools in results: + resolved_tools.extend(tools) + return resolved_tools @property diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 8a3623cb..df543db9 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -451,6 +451,27 @@ class TestCanonicalTools: assert tools[0].name == 'vertex_ai_search' assert tools[0].__class__.__name__ == 'VertexAiSearchTool' + async def test_multiple_tools_resolution(self): + """Test that multiple tools are resolved correctly.""" + + def _tool_1(): + pass + + def _tool_2(): + pass + + agent = LlmAgent( + name='test_agent', + model='gemini-pro', + tools=[_tool_1, _tool_2], + ) + ctx = await _create_readonly_context(agent) + tools = await agent.canonical_tools(ctx) + + assert len(tools) == 2 + assert tools[0].name == '_tool_1' + assert tools[1].name == '_tool_2' + # Tests for multi-provider model support via string model names @pytest.mark.parametrize( From e6b601a2ab71b7e2df0240fd55550dca1eba8397 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 20 Feb 2026 11:20:31 -0800 Subject: [PATCH 006/102] fix: Invoke on_tool_error_callback for missing tools in live mode In live mode, when the model calls an unregistered tool, ADK now runs on_tool_error_callback before failing. If the callback returns a response, ADK emits that function response and continues; otherwise it keeps the old ValueError Co-authored-by: George Weale PiperOrigin-RevId: 872996178 --- src/google/adk/flows/llm_flows/functions.py | 57 +++++++++++++- .../llm_flows/test_live_tool_callbacks.py | 77 +++++++++++++++++++ 2 files changed, 131 insertions(+), 3 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 6f34e8fe..4d045fac 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -660,14 +660,65 @@ async def _execute_single_function_call_live( streaming_lock: asyncio.Lock, ) -> Optional[Event]: """Execute a single function call for live mode with thread safety.""" - tool, tool_context = _get_tool_and_context( - invocation_context, function_call, tools_dict - ) + async def _run_on_tool_error_callbacks( + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict[str, Any]]: + """Runs the on_tool_error_callbacks for the given tool.""" + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=tool_args, + tool_context=tool_context, + error=error, + ) + ) + if error_response is not None: + return error_response + + for callback in agent.canonical_on_tool_error_callbacks: + error_response = callback( + tool=tool, + args=tool_args, + tool_context=tool_context, + error=error, + ) + if inspect.isawaitable(error_response): + error_response = await error_response + if error_response is not None: + return error_response + + return None + + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. function_args = ( copy.deepcopy(function_call.args) if function_call.args else {} ) + tool_context = _create_tool_context(invocation_context, function_call) + + try: + tool = _get_tool(function_call, tools_dict) + except ValueError as tool_error: + tool = BaseTool(name=function_call.name, description='Tool not found') + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + return __build_response_event( + tool, error_response, tool_context, invocation_context + ) + raise tool_error + async def _run_with_trace(): nonlocal function_args diff --git a/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py index caab8f3f..016e9b49 100644 --- a/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py @@ -386,3 +386,80 @@ async def test_live_callback_compatibility_with_async(): async_response = async_result.content.parts[0].function_response.response live_response = live_result.content.parts[0].function_response.response assert async_response == live_response == {"bypassed": "by_before_callback"} + + +@pytest.mark.asyncio +async def test_live_on_tool_error_callback_tool_not_found_noop(): + """Test that on_tool_error_callback is a no-op when the tool is not found.""" + + def noop_on_tool_error_callback(tool, args, tool_context, error): + return None + + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + on_tool_error_callback=noop_on_tool_error_callback, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="" + ) + function_call = types.FunctionCall(name="nonexistent_function", args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + with pytest.raises(ValueError): + await handle_function_calls_live(invocation_context, event, tools_dict) + + +@pytest.mark.asyncio +async def test_live_on_tool_error_callback_tool_not_found_modify_tool_response(): + """Test that on_tool_error_callback modifies tool response when tool is not found.""" + + def mock_on_tool_error_callback(tool, args, tool_context, error): + return {"result": "on_tool_error_callback_response"} + + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + on_tool_error_callback=mock_on_tool_error_callback, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="" + ) + function_call = types.FunctionCall(name="nonexistent_function", args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + result_event = await handle_function_calls_live( + invocation_context, + event, + tools_dict, + ) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == { + "result": "on_tool_error_callback_response" + } From a7b509763c1732f0363e90952bb4c2672572d542 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 20 Feb 2026 11:28:29 -0800 Subject: [PATCH 007/102] feat: Use --memory_service_uri in ADK CLI run command Co-authored-by: George Weale PiperOrigin-RevId: 873000092 --- src/google/adk/cli/cli.py | 20 ++- src/google/adk/cli/cli_tools_click.py | 9 +- src/google/adk/cli/service_registry.py | 6 + tests/unittests/cli/test_service_registry.py | 7 + tests/unittests/cli/utils/test_cli.py | 136 ++++++++++++++++++ .../cli/utils/test_cli_tools_click.py | 13 +- .../cli/utils/test_service_factory.py | 9 ++ 7 files changed, 187 insertions(+), 13 deletions(-) diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 16eba88b..1d49f50d 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -28,6 +28,7 @@ from ..apps.app import App from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from ..memory.base_memory_service import BaseMemoryService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session @@ -37,6 +38,7 @@ from .service_registry import load_services_module from .utils import envs from .utils.agent_loader import AgentLoader from .utils.service_factory import create_artifact_service_from_options +from .utils.service_factory import create_memory_service_from_options from .utils.service_factory import create_session_service_from_options @@ -53,6 +55,7 @@ async def run_input_file( session_service: BaseSessionService, credential_service: BaseCredentialService, input_path: str, + memory_service: Optional[BaseMemoryService] = None, ) -> Session: app = ( agent_or_app @@ -63,6 +66,7 @@ async def run_input_file( app=app, artifact_service=artifact_service, session_service=session_service, + memory_service=memory_service, credential_service=credential_service, ) with open(input_path, 'r', encoding='utf-8') as f: @@ -93,6 +97,7 @@ async def run_interactively( session: Session, session_service: BaseSessionService, credential_service: BaseCredentialService, + memory_service: Optional[BaseMemoryService] = None, ) -> None: app = ( root_agent_or_app @@ -103,6 +108,7 @@ async def run_interactively( app=app, artifact_service=artifact_service, session_service=session_service, + memory_service=memory_service, credential_service=credential_service, ) while True: @@ -137,6 +143,7 @@ async def run_cli( session_id: Optional[str] = None, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, + memory_service_uri: Optional[str] = None, use_local_storage: bool = True, ) -> None: """Runs an interactive CLI for a certain agent. @@ -154,6 +161,7 @@ async def run_cli( session_id: Optional[str], the session ID to save the session to on exit. session_service_uri: Optional[str], custom session service URI. artifact_service_uri: Optional[str], custom artifact service URI. + memory_service_uri: Optional[str], custom memory service URI. use_local_storage: bool, whether to use local .adk storage by default. """ agent_parent_path = Path(agent_parent_dir).resolve() @@ -171,6 +179,9 @@ async def run_cli( if isinstance(agent_or_app, App) and agent_or_app.name != agent_folder_name: app_name_to_dir = {agent_or_app.name: agent_folder_name} + if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'): + envs.load_dotenv_for_agent(agent_folder_name, agents_dir) + # Create session and artifact services using factory functions. # Sessions persist under //.adk/session.db when enabled. session_service = create_session_service_from_options( @@ -185,10 +196,12 @@ async def run_cli( artifact_service_uri=artifact_service_uri, use_local_storage=use_local_storage, ) + memory_service = create_memory_service_from_options( + base_dir=agent_parent_path, + memory_service_uri=memory_service_uri, + ) credential_service = InMemoryCredentialService() - if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'): - envs.load_dotenv_for_agent(agent_folder_name, agents_dir) # Helper function for printing events def _print_event(event) -> None: @@ -208,6 +221,7 @@ async def run_cli( agent_or_app=agent_or_app, artifact_service=artifact_service, session_service=session_service, + memory_service=memory_service, credential_service=credential_service, input_path=input_file, ) @@ -235,6 +249,7 @@ async def run_cli( session, session_service, credential_service, + memory_service=memory_service, ) else: session = await session_service.create_session( @@ -247,6 +262,7 @@ async def run_cli( session, session_service, credential_service, + memory_service=memory_service, ) if save_session: diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 5b5d3e5c..f55a8f10 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -645,14 +645,6 @@ def cli_run( """ logs.log_to_tmp_folder() - # Validation warning for memory_service_uri (not supported for adk run) - if memory_service_uri: - click.secho( - "WARNING: --memory_service_uri is not supported for adk run.", - fg="yellow", - err=True, - ) - agent_parent_folder = os.path.dirname(agent) agent_folder_name = os.path.basename(agent) @@ -666,6 +658,7 @@ def cli_run( session_id=session_id, session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, + memory_service_uri=memory_service_uri, use_local_storage=use_local_storage, ) ) diff --git a/src/google/adk/cli/service_registry.py b/src/google/adk/cli/service_registry.py index 2ea286ef..b1328958 100644 --- a/src/google/adk/cli/service_registry.py +++ b/src/google/adk/cli/service_registry.py @@ -301,6 +301,11 @@ def _register_builtin_services(registry: ServiceRegistry) -> None: registry.register_artifact_service("file", file_artifact_factory) # -- Memory Services -- + def memory_memory_factory(_uri: str, **_): + from ..memory.in_memory_memory_service import InMemoryMemoryService + + return InMemoryMemoryService() + def rag_memory_factory(uri: str, **kwargs): from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService @@ -324,6 +329,7 @@ def _register_builtin_services(registry: ServiceRegistry) -> None: ) return VertexAiMemoryBankService(**params) + registry.register_memory_service("memory", memory_memory_factory) registry.register_memory_service("rag", rag_memory_factory) registry.register_memory_service("agentengine", agentengine_memory_factory) diff --git a/tests/unittests/cli/test_service_registry.py b/tests/unittests/cli/test_service_registry.py index 37c6e7c2..dd33e006 100644 --- a/tests/unittests/cli/test_service_registry.py +++ b/tests/unittests/cli/test_service_registry.py @@ -165,6 +165,13 @@ def test_create_memory_service_agentengine_full(registry, mock_services): ) +def test_create_memory_service_memory(registry): + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + + memory_service = registry.create_memory_service("memory://") + assert isinstance(memory_service, InMemoryMemoryService) + + # General Tests def test_unsupported_scheme(registry, mock_services): session_service = registry.create_session_service("unsupported://foo") diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 6814ef97..f7df1bf1 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -354,9 +354,145 @@ async def test_run_cli_accepts_memory_scheme( save_session=False, session_service_uri="memory://", artifact_service_uri="memory://", + memory_service_uri="memory://", ) +@pytest.mark.asyncio +async def test_run_cli_invalid_memory_uri_surfaces_value_error( + fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_cli should let ValueError propagate for invalid memory service URIs.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": []} + input_path = tmp_path / "invalid_memory_uri.json" + input_path.write_text(json.dumps(input_json)) + + def _raise_invalid_memory_uri( + *, + base_dir: Path | str, + memory_service_uri: str | None = None, + ) -> object: + del base_dir, memory_service_uri + raise ValueError("Unsupported memory service URI: unknown://x") + + monkeypatch.setattr( + cli, "create_memory_service_from_options", _raise_invalid_memory_uri + ) + + with pytest.raises(ValueError, match="Unsupported memory service URI"): + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + memory_service_uri="unknown://x", + ) + + +@pytest.mark.asyncio +async def test_run_cli_passes_memory_service_to_input_file( + fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_cli should construct and pass the configured memory service.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": []} + input_path = tmp_path / "memory_input.json" + input_path.write_text(json.dumps(input_json)) + + memory_service_sentinel = object() + captured_factory_args: dict[str, Any] = {} + captured_memory_service: dict[str, Any] = {} + + def _memory_factory( + *, + base_dir: Path | str, + memory_service_uri: str | None = None, + ) -> object: + captured_factory_args["base_dir"] = base_dir + captured_factory_args["memory_service_uri"] = memory_service_uri + return memory_service_sentinel + + async def _run_input_file( + app_name: str, + user_id: str, + agent_or_app: BaseAgent | App, + artifact_service: Any, + session_service: Any, + credential_service: InMemoryCredentialService, + input_path: str, + memory_service: Any = None, + ) -> object: + del app_name, user_id, agent_or_app, artifact_service + del session_service, credential_service, input_path + captured_memory_service["value"] = memory_service + return object() + + monkeypatch.setattr( + cli, "create_memory_service_from_options", _memory_factory + ) + monkeypatch.setattr(cli, "run_input_file", _run_input_file) + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + memory_service_uri="memory://", + ) + + assert Path(captured_factory_args["base_dir"]) == parent_dir.resolve() + assert captured_factory_args["memory_service_uri"] == "memory://" + assert captured_memory_service["value"] is memory_service_sentinel + + +@pytest.mark.asyncio +async def test_run_cli_loads_dotenv_before_memory_service_creation( + fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_cli should load agent .env values before creating memory service.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": []} + input_path = tmp_path / "dotenv_order_input.json" + input_path.write_text(json.dumps(input_json)) + + call_order: list[str] = [] + + def _load_dotenv_for_agent(agent_name: str, agents_dir: str) -> None: + del agent_name, agents_dir + call_order.append("load_dotenv") + + def _memory_factory( + *, + base_dir: Path | str, + memory_service_uri: str | None = None, + ) -> object: + del base_dir, memory_service_uri + call_order.append("create_memory") + return object() + + monkeypatch.setenv("ADK_DISABLE_LOAD_DOTENV", "0") + monkeypatch.setattr(cli.envs, "load_dotenv_for_agent", _load_dotenv_for_agent) + monkeypatch.setattr( + cli, "create_memory_service_from_options", _memory_factory + ) + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + memory_service_uri="memory://", + ) + + assert "create_memory" in call_order + assert "load_dotenv" in call_order + assert call_order.index("load_dotenv") < call_order.index("create_memory") + + @pytest.mark.asyncio async def test_run_interactively_whitespace_and_exit( tmp_path: Path, monkeypatch: pytest.MonkeyPatch diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index 61b1468c..7c642dbb 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -23,6 +23,7 @@ from types import SimpleNamespace from typing import Any from typing import Dict from typing import List +from typing import Optional from typing import Tuple from unittest import mock @@ -129,7 +130,7 @@ def test_cli_create_cmd_invokes_run_cmd( # cli run @pytest.mark.parametrize( - "cli_args,expected_session_uri,expected_artifact_uri", + "cli_args,expected_session_uri,expected_artifact_uri,expected_memory_uri", [ pytest.param( [ @@ -137,15 +138,19 @@ def test_cli_create_cmd_invokes_run_cmd( "memory://", "--artifact_service_uri", "memory://", + "--memory_service_uri", + "memory://", ], "memory://", "memory://", + "memory://", id="memory_scheme_uris", ), pytest.param( [], None, None, + None, id="default_uris_none", ), ], @@ -154,8 +159,9 @@ def test_cli_run_service_uris( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cli_args: list, - expected_session_uri: str, - expected_artifact_uri: str, + expected_session_uri: Optional[str], + expected_artifact_uri: Optional[str], + expected_memory_uri: Optional[str], ) -> None: """`adk run` should forward service URIs correctly to run_cli.""" agent_dir = tmp_path / "agent" @@ -186,6 +192,7 @@ def test_cli_run_service_uris( coro_locals = captured_locals[0] assert coro_locals.get("session_service_uri") == expected_session_uri assert coro_locals.get("artifact_service_uri") == expected_artifact_uri + assert coro_locals.get("memory_service_uri") == expected_memory_uri assert coro_locals["agent_folder_name"] == "agent" diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py index 910bf906..d6f1426a 100644 --- a/tests/unittests/cli/utils/test_service_factory.py +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -252,6 +252,15 @@ def test_create_memory_service_defaults_to_in_memory(tmp_path: Path): assert isinstance(service, InMemoryMemoryService) +def test_create_memory_service_supports_memory_uri(tmp_path: Path): + service = service_factory.create_memory_service_from_options( + base_dir=tmp_path, + memory_service_uri="memory://", + ) + + assert isinstance(service, InMemoryMemoryService) + + def test_create_memory_service_raises_on_unknown_scheme( tmp_path: Path, monkeypatch ): From 09ee3c3695b420a30518364d9f3cb18ce0c1f6b6 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Feb 2026 11:59:20 -0800 Subject: [PATCH 008/102] ADK changes PiperOrigin-RevId: 873013637 --- contributing/samples/bigquery/README.md | 4 - pyproject.toml | 1 - .../tools/bigquery/bigquery_credentials.py | 8 +- .../adk/tools/bigquery/bigquery_toolset.py | 2 - src/google/adk/tools/bigquery/client.py | 45 +- src/google/adk/tools/bigquery/search_tool.py | 179 ------- .../tools/bigquery/test_bigquery_client.py | 75 --- .../bigquery/test_bigquery_credentials.py | 16 +- .../bigquery/test_bigquery_search_tool.py | 448 ------------------ .../tools/bigquery/test_bigquery_toolset.py | 3 +- 10 files changed, 13 insertions(+), 768 deletions(-) delete mode 100644 src/google/adk/tools/bigquery/search_tool.py delete mode 100644 tests/unittests/tools/bigquery/test_bigquery_search_tool.py diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index fc3f8610..3ed97432 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -55,9 +55,6 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: `ARIMA_PLUS` model and then querying it with `ML.DETECT_ANOMALIES` to detect time series data anomalies. -11. `search_catalog` - Searches for data entries across projects using the Dataplex Catalog. This allows discovery of datasets, tables, and other assets. - ## How to use Set up environment variables in your `.env` file for using @@ -162,4 +159,3 @@ the necessary access tokens to call BigQuery APIs on their behalf. * which tables exist in the ml_datasets dataset? * show more details about the penguins table * compute penguins population per island. -* are there any tables related to animals in project ? diff --git a/pyproject.toml b/pyproject.toml index a1f136d5..9bec96cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ dependencies = [ "google-cloud-bigquery-storage>=2.0.0", "google-cloud-bigquery>=2.2.0", "google-cloud-bigtable>=2.32.0", # For Bigtable database - "google-cloud-dataplex>=1.7.0,<3.0.0", # For Dataplex Catalog Search tool "google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool "google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index 958ce9d7..fa23c74c 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -19,10 +19,6 @@ from ...features import FeatureName from .._google_credentials import BaseGoogleCredentialsConfig BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" -BIGQUERY_SCOPES = [ - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/dataplex", -] BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] @@ -38,8 +34,8 @@ class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig): super().__post_init__() if not self.scopes: - self.scopes = BIGQUERY_SCOPES - # Set the token cache key + self.scopes = BIGQUERY_DEFAULT_SCOPE + self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY return self diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index dba5f8ee..1a748b71 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -24,7 +24,6 @@ from typing_extensions import override from . import data_insights_tool from . import metadata_tool from . import query_tool -from . import search_tool from ...features import experimental from ...features import FeatureName from ...tools.base_tool import BaseTool @@ -88,7 +87,6 @@ class BigQueryToolset(BaseToolset): query_tool.analyze_contribution, query_tool.detect_anomalies, data_insights_tool.ask_data_insights, - search_tool.search_catalog, ] ] diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index 2cb4e67c..d57c0c80 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -14,22 +14,19 @@ from __future__ import annotations -from typing import List from typing import Optional -from typing import Union import google.api_core.client_info -from google.api_core.gapic_v1 import client_info as gapic_client_info from google.auth.credentials import Credentials from google.cloud import bigquery -from google.cloud import dataplex_v1 from ... import version -USER_AGENT_BASE = f"google-adk/{version.__version__}" -BQ_USER_AGENT = f"adk-bigquery-tool {USER_AGENT_BASE}" -DP_USER_AGENT = f"adk-dataplex-tool {USER_AGENT_BASE}" -USER_AGENT = BQ_USER_AGENT +USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" + + +from typing import List +from typing import Union def get_bigquery_client( @@ -51,7 +48,7 @@ def get_bigquery_client( A BigQuery client. """ - user_agents = [BQ_USER_AGENT] + user_agents = [USER_AGENT] if user_agent: if isinstance(user_agent, str): user_agents.append(user_agent) @@ -70,33 +67,3 @@ def get_bigquery_client( ) return bigquery_client - - -def get_dataplex_catalog_client( - *, - credentials: Credentials, - user_agent: Optional[Union[str, List[str]]] = None, -) -> dataplex_v1.CatalogServiceClient: - """Get a Dataplex CatalogServiceClient with minimal necessary arguments. - - Args: - credentials: The credentials to use for the request. - user_agent: Additional user agent string(s) to append. - - Returns: - A Dataplex Client. - """ - - user_agents = [DP_USER_AGENT] - if user_agent: - if isinstance(user_agent, str): - user_agents.append(user_agent) - else: - user_agents.extend([ua for ua in user_agent if ua]) - - client_info = gapic_client_info.ClientInfo(user_agent=" ".join(user_agents)) - - return dataplex_v1.CatalogServiceClient( - credentials=credentials, - client_info=client_info, - ) diff --git a/src/google/adk/tools/bigquery/search_tool.py b/src/google/adk/tools/bigquery/search_tool.py deleted file mode 100644 index 0bf01d5a..00000000 --- a/src/google/adk/tools/bigquery/search_tool.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import logging -from typing import Any - -from google.api_core import exceptions as api_exceptions -from google.auth.credentials import Credentials -from google.cloud import dataplex_v1 - -from . import client -from .config import BigQueryToolConfig - - -def _construct_search_query_helper( - predicate: str, operator: str, items: list[str] -) -> str: - """Constructs a search query part for a specific predicate and items.""" - if not items: - return "" - - clauses = [f'{predicate}{operator}"{item}"' for item in items] - return "(" + " OR ".join(clauses) + ")" if len(items) > 1 else clauses[0] - - -def search_catalog( - prompt: str, - project_id: str, - *, - credentials: Credentials, - settings: BigQueryToolConfig, - location: str | None = None, - page_size: int = 10, - project_ids_filter: list[str] | None = None, - dataset_ids_filter: list[str] | None = None, - types_filter: list[str] | None = None, -) -> dict[str, Any]: - """Searches for BigQuery assets within Dataplex. - - Args: - prompt: The base search query (natural language or keywords). - project_id: The Google Cloud project ID to scope the search. - credentials: Credentials for the request. - settings: BigQuery tool settings. - location: The Dataplex location to use. - page_size: Maximum number of results. - project_ids_filter: Specific project IDs to include in the search results. - If None, defaults to the scoping project_id. - dataset_ids_filter: BigQuery dataset IDs to filter by. - types_filter: Entry types to filter by (e.g., BigQueryEntryType.TABLE, - BigQueryEntryType.DATASET). - - Returns: - Search results or error. The "results" list contains items with: - - name: The Dataplex Entry name (e.g., - "projects/p/locations/l/entryGroups/g/entries/e"). - - linked_resource: The underlying BigQuery resource name (e.g., - "//bigquery.googleapis.com/projects/p/datasets/d/tables/t"). - - display_name, entry_type, description, location, update_time. - - Examples: - Search for tables related to customer data: - - >>> search_catalog( - ... prompt="Search for tables related to customer data", - ... project_id="my-project", - ... credentials=creds, - ... settings=settings - ... ) - { - "status": "SUCCESS", - "results": [ - { - "name": - "projects/my-project/locations/us/entryGroups/@bigquery/entries/entry-id", - "display_name": "customer_table", - "entry_type": - "projects/p/locations/l/entryTypes/bigquery-table", - "linked_resource": - "//bigquery.googleapis.com/projects/my-project/datasets/d/tables/customer_table", - "description": "Table containing customer details.", - "location": "us", - "update_time": "2024-01-01 12:00:00+00:00" - } - ] - } - """ - - try: - if not project_id: - return { - "status": "ERROR", - "error_details": "project_id must be provided.", - } - - with client.get_dataplex_catalog_client( - credentials=credentials, - user_agent=[settings.application_name, "search_catalog"], - ) as dataplex_client: - query_parts = [] - if prompt: - query_parts.append(f"({prompt})") - - # Filter by project IDs - projects_to_filter = ( - project_ids_filter if project_ids_filter else [project_id] - ) - if projects_to_filter: - query_parts.append( - _construct_search_query_helper("projectid", "=", projects_to_filter) - ) - - # Filter by dataset IDs - if dataset_ids_filter: - dataset_resource_filters = [] - for pid in projects_to_filter: - for did in dataset_ids_filter: - dataset_resource_filters.append( - f'linked_resource:"//bigquery.googleapis.com/projects/{pid}/datasets/{did}/*"' - ) - if dataset_resource_filters: - query_parts.append(f"({' OR '.join(dataset_resource_filters)})") - # Filter by entry types - if types_filter: - query_parts.append( - _construct_search_query_helper("type", "=", types_filter) - ) - - # Always scope to BigQuery system - query_parts.append("system=BIGQUERY") - - full_query = " AND ".join(filter(None, query_parts)) - - search_location = location or settings.location or "global" - search_scope = f"projects/{project_id}/locations/{search_location}" - - request = dataplex_v1.SearchEntriesRequest( - name=search_scope, - query=full_query, - page_size=page_size, - semantic_search=True, - ) - - response = dataplex_client.search_entries(request=request) - - results = [] - for result in response.results: - entry = result.dataplex_entry - source = entry.entry_source - results.append({ - "name": entry.name, - "display_name": source.display_name or "", - "entry_type": entry.entry_type, - "update_time": str(entry.update_time), - "linked_resource": source.resource or "", - "description": source.description or "", - "location": source.location or "", - }) - return {"status": "SUCCESS", "results": results} - - except api_exceptions.GoogleAPICallError as e: - logging.exception("search_catalog tool: API call failed") - return {"status": "ERROR", "error_details": f"Dataplex API Error: {e}"} - except Exception as e: - logging.exception("search_catalog tool: Unexpected error") - return {"status": "ERROR", "error_details": repr(e)} diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py index d8d5e726..80a97f8f 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_client.py +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -18,13 +18,9 @@ import os from unittest import mock import google.adk -from google.adk.tools.bigquery.client import DP_USER_AGENT from google.adk.tools.bigquery.client import get_bigquery_client -from google.adk.tools.bigquery.client import get_dataplex_catalog_client -from google.api_core.gapic_v1 import client_info as gapic_client_info import google.auth from google.auth.exceptions import DefaultCredentialsError -from google.cloud import dataplex_v1 from google.cloud.bigquery import client as bigquery_client from google.oauth2.credentials import Credentials @@ -205,74 +201,3 @@ def test_bigquery_client_location_custom(): # Verify that the client has the desired project set assert client.project == "test-gcp-project" assert client.location == "us-central1" - - -# Tests for Dataplex Catalog Client -# ------------------------------------------------------------------------------ - - -# Mock the CatalogServiceClient class directly -@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) -def test_dataplex_client_default(mock_catalog_service_client): - """Test get_dataplex_catalog_client with default user agent.""" - mock_creds = mock.create_autospec(Credentials, instance=True) - - client = get_dataplex_catalog_client(credentials=mock_creds) - - mock_catalog_service_client.assert_called_once() - _, kwargs = mock_catalog_service_client.call_args - - assert kwargs["credentials"] == mock_creds - client_info = kwargs["client_info"] - assert isinstance(client_info, gapic_client_info.ClientInfo) - assert client_info.user_agent == DP_USER_AGENT - - # Ensure the function returns the mock instance - assert client == mock_catalog_service_client.return_value - - -@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) -def test_dataplex_client_custom_user_agent_str(mock_catalog_service_client): - """Test get_dataplex_catalog_client with a custom user agent string.""" - mock_creds = mock.create_autospec(Credentials, instance=True) - custom_ua = "catalog_ua/1.0" - expected_ua = f"{DP_USER_AGENT} {custom_ua}" - - get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua) - - mock_catalog_service_client.assert_called_once() - _, kwargs = mock_catalog_service_client.call_args - client_info = kwargs["client_info"] - assert client_info.user_agent == expected_ua - - -@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) -def test_dataplex_client_custom_user_agent_list(mock_catalog_service_client): - """Test get_dataplex_catalog_client with a custom user agent list.""" - mock_creds = mock.create_autospec(Credentials, instance=True) - custom_ua_list = ["catalog_ua", "catalog_ua_2.0"] - expected_ua = f"{DP_USER_AGENT} {' '.join(custom_ua_list)}" - - get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) - - mock_catalog_service_client.assert_called_once() - _, kwargs = mock_catalog_service_client.call_args - client_info = kwargs["client_info"] - assert client_info.user_agent == expected_ua - - -@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) -def test_dataplex_client_custom_user_agent_list_with_none( - mock_catalog_service_client, -): - """Test get_dataplex_catalog_client with a list containing None.""" - mock_creds = mock.create_autospec(Credentials, instance=True) - custom_ua_list = ["catalog_ua", None, "catalog_ua_2.0"] - expected_ua = f"{DP_USER_AGENT} catalog_ua catalog_ua_2.0" - - get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) - - mock_catalog_service_client.assert_called_once() - _, kwargs = mock_catalog_service_client.call_args - client_info = kwargs["client_info"] - assert client_info.user_agent == expected_ua diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index e2066292..9cf8c9e4 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -44,11 +44,9 @@ class TestBigQueryCredentials: # Verify that the credentials are properly stored and attributes are extracted assert config.credentials == auth_creds + assert config.client_id is None assert config.client_secret is None - assert config.scopes == [ - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/dataplex", - ] + assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] def test_valid_credentials_object_oauth2_credentials(self): """Test that providing valid Credentials object works correctly with @@ -88,10 +86,7 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == [ - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/dataplex", - ] + assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] def test_valid_client_id_secret_pair_w_scope(self): """Test that providing client ID and secret with explicit scopes works. @@ -133,10 +128,7 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == [ - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/dataplex", - ] + assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] def test_missing_client_secret_raises_error(self): """Test that missing client secret raises appropriate validation error. diff --git a/tests/unittests/tools/bigquery/test_bigquery_search_tool.py b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py deleted file mode 100644 index 0ccdc9e1..00000000 --- a/tests/unittests/tools/bigquery/test_bigquery_search_tool.py +++ /dev/null @@ -1,448 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import sys -from typing import Any -import unittest -from unittest import mock - -from absl.testing import parameterized - -# Mock google.genai and pydantic if not available, before importing google.adk modules -try: - import google.genai -except ImportError: - m = mock.MagicMock() - m.__path__ = [] - sys.modules["google.genai"] = m - sys.modules["google.genai.types"] = mock.MagicMock() - sys.modules["google.genai.errors"] = mock.MagicMock() - -try: - import pydantic -except ImportError: - m_pydantic = mock.MagicMock() - - class MockBaseModel: - pass - - m_pydantic.BaseModel = MockBaseModel - sys.modules["pydantic"] = m_pydantic - -try: - import fastapi - import fastapi.openapi.models -except ImportError: - m_fastapi = mock.MagicMock() - m_fastapi.openapi.models = mock.MagicMock() - sys.modules["fastapi"] = m_fastapi - sys.modules["fastapi.openapi"] = mock.MagicMock() - sys.modules["fastapi.openapi.models"] = mock.MagicMock() - - -from google.adk.tools.bigquery import search_tool -from google.adk.tools.bigquery.config import BigQueryToolConfig -from google.api_core import exceptions as api_exceptions -from google.auth.credentials import Credentials -from google.cloud import dataplex_v1 - - -def _mock_creds(): - return mock.create_autospec(Credentials, instance=True) - - -def _mock_settings(app_name: str | None = "test-app"): - return BigQueryToolConfig(application_name=app_name) - - -def _mock_search_entries_response(results: list[dict[str, Any]]): - mock_response = mock.MagicMock(spec=dataplex_v1.SearchEntriesResponse) - mock_results = [] - for r in results: - mock_result = mock.create_autospec( - dataplex_v1.SearchEntriesResult, instance=True - ) - # Manually attach dataplex_entry since it's not visible in dir() of the proto class - mock_entry = mock.create_autospec(dataplex_v1.Entry, instance=True) - mock_result.dataplex_entry = mock_entry - - mock_entry.name = r.get("name") - mock_entry.entry_type = r.get("entry_type") - mock_entry.update_time = r.get("update_time", "2026-01-14T05:00:00Z") - - # Manually attach entry_source since it's not visible in dir() of the proto class - mock_source = mock.create_autospec(dataplex_v1.EntrySource, instance=True) - mock_entry.entry_source = mock_source - - mock_source.display_name = r.get("display_name") - mock_source.resource = r.get("linked_resource") - mock_source.description = r.get("description") - mock_source.location = r.get("location") - mock_results.append(mock_result) - mock_response.results = mock_results - return mock_response - - -class TestSearchCatalog(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.mock_dataplex_client = mock.create_autospec( - dataplex_v1.CatalogServiceClient, instance=True - ) - - # Patch get_dataplex_catalog_client - self.mock_get_dataplex_client = self.enter_context( - mock.patch( - "google.adk.tools.bigquery.client.get_dataplex_catalog_client", - autospec=True, - ) - ) - self.mock_get_dataplex_client.return_value = self.mock_dataplex_client - self.mock_dataplex_client.__enter__.return_value = self.mock_dataplex_client - - # Patch SearchEntriesRequest - self.mock_search_request = self.enter_context( - mock.patch( - "google.cloud.dataplex_v1.SearchEntriesRequest", autospec=True - ) - ) - - def test_search_catalog_success(self): - """Test the successful path of search_catalog.""" - creds = _mock_creds() - settings = _mock_settings() - prompt = "customer data" - project_id = "test-project" - location = "us" - - mock_api_results = [{ - "name": "entry1", - "entry_type": "TABLE", - "display_name": "Cust Table", - "linked_resource": ( - "//bigquery.googleapis.com/projects/p/datasets/d/tables/t1" - ), - "description": "Table 1", - "location": "us", - }] - self.mock_dataplex_client.search_entries.return_value = ( - _mock_search_entries_response(mock_api_results) - ) - - result = search_tool.search_catalog( - prompt=prompt, - project_id=project_id, - credentials=creds, - settings=settings, - location=location, - ) - - with self.subTest("Test result content"): - self.assertEqual(result["status"], "SUCCESS") - self.assertLen(result["results"], 1) - self.assertEqual(result["results"][0]["name"], "entry1") - self.assertEqual(result["results"][0]["display_name"], "Cust Table") - - with self.subTest("Test mock calls"): - self.mock_get_dataplex_client.assert_called_once_with( - credentials=creds, user_agent=["test-app", "search_catalog"] - ) - - expected_query = ( - '(customer data) AND projectid="test-project" AND system=BIGQUERY' - ) - self.mock_search_request.assert_called_once_with( - name=f"projects/{project_id}/locations/us", - query=expected_query, - page_size=10, - semantic_search=True, - ) - self.mock_dataplex_client.search_entries.assert_called_once_with( - request=self.mock_search_request.return_value - ) - - def test_search_catalog_no_project_id(self): - """Test search_catalog with missing project_id.""" - result = search_tool.search_catalog( - prompt="test", - project_id="", - credentials=_mock_creds(), - settings=_mock_settings(), - location="us", - ) - self.assertEqual(result["status"], "ERROR") - self.assertIn("project_id must be provided", result["error_details"]) - self.mock_get_dataplex_client.assert_not_called() - - def test_search_catalog_api_error(self): - """Test search_catalog handling API exceptions.""" - self.mock_dataplex_client.search_entries.side_effect = ( - api_exceptions.BadRequest("Invalid query") - ) - - result = search_tool.search_catalog( - prompt="test", - project_id="test-project", - credentials=_mock_creds(), - settings=_mock_settings(), - location="us", - ) - self.assertEqual(result["status"], "ERROR") - self.assertIn( - "Dataplex API Error: 400 Invalid query", result["error_details"] - ) - - def test_search_catalog_other_exception(self): - """Test search_catalog handling unexpected exceptions.""" - self.mock_get_dataplex_client.side_effect = Exception( - "Something went wrong" - ) - - result = search_tool.search_catalog( - prompt="test", - project_id="test-project", - credentials=_mock_creds(), - settings=_mock_settings(), - location="us", - ) - self.assertEqual(result["status"], "ERROR") - self.assertIn("Something went wrong", result["error_details"]) - - @parameterized.named_parameters( - ("project_filter", "p", ["proj1"], None, None, 'projectid="proj1"'), - ( - "multi_project_filter", - "p", - ["p1", "p2"], - None, - None, - '(projectid="p1" OR projectid="p2")', - ), - ("type_filter", "p", None, None, ["TABLE"], 'type="TABLE"'), - ( - "multi_type_filter", - "p", - None, - None, - ["TABLE", "DATASET"], - '(type="TABLE" OR type="DATASET")', - ), - ( - "project_and_dataset_filters", - "inventory", - ["proj1", "proj2"], - ["dsetA"], - None, - ( - '(projectid="proj1" OR projectid="proj2") AND' - ' (linked_resource:"//bigquery.googleapis.com/projects/proj1/datasets/dsetA/*"' - ' OR linked_resource:"//bigquery.googleapis.com/projects/proj2/datasets/dsetA/*")' - ), - ), - ) - def test_search_catalog_query_construction( - self, prompt, project_ids, dataset_ids, types, expected_query_part - ): - """Test different query constructions based on filters.""" - search_tool.search_catalog( - prompt=prompt, - project_id="test-project", - credentials=_mock_creds(), - settings=_mock_settings(), - location="us", - project_ids_filter=project_ids, - dataset_ids_filter=dataset_ids, - types_filter=types, - ) - - self.mock_search_request.assert_called_once() - _, kwargs = self.mock_search_request.call_args - query = kwargs["query"] - - if prompt: - assert f"({prompt})" in query - assert "system=BIGQUERY" in query - assert expected_query_part in query - - def test_search_catalog_no_app_name(self): - """Test search_catalog when settings.application_name is None.""" - creds = _mock_creds() - settings = _mock_settings(app_name=None) - search_tool.search_catalog( - prompt="test", - project_id="test-project", - credentials=creds, - settings=settings, - location="us", - ) - - self.mock_get_dataplex_client.assert_called_once_with( - credentials=creds, user_agent=[None, "search_catalog"] - ) - - def test_search_catalog_multi_project_filter_semantic(self): - """Test semantic search with a multi-project filter.""" - creds = _mock_creds() - settings = _mock_settings() - prompt = "What datasets store user profiles?" - project_id = "main-project" - project_filters = ["user-data-proj", "shared-infra-proj"] - location = "global" - - self.mock_dataplex_client.search_entries.return_value = ( - _mock_search_entries_response([]) - ) - - search_tool.search_catalog( - prompt=prompt, - project_id=project_id, - credentials=creds, - settings=settings, - location=location, - project_ids_filter=project_filters, - types_filter=["DATASET"], - ) - - expected_query = ( - f"({prompt}) AND " - '(projectid="user-data-proj" OR projectid="shared-infra-proj") AND ' - 'type="DATASET" AND system=BIGQUERY' - ) - self.mock_search_request.assert_called_once_with( - name=f"projects/{project_id}/locations/{location}", - query=expected_query, - page_size=10, - semantic_search=True, - ) - self.mock_dataplex_client.search_entries.assert_called_once() - - def test_search_catalog_natural_language_semantic(self): - """Test natural language prompts with semantic search enabled and check output.""" - creds = _mock_creds() - settings = _mock_settings() - prompt = "Find tables about football matches" - project_id = "sports-analytics" - location = "europe-west1" - - # Mock the results that the API would return for this semantic query - mock_api_results = [ - { - "name": ( - "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb1" - ), - "display_name": "uk_football_premiership", - "entry_type": ( - "projects/655216118709/locations/global/entryTypes/bigquery-table" - ), - "linked_resource": ( - "//bigquery.googleapis.com/projects/sports-analytics/datasets/uk/tables/premiership" - ), - "description": "Stats for UK Premier League matches.", - "location": "europe-west1", - }, - { - "name": ( - "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb2" - ), - "display_name": "serie_a_matches", - "entry_type": ( - "projects/655216118709/locations/global/entryTypes/bigquery-table" - ), - "linked_resource": ( - "//bigquery.googleapis.com/projects/sports-analytics/datasets/italy/tables/serie_a" - ), - "description": "Italian Serie A football results.", - "location": "europe-west1", - }, - ] - self.mock_dataplex_client.search_entries.return_value = ( - _mock_search_entries_response(mock_api_results) - ) - - result = search_tool.search_catalog( - prompt=prompt, - project_id=project_id, - credentials=creds, - settings=settings, - location=location, - ) - - with self.subTest("Query Construction"): - # Assert the request was made as expected - expected_query = ( - f'({prompt}) AND projectid="{project_id}" AND system=BIGQUERY' - ) - self.mock_search_request.assert_called_once_with( - name=f"projects/{project_id}/locations/{location}", - query=expected_query, - page_size=10, - semantic_search=True, - ) - self.mock_dataplex_client.search_entries.assert_called_once() - - with self.subTest("Response Processing"): - # Assert the output is processed correctly - self.assertEqual(result["status"], "SUCCESS") - self.assertLen(result["results"], 2) - self.assertEqual( - result["results"][0]["display_name"], "uk_football_premiership" - ) - self.assertEqual(result["results"][1]["display_name"], "serie_a_matches") - self.assertIn("UK Premier League", result["results"][0]["description"]) - - def test_search_catalog_default_location(self): - """Test search_catalog fallback to global location when None is provided.""" - creds = _mock_creds() - settings = _mock_settings() - # settings.location is None by default - - self.mock_dataplex_client.search_entries.return_value = ( - _mock_search_entries_response([]) - ) - - search_tool.search_catalog( - prompt="test", - project_id="test-project", - credentials=creds, - settings=settings, - ) - - self.mock_search_request.assert_called_once() - _, kwargs = self.mock_search_request.call_args - name_arg = kwargs["name"] - self.assertIn("locations/global", name_arg) - - def test_search_catalog_settings_location(self): - """Test search_catalog uses settings.location when provided.""" - creds = _mock_creds() - settings = BigQueryToolConfig(location="eu") - - self.mock_dataplex_client.search_entries.return_value = ( - _mock_search_entries_response([]) - ) - - search_tool.search_catalog( - prompt="test", - project_id="test-project", - credentials=creds, - settings=settings, - ) - - self.mock_search_request.assert_called_once() - _, kwargs = self.mock_search_request.call_args - name_arg = kwargs["name"] - self.assertIn("locations/eu", name_arg) diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index 0eced4b1..f1f73aa6 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 11 + assert len(tools) == 10 assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ @@ -55,7 +55,6 @@ async def test_bigquery_toolset_tools_default(): "forecast", "analyze_contribution", "detect_anomalies", - "search_catalog", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names From 9c4c44536904f5cf3301a5abb910a5666344a8c5 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Feb 2026 13:26:59 -0800 Subject: [PATCH 009/102] feat: Add /chat/completions integration to ApigeeLlm PiperOrigin-RevId: 873049983 --- src/google/adk/models/apigee_llm.py | 634 +++++++++++++++++- tests/unittests/models/test_apigee_llm.py | 174 ++++- .../models/test_completions_http_client.py | 440 ++++++++++++ 3 files changed, 1235 insertions(+), 13 deletions(-) create mode 100644 tests/unittests/models/test_completions_http_client.py diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index 92f94c75..90a91f32 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -12,21 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations +import asyncio +import atexit +import base64 +import collections.abc +import enum from functools import cached_property +import json import logging import os +from typing import Any +from typing import AsyncGenerator from typing import Optional from typing import TYPE_CHECKING from google.adk import version as adk_version from google.genai import types +import httpx +import tenacity from typing_extensions import override from ..utils.env_utils import is_env_enabled from .google_llm import Gemini +from .llm_response import LlmResponse if TYPE_CHECKING: from google.genai import Client @@ -49,6 +59,20 @@ class ApigeeLlm(Gemini): model: The name of the Gemini model. """ + class ApiType(str, enum.Enum): + """The supported API types for Apigee LLM.""" + + UNKNOWN = 'unknown' + CHAT_COMPLETIONS = 'chat_completions' + GENAI = 'genai' + + @classmethod + def _missing_(cls, value): + # Empty string or None should return UNKNOWN. + if not value: + return cls.UNKNOWN + return super()._missing_(value) + def __init__( self, *, @@ -56,6 +80,7 @@ class ApigeeLlm(Gemini): proxy_url: str | None = None, custom_headers: dict[str, str] | None = None, retry_options: Optional[types.HttpRetryOptions] = None, + api_type: ApiType | str = ApiType.UNKNOWN, ): """Initializes the Apigee LLM backend. @@ -80,19 +105,31 @@ class ApigeeLlm(Gemini): - `apigee/vertex_ai/gemini-2.5-flash` - `apigee/gemini/v1/gemini-2.5-flash` - `apigee/vertex_ai/v1beta/gemini-2.5-flash` - proxy_url: The URL of the Apigee proxy. custom_headers: A dictionary of headers to be sent with the request. + If needed, you can add authorization headers here, for example: + {'Authorization': f'Bearer {API_KEY}'}. ApigeeLlm already handles + authorization headers in Vertex AI and Gemini API calls. retry_options: Allow google-genai to retry failed responses. - """ + api_type: The type of API to use. One of `ApiType` or string. + """ # fmt: skip super().__init__(model=model, retry_options=retry_options) # Validate the model string. Create a helper method to validate the model # string. if not _validate_model_string(model): raise ValueError(f'Invalid model string: {model}') - - self._isvertexai = _identify_vertexai(model) + if isinstance(api_type, str): + api_type = ApigeeLlm.ApiType(api_type) + if api_type and api_type != ApigeeLlm.ApiType.UNKNOWN: + self._api_type = api_type + elif model.startswith(('apigee/gemini/', 'apigee/vertex_ai/')): + self._api_type = ApigeeLlm.ApiType.GENAI + elif model.startswith('apigee/openai/'): + self._api_type = ApigeeLlm.ApiType.CHAT_COMPLETIONS + else: + self._api_type = ApigeeLlm.ApiType.GENAI + self._isvertexai = _identify_vertexai(model, self._api_type) # Set the project and location for Vertex AI. if self._isvertexai: @@ -131,6 +168,42 @@ class ApigeeLlm(Gemini): r'apigee\/.*', ] + @cached_property + def _completions_http_client(self) -> CompletionsHTTPClient: + """Provides the completions HTTP client.""" + return CompletionsHTTPClient( + base_url=self._proxy_url, + headers=self._merge_tracking_headers(self._custom_headers), + retry_options=self.retry_options, + ) + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + if self._api_type == ApigeeLlm.ApiType.CHAT_COMPLETIONS: + await self._preprocess_other_requests(llm_request) + async for ( + response + ) in self._completions_http_client.generate_content_async( + llm_request, stream + ): + yield response + else: + async for response in super().generate_content_async(llm_request, stream): + yield response + + async def _preprocess_other_requests(self, llm_request: LlmRequest) -> None: + """Preprocesses the request for non-Gemini/Vertex AI models.""" + llm_request.model = _get_model_id(llm_request.model) + if llm_request.config and llm_request.config.tools: + # Check if computer use is configured + for tool in llm_request.config.tools: + if isinstance(tool, types.Tool) and tool.computer_use: + llm_request.config.system_instruction = None + await self._adapt_computer_use_tool(llm_request) + self._maybe_append_user_content(llm_request) + @cached_property def api_client(self) -> Client: """Provides the api client. @@ -167,11 +240,25 @@ class ApigeeLlm(Gemini): await super()._preprocess_request(llm_request) -def _identify_vertexai(model: str) -> bool: - """Returns True if the model spec starts with apigee/vertex_ai.""" - return not model.startswith('apigee/gemini/') and ( - model.startswith('apigee/vertex_ai/') - or is_env_enabled(_GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME) +def _identify_vertexai(model: str, api_type: ApigeeLlm.ApiType) -> bool: + """Returns if a model is Vertex AI. + + 1. The api_type is GENAI or UNKNOWN. + 2. The model is provider is Vertex AI model or the + GOOGLE_GENAI_USE_VERTEXAI environment variable is set to TRUE or 1. + + Args: + model: The model string. + api_type: The type of API to use. + """ + if api_type not in (ApigeeLlm.ApiType.GENAI, ApigeeLlm.ApiType.UNKNOWN): + return False + if model.startswith('apigee/gemini/'): + return False + if model.startswith('apigee/openai/'): + return False + return model.startswith('apigee/vertex_ai/') or is_env_enabled( + _GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME ) @@ -240,7 +327,7 @@ def _validate_model_string(model: str) -> bool: # and model_id are present. This is a valid format. if len(components) == 3: # Format: // - if components[0] not in ('vertex_ai', 'gemini'): + if components[0] not in ('vertex_ai', 'gemini', 'openai'): return False if not components[1].startswith('v'): return False @@ -249,10 +336,533 @@ def _validate_model_string(model: str) -> bool: # If the model string has 2 components, it means either the provider or the # version (but not both), and model_id are present. if len(components) == 2: - if components[0] in ['vertex_ai', 'gemini']: + if components[0] in ['vertex_ai', 'gemini', 'openai']: return True if components[0].startswith('v'): return True return False return False + + +class CompletionsHTTPClient: + """A generic HTTP client for completions, compatible with OpenAI API.""" + + def __init__( + self, + base_url: str, + headers: dict[str, str] | None = None, + retry_options: Optional[types.HttpRetryOptions] = None, + ): + self._base_url = base_url + self._headers = headers or {} + self.retry_options = retry_options + + def __del__(self) -> None: + self.close() + + @cached_property + def _client(self) -> httpx.AsyncClient: + """Provides the httpx client.""" + client = httpx.AsyncClient( + base_url=self._base_url, + headers=self._headers, + timeout=None, + follow_redirects=True, + ) + atexit.register(self._cleanup_client, client) + return client + + @staticmethod + def _cleanup_client(client: httpx.AsyncClient) -> None: + """Cleans up the httpx client.""" + if client.is_closed: + return + try: + loop = asyncio.get_running_loop() + loop.create_task(client.aclose()) + except RuntimeError: + try: + # This fails if aynscio.run is already called in main and is being closed. + asyncio.run(client.aclose()) + except RuntimeError: + pass + + def close(self) -> None: + if '_client' not in self.__dict__: + return + self._cleanup_client(self._client) + + async def aclose(self) -> None: + if '_client' not in self.__dict__: + return + if self._client.is_closed: + return + await self._client.aclose() + + def _get_retry_kwargs(self) -> dict[str, Any]: + """Returns the retry kwargs for tenacity.""" + if not self.retry_options: + return {'stop': tenacity.stop_after_attempt(1), 'reraise': True} + + default_attempts = 5 + default_initial_delay = 1.0 + default_max_delay = 60.0 + default_exp_base = 2 + default_jitter = 1 + default_status_codes = (408, 429, 500, 502, 503, 504) + + opts = self.retry_options + stop = tenacity.stop_after_attempt( + opts.attempts if opts.attempts is not None else default_attempts + ) + + retriable_codes = ( + opts.http_status_codes + if opts.http_status_codes is not None + else default_status_codes + ) + + retry_network = tenacity.retry_if_exception_type(httpx.NetworkError) + + def is_retriable(e: Exception) -> bool: + if isinstance(e, httpx.HTTPStatusError): + return e.response.status_code in retriable_codes + return False + + retry_status = tenacity.retry_if_exception(is_retriable) + + wait = tenacity.wait_exponential_jitter( + initial=( + opts.initial_delay + if opts.initial_delay is not None + else default_initial_delay + ), + max=( + opts.max_delay if opts.max_delay is not None else default_max_delay + ), + exp_base=( + opts.exp_base if opts.exp_base is not None else default_exp_base + ), + jitter=opts.jitter if opts.jitter is not None else default_jitter, + ) + + return { + 'stop': stop, + 'retry': tenacity.retry_any(retry_network, retry_status), + 'reraise': True, + 'wait': wait, + } + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool + ) -> AsyncGenerator[LlmResponse, None]: + """Generates content using the OpenAI-compatible HTTP API.""" + payload = self._construct_payload(llm_request, stream) + headers = self._headers.copy() + headers['Content-Type'] = 'application/json' + + url = self._base_url + if not url: + raise ValueError('Base URL is not set.') + + if not url.endswith('/chat/completions'): + url = f"{url.rstrip('/')}/chat/completions" + + if stream: + raise NotImplementedError('Streaming is not supported yet.') + else: + response = await self._httpx_post_with_retry(url, payload, headers) + data = response.json() + yield self._parse_response(data) + + async def _httpx_post_with_retry( + self, url: str, payload: dict[str, Any], headers: dict[str, str] + ) -> httpx.Response: + """Sends a POST request and handles retries.""" + retry_kwargs = self._get_retry_kwargs() + async for attempt in tenacity.AsyncRetrying(**retry_kwargs): + with attempt: + response = await self._client.post(url, json=payload, headers=headers) + response.raise_for_status() + return response + + async def _handle_streaming_response( + self, response: httpx.Response + ) -> AsyncGenerator[LlmResponse, None]: + """Handles streaming response from OpenAI-compatible API.""" + raise NotImplementedError('Streaming is not supported yet.') + + def _construct_payload( + self, llm_request: LlmRequest, stream: bool + ) -> dict[str, Any]: + """Constructs the payload from the LlmRequest.""" + messages = [] + if llm_request.config and llm_request.config.system_instruction: + content = self._serialize_system_instruction( + llm_request.config.system_instruction + ) + if content: + messages.append({ + 'role': 'system', + 'content': content, + }) + + for content in llm_request.contents: + messages += self._content_to_messages(content) + + payload = { + 'model': _get_model_id(llm_request.model), + 'messages': messages, + 'stream': stream, + } + + if llm_request.config: + self._map_config_parameters(llm_request.config, payload) + self._map_tools(llm_request.config, payload) + + return payload + + def _map_config_parameters( + self, config: types.GenerateContentConfig, payload: dict[str, Any] + ) -> None: + """Maps configuration parameters to the payload.""" + if config.temperature is not None: + payload['temperature'] = config.temperature + if config.top_p is not None: + payload['top_p'] = config.top_p + if config.max_output_tokens is not None: + payload['max_tokens'] = config.max_output_tokens + if config.stop_sequences: + payload['stop'] = config.stop_sequences + if config.frequency_penalty is not None: + payload['frequency_penalty'] = config.frequency_penalty + if config.presence_penalty is not None: + payload['presence_penalty'] = config.presence_penalty + if config.seed is not None: + payload['seed'] = config.seed + if config.candidate_count is not None: + payload['n'] = config.candidate_count + if config.response_logprobs: + payload['logprobs'] = True + if config.logprobs is not None: + payload['top_logprobs'] = config.logprobs + + if config.response_json_schema: + payload['response_format'] = { + 'type': 'json_schema', + 'json_schema': config.response_json_schema, + } + elif config.response_mime_type == 'application/json': + payload['response_format'] = {'type': 'json_object'} + + def _map_tools( + self, config: types.GenerateContentConfig, payload: dict[str, Any] + ) -> None: + """Maps tools and tool configuration to the payload.""" + if config.tools: + tools = [] + for tool in config.tools: + if tool.function_declarations: + for func in tool.function_declarations: + tools.append(self._function_declaration_to_tool(func)) + if tools: + payload['tools'] = tools + if config.tool_config and config.tool_config.function_calling_config: + mode = config.tool_config.function_calling_config.mode + if mode == types.FunctionCallingConfigMode.ANY: + payload['tool_choice'] = 'required' + elif mode == types.FunctionCallingConfigMode.NONE: + payload['tool_choice'] = 'none' + elif mode == types.FunctionCallingConfigMode.AUTO: + payload['tool_choice'] = 'auto' + + def _content_to_messages( + self, content: types.Content + ) -> list[dict[str, Any]]: + """Converts a Content object to /chat/completions messages.""" + role = content.role + if role == 'model': + role = 'assistant' + + tool_calls = [] + content_parts = [] + + function_responses = [] + + for part in content.parts or []: + self._process_content_part(content, part, tool_calls, content_parts) + if part.function_response: + function_responses.append({ + 'role': 'tool', + 'tool_call_id': part.function_response.id, + 'content': json.dumps(part.function_response.response), + }) + if function_responses: + return function_responses + + message = {'role': role} + if tool_calls: + message['tool_calls'] = tool_calls + if not content_parts: + message['content'] = None + + if content_parts: + if len(content_parts) == 1 and content_parts[0]['type'] == 'text': + message['content'] = content_parts[0]['text'] + else: + message['content'] = content_parts + return [message] + + def _process_content_part( + self, + content: types.Content, + part: types.Part, + tool_calls: list[dict[str, Any]], + content_parts: list[dict[str, Any]], + ) -> None: + """Processes a single Part and updates tool_calls or content_parts.""" + if content.role != 'user' and ( + part.inline_data + or ( + part.file_data + and part.file_data.mime_type + and part.file_data.mime_type.startswith('image') + ) + ): + logger.warning('Image data is not supported for assistant turns.') + return + + if part.function_call: + tool_call = { + 'id': part.function_call.id or 'call_' + part.function_call.name, + 'type': 'function', + 'function': { + 'name': part.function_call.name, + 'arguments': ( + json.dumps(part.function_call.args) + if part.function_call.args + else '{}' + ), + }, + } + if part.thought_signature: + sig = part.thought_signature + if isinstance(sig, bytes): + sig = base64.b64encode(sig).decode('utf-8') + tool_call['extra_content'] = { + 'google': { + 'thought_signature': sig, + }, + } + tool_calls.append(tool_call) + elif part.function_response: + # Handled in the loop to return immediately + pass + elif part.text: + content_parts.append({'type': 'text', 'text': part.text}) + elif part.inline_data: + mime_type = part.inline_data.mime_type + data = base64.b64encode(part.inline_data.data).decode('utf-8') + url = f'data:{mime_type};base64,{data}' + content_parts.append({'type': 'image_url', 'image_url': {'url': url}}) + elif part.file_data: + if part.file_data.file_uri: + content_parts.append({ + 'type': 'image_url', + 'image_url': {'url': part.file_data.file_uri}, + }) + elif part.executable_code: + logger.warning( + 'Executable code is not supported in the standard Chat Completions' + ' API.' + ) + elif part.code_execution_result: + logger.warning( + 'Code execution result is not supported in the standard Chat' + ' Completions API.' + ) + + def _function_declaration_to_tool( + self, func: types.FunctionDeclaration + ) -> dict[str, Any]: + """Converts a FunctionDeclaration to an OpenAI tool dictionary.""" + parameters = {} + if func.parameters_json_schema: + parameters = func.parameters_json_schema + elif func.parameters: + parameters = func.parameters.model_dump(exclude_none=True) + + return { + 'type': 'function', + 'function': { + 'name': func.name, + 'description': func.description, + 'parameters': parameters, + }, + } + + def _serialize_system_instruction( + self, system_instruction: Optional[types.ContentUnion] + ) -> str | None: + """Serializes system instruction to a string from ContentUnion type.""" + if not system_instruction: + return None + if isinstance(system_instruction, str): + return system_instruction + if isinstance(system_instruction, types.Part): + return system_instruction.text + if isinstance(system_instruction, types.Content): + return ''.join( + part.text for part in system_instruction.parts if part.text + ) + if isinstance(system_instruction, dict): + part = types.Part(**system_instruction) + return part.text + if isinstance(system_instruction, collections.abc.Iterable): + parts = [] + for item in system_instruction: + if isinstance(item, str): + parts.append(types.Part(text=item)) + elif isinstance(item, types.Part): + parts.append(item) + elif isinstance(item, dict): + parts.append(types.Part(**item)) + return ''.join(part.text for part in parts if part.text) + return None + + def _parse_logprobs( + self, logprobs_data: dict[str, Any] | None + ) -> types.LogprobsResult | None: + """Parses OpenAI logprobs data into LogprobsResult.""" + if not logprobs_data or 'content' not in logprobs_data: + return None + + chosen_candidates = [] + top_candidates = [] + + for item in logprobs_data['content']: + chosen_candidates.append( + types.LogprobsResultCandidate( + token=item.get('token'), + log_probability=item.get('logprob'), + # OpenAI text format usually doesn't expose ID easily here + token_id=None, + ) + ) + + if 'top_logprobs' in item: + current_top_candidates = [] + for top_item in item['top_logprobs']: + current_top_candidates.append( + types.LogprobsResultCandidate( + token=top_item.get('token'), + log_probability=top_item.get('logprob'), + token_id=None, + ) + ) + top_candidates.append( + types.LogprobsResultTopCandidates(candidates=current_top_candidates) + ) + + return types.LogprobsResult( + chosen_candidates=chosen_candidates, top_candidates=top_candidates + ) + + def _parse_response(self, response: dict[str, Any]) -> LlmResponse: + """Parses an OpenAI response dictionary into an LlmResponse.""" + choices = response.get('choices', []) + if not choices: + return LlmResponse() + + choice = choices[0] + message = choice.get('message', {}) + role = message.get('role', 'model') + if role == 'assistant': + role = 'model' + + parts = [] + content_str = message.get('content') + if content_str: + parts.append(types.Part.from_text(text=content_str)) + + tool_calls = message.get('tool_calls') + if tool_calls: + for tool_call in tool_calls: + call_type = tool_call.get('type', 'unknown') + # TODO: Add support for 'custom' type. + if call_type != 'function': + raise ValueError( + f'Unsupported tool_call type: {call_type} in call {tool_call}' + ) + func = tool_call.get('function', {}) + part = self._parse_function_call(func) + parts.append(part) + + function_call = message.get('function_call') + if function_call: + part = self._parse_function_call(function_call) + parts.append(part) + + usage = response.get('usage', {}) + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=usage.get('prompt_tokens', 0), + candidates_token_count=usage.get('completion_tokens', 0), + total_token_count=usage.get('total_tokens', 0), + ) + + logprobs_result = self._parse_logprobs(choice.get('logprobs')) + + custom_metadata = { + 'id': response.get('id'), + 'created': response.get('created'), + 'model': response.get('model'), + 'system_fingerprint': response.get('system_fingerprint'), + 'service_tier': response.get('service_tier'), + } + custom_metadata = { + k: v for k, v in custom_metadata.items() if v is not None + } + + return LlmResponse( + content=types.Content(role=role, parts=parts), + usage_metadata=usage_metadata, + finish_reason=self._map_finish_reason(choice.get('finish_reason')), + logprobs_result=logprobs_result, + model_version=response.get('model'), + custom_metadata=custom_metadata, + ) + + def _map_finish_reason(self, reason: str | None) -> types.FinishReason: + if reason == 'stop': + return types.FinishReason.STOP + if reason == 'length': + return types.FinishReason.MAX_TOKENS + if reason == 'tool_calls': + return types.FinishReason.STOP + if reason == 'content_filter': + return types.FinishReason.SAFETY + return types.FinishReason.FINISH_REASON_UNSPECIFIED + + def _parse_function_call(self, func: dict[str, Any]) -> types.Part: + """Parses a function call dictionary into a Part.""" + name = func.get('name') + args_str = func.get('arguments', '{}') + try: + args = json.loads(args_str) + except json.JSONDecodeError: + args = {} + tool_part = types.Part.from_function_call(name=name, args=args) + if tool_part.function_call: + tool_part.function_call.id = func.get('id', None) + # Add support for gemini's thought_signature. + thought_signature = ( + func.get('extra_content', {}) + .get('google', {}) + .get('thought_signature', '') + ) + if thought_signature: + if isinstance(thought_signature, str): + thought_signature = base64.b64decode(thought_signature) + tool_part.thought_signature = thought_signature + return tool_part diff --git a/tests/unittests/models/test_apigee_llm.py b/tests/unittests/models/test_apigee_llm.py index 67894ea8..c57bc9fc 100644 --- a/tests/unittests/models/test_apigee_llm.py +++ b/tests/unittests/models/test_apigee_llm.py @@ -19,6 +19,7 @@ from unittest import mock from unittest.mock import AsyncMock from google.adk.models.apigee_llm import ApigeeLlm +from google.adk.models.apigee_llm import CompletionsHTTPClient from google.adk.models.llm_request import LlmRequest from google.genai import types from google.genai.types import Content @@ -441,7 +442,6 @@ async def test_model_string_parsing_and_client_initialization( @pytest.mark.parametrize( 'invalid_model_string', [ - 'apigee/openai/v1/gpt', 'apigee/', # Missing model_id 'apigee', # Invalid format 'gemini-pro', # Invalid format @@ -455,3 +455,175 @@ async def test_invalid_model_strings_raise_value_error(invalid_model_string): ValueError, match=f'Invalid model string: {invalid_model_string}' ): ApigeeLlm(model=invalid_model_string, proxy_url=PROXY_URL) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'model', + [ + 'apigee/openai/gpt-4o', + 'apigee/openai/v1/gpt-4o', + 'apigee/openai/v1/gpt-3.5-turbo', + ], +) +async def test_validate_model_for_chat_completion_providers(model): + """Tests that new providers like OpenAI are accepted.""" + # Should not raise ValueError + ApigeeLlm(model=model, proxy_url=PROXY_URL) + + +@pytest.mark.parametrize( + ('model', 'api_type', 'expected_api_type'), + [ + # Default case (input defaults to UNKNOWN) + ( + 'apigee/openai/gpt-4o', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/openai/v1/gpt-3.5-turbo', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/gemini/v1/gemini-pro', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.GENAI, + ), + ( + 'apigee/vertex_ai/gemini-pro', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.GENAI, + ), + ( + 'apigee/vertex_ai/v1beta/gemini-1.5-pro', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.GENAI, + ), + # Override by setting the ApiType + ( + 'apigee/gemini/pro', + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/gemini/pro', + ApigeeLlm.ApiType.GENAI, + ApigeeLlm.ApiType.GENAI, + ), + ( + 'apigee/openai/gpt-4o', + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/openai/gpt-4o', + ApigeeLlm.ApiType.GENAI, + ApigeeLlm.ApiType.GENAI, + ), + # Override by setting the ApiType as a string + ( + 'apigee/gemini/pro', + 'chat_completions', + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/gemini/pro', + 'genai', + ApigeeLlm.ApiType.GENAI, + ), + ( + 'apigee/openai/gpt-4o', + 'chat_completions', + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/openai/gpt-4o', + 'genai', + ApigeeLlm.ApiType.GENAI, + ), + ], +) +def test_api_type_resolution(model, api_type, expected_api_type): + """Tests that api_type is resolved correctly.""" + llm = ApigeeLlm( + model=model, + proxy_url=PROXY_URL, + api_type=api_type, + ) + assert llm._api_type == expected_api_type + + +@pytest.mark.parametrize( + ('input_value', 'expected_type'), + [ + ('chat_completions', ApigeeLlm.ApiType.CHAT_COMPLETIONS), + ('genai', ApigeeLlm.ApiType.GENAI), + ('unknown', ApigeeLlm.ApiType.UNKNOWN), + ('', ApigeeLlm.ApiType.UNKNOWN), + (None, ApigeeLlm.ApiType.UNKNOWN), + ], +) +def test_apitype_creation(input_value, expected_type): + """Tests the creation of ApiType enum members.""" + assert ApigeeLlm.ApiType(input_value) == expected_type + + +def test_apitype_creation_invalid(): + """Tests that invalid ApiType raises ValueError.""" + with pytest.raises(ValueError): + ApigeeLlm.ApiType('invalid') + + +def test_invalid_api_type_raises_error(): + """Tests that invalid string for api_type raises ValueError.""" + with pytest.raises(ValueError): + ApigeeLlm( + model='apigee/gemini-pro', + proxy_url=PROXY_URL, + api_type='invalid_type', + ) + + +@pytest.mark.asyncio +async def test_generate_content_async_dispatch_to_completions_client( + llm_request, +): + """Tests that generate_content_async uses CompletionsHTTPClient for OpenAI models.""" + llm_request.model = 'apigee/openai/gpt-4o' + with ( + mock.patch.object( + CompletionsHTTPClient, + 'generate_content_async', + ) as mock_completions_generate_content, + mock.patch('google.genai.Client') as mock_genai_client, + ): + apigee_llm = ApigeeLlm(model='apigee/openai/gpt-4o', proxy_url=PROXY_URL) + _ = [ + r + async for r in apigee_llm.generate_content_async( + llm_request, stream=False + ) + ] + mock_completions_generate_content.assert_called_once() + mock_genai_client.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'model', + [ + 'apigee/openai/gpt-4o', + 'apigee/openai/v1/gpt-3.5-turbo', + ], +) +async def test_api_key_injection_openai(model): + """Tests that api_key is injected for OpenAI models.""" + apigee_llm = ApigeeLlm( + model=model, + proxy_url=PROXY_URL, + custom_headers={'Authorization': 'Bearer sk-test-key'}, + ) + client = apigee_llm._completions_http_client + assert client._headers['Authorization'] == 'Bearer sk-test-key' diff --git a/tests/unittests/models/test_completions_http_client.py b/tests/unittests/models/test_completions_http_client.py new file mode 100644 index 00000000..f16376d7 --- /dev/null +++ b/tests/unittests/models/test_completions_http_client.py @@ -0,0 +1,440 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock +from unittest.mock import AsyncMock + +from google.adk.models.apigee_llm import CompletionsHTTPClient +from google.adk.models.llm_request import LlmRequest +from google.genai import types +import httpx +import pytest + + +@pytest.fixture +def client(): + return CompletionsHTTPClient(base_url='https://example.com') + + +@pytest.fixture(name='llm_request') +def fixture_llm_request(): + return LlmRequest( + model='apigee/open_llama', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='Hello')]) + ], + ) + + +@pytest.mark.asyncio +async def test_construct_payload_basic_payload(client, llm_request): + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': 'Hi'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + call_args = mock_post.call_args + url = call_args[0][0] + kwargs = call_args[1] + + assert url == 'https://example.com/chat/completions' + payload = kwargs['json'] + assert payload['model'] == 'open_llama' + assert payload['stream'] is False + assert len(payload['messages']) == 1 + assert payload['messages'][0]['role'] == 'user' + assert payload['messages'][0]['content'] == 'Hello' + + +@pytest.mark.asyncio +async def test_construct_payload_with_config(client, llm_request): + llm_request.config = types.GenerateContentConfig( + temperature=0.7, + top_p=0.9, + max_output_tokens=100, + stop_sequences=['STOP'], + frequency_penalty=0.5, + presence_penalty=0.5, + seed=42, + candidate_count=2, + response_mime_type='application/json', + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': 'Hi'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + + assert payload['temperature'] == 0.7 + assert payload['top_p'] == 0.9 + assert payload['max_tokens'] == 100 + assert payload['stop'] == ['STOP'] + assert payload['frequency_penalty'] == 0.5 + assert payload['presence_penalty'] == 0.5 + assert payload['seed'] == 42 + assert payload['n'] == 2 + assert payload['response_format'] == {'type': 'json_object'} + + +@pytest.mark.asyncio +async def test_construct_payload_with_tools(client, llm_request): + tool = types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name='get_weather', + description='Get weather', + parameters=types.Schema( + type=types.Type.OBJECT, + properties={'location': types.Schema(type=types.Type.STRING)}, + ), + ) + ] + ) + llm_request.config = types.GenerateContentConfig(tools=[tool]) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': 'Hi'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + assert 'tools' in payload + assert payload['tools'][0]['function']['name'] == 'get_weather' + + +@pytest.mark.asyncio +async def test_construct_payload_system_instruction(client, llm_request): + llm_request.config = types.GenerateContentConfig( + system_instruction='You are a helpful assistant.' + ) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': 'Hi'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + payload = mock_post.call_args[1]['json'] + assert payload['messages'][0]['role'] == 'system' + assert payload['messages'][0]['content'] == 'You are a helpful assistant.' + # Ensure user message follows system + assert payload['messages'][1]['role'] == 'user' + + +@pytest.mark.asyncio +async def test_construct_payload_multimodal_content(client): + # Mock inline_data for image + image_data = b'fake_image_bytes' + llm_request = LlmRequest( + model='apigee/open_llama', + contents=[ + types.Content( + role='user', + parts=[ + types.Part.from_text(text='What is this?'), + types.Part.from_bytes( + data=image_data, mime_type='image/jpeg' + ), + ], + ) + ], + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [ + {'message': {'role': 'assistant', 'content': 'It is an image'}} + ] + } + + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + assert len(payload['messages']) == 1 + message = payload['messages'][0] + assert message['role'] == 'user' + assert isinstance(message['content'], list) + assert len(message['content']) == 2 + assert message['content'][0] == {'type': 'text', 'text': 'What is this?'} + assert message['content'][1]['type'] == 'image_url' + # Base64 encoding of b'fake_image_bytes' is 'ZmFrZV9pbWFnZV9ieXRlcw==' + assert message['content'][1]['image_url']['url'] == ( + 'data:image/jpeg;base64,ZmFrZV9pbWFnZV9ieXRlcw==' + ) + + +@pytest.mark.asyncio +async def test_construct_payload_image_file_uri(client): + llm_request = LlmRequest( + model='apigee/open_llama', + contents=[ + types.Content( + role='user', + parts=[ + types.Part.from_uri( + file_uri='https://example.com/image.jpg', + mime_type='image/jpeg', + ) + ], + ) + ], + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [ + {'message': {'role': 'assistant', 'content': 'It is an image'}} + ] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + assert len(payload['messages']) == 1 + message = payload['messages'][0] + assert message['role'] == 'user' + assert isinstance(message['content'], list) + assert message['content'][0] == { + 'type': 'image_url', + 'image_url': {'url': 'https://example.com/image.jpg'}, + } + + +@pytest.mark.asyncio +async def test_generate_content_async_function_call_response( + client, llm_request +): + # Mock response with tool call + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{ + 'message': { + 'role': 'assistant', + 'content': None, + 'tool_calls': [{ + 'id': 'call_123', + 'type': 'function', + 'function': { + 'name': 'get_weather', + 'arguments': '{"location": "London"}', + }, + }], + } + }] + } + mock_response.status_code = 200 + + with mock.patch.object(httpx.AsyncClient, 'post', return_value=mock_response): + responses = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + part = responses[0].content.parts[0] + assert part.function_call + assert part.function_call.name == 'get_weather' + assert part.function_call.args == {'location': 'London'} + assert part.function_call.id == 'call_123' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ('response_json_schema', 'response_mime_type', 'expected_response_format'), + [ + # Case 1: Only response_json_schema is provided + ( + {'type': 'object', 'properties': {'name': {'type': 'string'}}}, + None, + { + 'type': 'json_schema', + 'json_schema': { + 'type': 'object', + 'properties': {'name': {'type': 'string'}}, + }, + }, + ), + # Case 2: Both provided, schema takes precedence + ( + {'type': 'object', 'properties': {'name': {'type': 'string'}}}, + 'application/json', + { + 'type': 'json_schema', + 'json_schema': { + 'type': 'object', + 'properties': {'name': {'type': 'string'}}, + }, + }, + ), + # Case 3: Only response_mime_type is provided + ( + None, + 'application/json', + {'type': 'json_object'}, + ), + ], +) +async def test_construct_payload_response_format( + client, + llm_request, + response_json_schema, + response_mime_type, + expected_response_format, +): + llm_request.config = types.GenerateContentConfig( + response_json_schema=response_json_schema, + response_mime_type=response_mime_type, + ) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': '{}'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + assert payload['response_format'] == expected_response_format + + +@pytest.mark.asyncio +async def test_generate_content_async_invalid_tool_call_type_raises_error( + client, llm_request +): + # Mock response with invalid tool call type + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{ + 'message': { + 'role': 'assistant', + 'content': None, + 'tool_calls': [{ + 'id': 'call_123', + # Invalid type + 'type': 'custom', + 'custom': { + 'name': 'read_string', + 'input': 'Hi! The this is a custom tool call!', + }, + }], + } + }] + } + mock_response.status_code = 200 + + with mock.patch.object(httpx.AsyncClient, 'post', return_value=mock_response): + with pytest.raises(ValueError, match='Unsupported tool_call type: custom'): + _ = [ + r + async for r in client.generate_content_async( + llm_request, stream=False + ) + ] + + +@pytest.mark.asyncio +async def test_generate_content_async_function_call_response( + client, llm_request +): + # Mock response with deprecated function call + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{ + 'message': { + 'role': 'assistant', + 'content': None, + 'function_call': { + 'name': 'get_weather', + 'arguments': '{"location": "London"}', + }, + } + }] + } + mock_response.status_code = 200 + + with mock.patch.object(httpx.AsyncClient, 'post', return_value=mock_response): + responses = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + part = responses[0].content.parts[0] + assert part.function_call + assert part.function_call.name == 'get_weather' + assert part.function_call.args == {'location': 'London'} + assert part.function_call.id is None From 77df6d8db77d1dbaf730723bb5fa5f83fa36bcb8 Mon Sep 17 00:00:00 2001 From: Liang Wu Date: Fri, 20 Feb 2026 14:22:25 -0800 Subject: [PATCH 010/102] ci: only keep `--extra test` in GitHub unit test workflow Co-authored-by: Liang Wu PiperOrigin-RevId: 873072872 --- .github/workflows/python-unit-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 8689f1a1..866ba8b3 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -43,7 +43,7 @@ jobs: run: | uv venv .venv source .venv/bin/activate - uv sync --extra test --extra eval --extra a2a + uv sync --extra test - name: Run unit tests with pytest run: | From abaa92944c4cd43d206e2986d405d4ee07d45afe Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Fri, 20 Feb 2026 14:24:36 -0800 Subject: [PATCH 011/102] feat: Agent Registry in ADK Client library for the Agent Registry API that allows users to discover, look up, and connect to agents and MCP servers cataloged in the registry. Co-authored-by: Kathy Wu PiperOrigin-RevId: 873073675 --- .../samples/agent_registry_agent/README.md | 49 ++++ .../samples/agent_registry_agent/__init__.py | 15 ++ .../samples/agent_registry_agent/agent.py | 63 +++++ .../integrations/agent_registry/__init__.py | 13 + .../agent_registry/test_agent_registry.py | 236 ++++++++++++++++++ 5 files changed, 376 insertions(+) create mode 100644 contributing/samples/agent_registry_agent/README.md create mode 100644 contributing/samples/agent_registry_agent/__init__.py create mode 100644 contributing/samples/agent_registry_agent/agent.py create mode 100644 tests/unittests/integrations/agent_registry/__init__.py create mode 100644 tests/unittests/integrations/agent_registry/test_agent_registry.py diff --git a/contributing/samples/agent_registry_agent/README.md b/contributing/samples/agent_registry_agent/README.md new file mode 100644 index 00000000..b9370b64 --- /dev/null +++ b/contributing/samples/agent_registry_agent/README.md @@ -0,0 +1,49 @@ +# Agent Registry Sample + +This sample demonstrates how to use the `AgentRegistry` client to discover agents and MCP servers registered in Google Cloud. + +## Setup + +1. Ensure you have Google Cloud credentials configured (e.g., `gcloud auth application-default login`). +2. Set the following environment variables: + +```bash +export GOOGLE_CLOUD_PROJECT=your-project-id +export GOOGLE_CLOUD_LOCATION=global # or your specific region +``` + +3. Obtain the full resource names for the agents and MCP servers you want to use. You can do this by running the sample script once to list them: + + ```bash + python3 agent.py + ``` + + Alternatively, use `gcloud` to list them: + + ```bash + # For agents + gcloud alpha agent-registry agents list --project=$GOOGLE_CLOUD_PROJECT --location=$GOOGLE_CLOUD_LOCATION + + # For MCP servers + gcloud alpha agent-registry mcp-servers list --project=$GOOGLE_CLOUD_PROJECT --location=$GOOGLE_CLOUD_LOCATION + ``` + +4. Replace `AGENT_NAME` and `MCP_SERVER_NAME` in `agent.py` with the last part of the resource names (e.g., if the name is `projects/.../agents/my-agent`, use `my-agent`). + +## Running the Sample + +Run the sample script to list available agents and MCP servers: + +```bash +python3 agent.py +``` + +## How it Works + +The sample uses `AgentRegistry` to: +- List registered agents using `list_agents()`. +- List registered MCP servers using `list_mcp_servers()`. + +It also shows (in comments) how to: +- Get a `RemoteA2aAgent` instance using `get_remote_a2a_agent(name)`. +- Get an `McpToolset` instance using `get_mcp_toolset(name)`. diff --git a/contributing/samples/agent_registry_agent/__init__.py b/contributing/samples/agent_registry_agent/__init__.py new file mode 100644 index 00000000..4015e47d --- /dev/null +++ b/contributing/samples/agent_registry_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/agent_registry_agent/agent.py b/contributing/samples/agent_registry_agent/agent.py new file mode 100644 index 00000000..38036dea --- /dev/null +++ b/contributing/samples/agent_registry_agent/agent.py @@ -0,0 +1,63 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample agent demonstrating Agent Registry discovery.""" + +import os + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.integrations.agent_registry import AgentRegistry + +# Project and location can be set via environment variables: +# GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION +project_id = os.environ.get("GOOGLE_CLOUD_PROJECT") +location = os.environ.get("GOOGLE_CLOUD_LOCATION", "global") + +# Initialize Agent Registry client +registry = AgentRegistry(project_id=project_id, location=location) + +print(f"Listing agents in {project_id}/{location}...") +agents = registry.list_agents() +for agent in agents.get("agents", []): + print(f"- Agent: {agent.get('displayName')} ({agent.get('name')})") + +print(f"\nListing MCP servers in {project_id}/{location}...") +mcp_servers = registry.list_mcp_servers() +for server in mcp_servers.get("mcpServers", []): + print(f"- MCP Server: {server.get('displayName')} ({server.get('name')})") + +# Example of using a specific agent or MCP server from the registry: +# (Note: These names should be full resource names as returned by list methods) + +# 1. Using a Remote A2A Agent as a sub-agent +# TODO: Replace AGENT_NAME with your agent name +remote_agent = registry.get_remote_a2a_agent( + f"projects/{project_id}/locations/{location}/agents/AGENT_NAME" +) + +# 2. Using an MCP Server in a toolset +# TODO: Replace MCP_SERVER_NAME with your MCP server name +mcp_toolset = registry.get_mcp_toolset( + f"projects/{project_id}/locations/{location}/mcpServers/MCP_SERVER_NAME" +) + +root_agent = LlmAgent( + model="gemini-2.5-flash", + name="discovery_agent", + instruction=( + "You have access to tools and sub-agents discovered via Registry." + ), + tools=[mcp_toolset], + sub_agents=[remote_agent], +) diff --git a/tests/unittests/integrations/agent_registry/__init__.py b/tests/unittests/integrations/agent_registry/__init__.py new file mode 100644 index 00000000..58d482ea --- /dev/null +++ b/tests/unittests/integrations/agent_registry/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py new file mode 100644 index 00000000..f54cdb67 --- /dev/null +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -0,0 +1,236 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.agents.remote_a2a_agent import RemoteA2aAgent +from google.adk.integrations.agent_registry import AgentRegistry +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +import httpx +import pytest + + +class TestAgentRegistry: + + @pytest.fixture + def registry(self): + with patch("google.auth.default", return_value=(MagicMock(), "project-id")): + return AgentRegistry(project_id="test-project", location="global") + + def test_init_raises_value_error_if_params_missing(self): + with pytest.raises( + ValueError, match="project_id and location must be provided" + ): + AgentRegistry(project_id=None, location=None) + + def test_get_connection_uri_mcp_interfaces_top_level(self, registry): + resource_details = { + "interfaces": [ + {"url": "https://mcp-v1main.com", "protocolBinding": "JSONRPC"} + ] + } + uri = registry._get_connection_uri( + resource_details, protocol_binding="JSONRPC" + ) + assert uri == "https://mcp-v1main.com" + + def test_get_connection_uri_agent_nested_protocols(self, registry): + resource_details = { + "protocols": [{ + "type": "A2A_AGENT", + "interfaces": [{ + "url": "https://my-agent.com", + "protocolBinding": "JSONRPC", + }], + }] + } + uri = registry._get_connection_uri( + resource_details, protocol_type="A2A_AGENT" + ) + assert uri == "https://my-agent.com" + + def test_get_connection_uri_filtering(self, registry): + resource_details = { + "protocols": [ + { + "type": "CUSTOM", + "interfaces": [{"url": "https://custom.com"}], + }, + { + "type": "A2A_AGENT", + "interfaces": [{ + "url": "https://my-agent.com", + "protocolBinding": "HTTP_JSON", + }], + }, + ] + } + # Filter by type + uri = registry._get_connection_uri( + resource_details, protocol_type="A2A_AGENT" + ) + assert uri == "https://my-agent.com" + + # Filter by binding + uri = registry._get_connection_uri( + resource_details, protocol_binding="HTTP_JSON" + ) + assert uri == "https://my-agent.com" + + # No match + uri = registry._get_connection_uri( + resource_details, protocol_type="A2A_AGENT", protocol_binding="JSONRPC" + ) + assert uri is None + + def test_get_connection_uri_returns_none_if_no_interfaces(self, registry): + resource_details = {} + uri = registry._get_connection_uri(resource_details) + assert uri is None + + def test_get_connection_uri_returns_none_if_no_url_in_interfaces( + self, registry + ): + resource_details = {"interfaces": [{"protocolBinding": "HTTP"}]} + uri = registry._get_connection_uri(resource_details) + assert uri is None + + @patch("httpx.Client") + def test_list_agents(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.json.return_value = {"agents": []} + mock_response.raise_for_status = MagicMock() + mock_httpx.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + # Mock auth refresh + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + agents = registry.list_agents() + assert agents == {"agents": []} + + @patch("httpx.Client") + def test_get_mcp_server(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.json.return_value = {"name": "test-mcp"} + mock_response.raise_for_status = MagicMock() + mock_httpx.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + server = registry.get_mcp_server("test-mcp") + assert server == {"name": "test-mcp"} + + @patch("httpx.Client") + def test_get_mcp_toolset(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.json.return_value = { + "displayName": "TestPrefix", + "interfaces": [ + {"url": "https://mcp.com", "protocolBinding": "JSONRPC"} + ], + } + mock_response.raise_for_status = MagicMock() + mock_httpx.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + toolset = registry.get_mcp_toolset("test-mcp") + assert isinstance(toolset, McpToolset) + assert toolset.tool_name_prefix == "TestPrefix" + + @patch("httpx.Client") + def test_get_remote_a2a_agent(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.json.return_value = { + "displayName": "TestAgent", + "description": "Test Desc", + "agentSpec": { + "a2aAgentCardUrl": "https://my-agent.com/agent-card.json" + }, + } + mock_response.raise_for_status = MagicMock() + mock_httpx.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + agent = registry.get_remote_a2a_agent("test-agent") + assert isinstance(agent, RemoteA2aAgent) + assert agent.name == "TestAgent" + assert agent.description == "Test Desc" + assert agent._agent_card_source == "https://my-agent.com/agent-card.json" + + def test_get_auth_headers(self, registry): + registry._credentials.token = "fake-token" + registry._credentials.refresh = MagicMock() + registry._credentials.quota_project_id = "quota-project" + + headers = registry._get_auth_headers() + assert headers["Authorization"] == "Bearer fake-token" + assert headers["x-goog-user-project"] == "quota-project" + + @patch("httpx.Client") + def test_make_request_raises_http_status_error(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.text = "Not Found" + error = httpx.HTTPStatusError( + "Error", request=MagicMock(), response=mock_response + ) + mock_httpx.return_value.__enter__.return_value.get.side_effect = error + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + with pytest.raises( + RuntimeError, match="API request failed with status 404" + ): + registry._make_request("test-path") + + @patch("httpx.Client") + def test_make_request_raises_request_error(self, mock_httpx, registry): + error = httpx.RequestError("Connection failed", request=MagicMock()) + mock_httpx.return_value.__enter__.return_value.get.side_effect = error + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + with pytest.raises( + RuntimeError, match="API request failed \(network error\)" + ): + registry._make_request("test-path") + + @patch("httpx.Client") + def test_make_request_raises_generic_exception(self, mock_httpx, registry): + mock_httpx.return_value.__enter__.return_value.get.side_effect = Exception( + "Generic error" + ) + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + with pytest.raises(RuntimeError, match="API request failed: Generic error"): + registry._make_request("test-path") From bbdf0ea2571e88c939de911206b80cd765619380 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 20 Feb 2026 14:41:33 -0800 Subject: [PATCH 012/102] chore: Update OpenTelemetry dependency upper bounds Co-authored-by: George Weale PiperOrigin-RevId: 873080640 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9bec96cb..5a65ec61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,13 +48,13 @@ dependencies = [ "httpx>=0.27.0, <1.0.0", # HTTP client library "jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation "mcp>=1.23.0, <2.0.0", # For MCP Toolset - "opentelemetry-api>=1.36.0, <1.40.0", # OpenTelemetry - keep below 1.40.0 to reduce risk of breaking changes around log-signal APIs. + "opentelemetry-api>=1.36.0, <1.39.0", # OpenTelemetry - keep below 1.39.0 due to current agent_engines exporter constraints. "opentelemetry-exporter-gcp-logging>=1.9.0a0, <2.0.0", "opentelemetry-exporter-gcp-monitoring>=1.9.0a0, <2.0.0", "opentelemetry-exporter-gcp-trace>=1.9.0, <2.0.0", "opentelemetry-exporter-otlp-proto-http>=1.36.0", "opentelemetry-resourcedetector-gcp>=1.9.0a0, <2.0.0", - "opentelemetry-sdk>=1.36.0, <1.40.0", + "opentelemetry-sdk>=1.36.0, <1.39.0", "pyarrow>=14.0.0", "pydantic>=2.7.0, <3.0.0", # For data validation/models "python-dateutil>=2.9.0.post0, <3.0.0", # For Vertext AI Session Service From 485fcb84e3ca351f83416c012edcafcec479c1db Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 20 Feb 2026 15:19:11 -0800 Subject: [PATCH 013/102] feat: Add intra-invocation compaction and token compaction pre-request Compact session events before LLM calls when token threshold is exceeded Co-authored-by: George Weale PiperOrigin-RevId: 873095899 --- src/google/adk/agents/invocation_context.py | 7 + src/google/adk/apps/compaction.py | 397 ++++++++++++------ src/google/adk/flows/llm_flows/compaction.py | 58 +++ src/google/adk/flows/llm_flows/single_flow.py | 4 + src/google/adk/runners.py | 8 +- tests/unittests/apps/test_compaction.py | 296 ++++++++++++- .../llm_flows/test_compaction_processor.py | 346 +++++++++++++++ 7 files changed, 983 insertions(+), 133 deletions(-) create mode 100644 src/google/adk/flows/llm_flows/compaction.py create mode 100644 tests/unittests/flows/llm_flows/test_compaction_processor.py diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 7a23a6cc..4c75e1c4 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -24,6 +24,7 @@ from pydantic import ConfigDict from pydantic import Field from pydantic import PrivateAttr +from ..apps.app import EventsCompactionConfig from ..apps.app import ResumabilityConfig from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService @@ -200,6 +201,12 @@ class InvocationContext(BaseModel): resumability_config: Optional[ResumabilityConfig] = None """The resumability config that applies to all agents under this invocation.""" + events_compaction_config: Optional[EventsCompactionConfig] = None + """The compaction config for this invocation.""" + + token_compaction_checked: bool = False + """Whether token-threshold compaction ran during this invocation.""" + plugin_manager: PluginManager = Field(default_factory=PluginManager) """The manager for keeping track of plugins in this invocation.""" diff --git a/src/google/adk/apps/compaction.py b/src/google/adk/apps/compaction.py index 4af7b512..61941bff 100644 --- a/src/google/adk/apps/compaction.py +++ b/src/google/adk/apps/compaction.py @@ -16,25 +16,53 @@ from __future__ import annotations import logging +from google.genai import types + +from ..agents.base_agent import BaseAgent from ..events.event import Event from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session from .app import App +from .app import EventsCompactionConfig from .llm_event_summarizer import LlmEventSummarizer logger = logging.getLogger('google_adk.' + __name__) -def _count_text_chars_in_event(event: Event) -> int: - """Returns the number of text characters in an event's content.""" +def _count_text_chars_in_content(content: types.Content | None) -> int: + """Returns the number of text characters in a content object.""" total_chars = 0 - if event.content and event.content.parts: - for part in event.content.parts: + if content and content.parts: + for part in content.parts: if part.text: total_chars += len(part.text) return total_chars +def _valid_compactions( + events: list[Event], +) -> list[tuple[int, float, float, Event]]: + """Returns compaction events with fully-defined compaction ranges.""" + compactions: list[tuple[int, float, float, Event]] = [] + for i, event in enumerate(events): + if not (event.actions and event.actions.compaction): + continue + compaction = event.actions.compaction + if ( + compaction.start_timestamp is None + or compaction.end_timestamp is None + or compaction.compacted_content is None + ): + continue + compactions.append(( + i, + compaction.start_timestamp, + compaction.end_timestamp, + event, + )) + return compactions + + def _is_compaction_subsumed( *, start_timestamp: float, @@ -60,67 +88,29 @@ def _is_compaction_subsumed( return False -def _estimate_prompt_token_count(events: list[Event]) -> int | None: +def _estimate_prompt_token_count( + *, + events: list[Event], + current_branch: str | None, + agent_name: str, +) -> int | None: """Returns an approximate prompt token count from session events. - This estimate is compaction-aware: it counts compaction summaries and only - counts raw events that would remain visible after applying compaction ranges. + This estimate mirrors the effective content-building path used by the + contents request processor. """ - compactions: list[tuple[int, float, float, Event]] = [] - for i, event in enumerate(events): - if not (event.actions and event.actions.compaction): - continue - compaction = event.actions.compaction - if ( - compaction.start_timestamp is None - or compaction.end_timestamp is None - or compaction.compacted_content is None - ): - continue - compactions.append(( - i, - compaction.start_timestamp, - compaction.end_timestamp, - Event( - timestamp=compaction.end_timestamp, - author='model', - content=compaction.compacted_content, - branch=event.branch, - invocation_id=event.invocation_id, - actions=event.actions, - ), - )) - - effective_compactions = [ - (i, start, end, summary_event) - for i, start, end, summary_event in compactions - if not _is_compaction_subsumed( - start_timestamp=start, - end_timestamp=end, - event_index=i, - compactions=compactions, - ) - ] - compaction_ranges = [ - (start, end) for _, start, end, _ in effective_compactions - ] - - def _is_timestamp_compacted(ts: float) -> bool: - for start_ts, end_ts in compaction_ranges: - if start_ts <= ts <= end_ts: - return True - return False + # Deferred import: contents depends on agents.invocation_context which + # imports from apps, so a top-level import would create a circular dependency. + from ..flows.llm_flows import contents + effective_contents = contents._get_contents( + current_branch=current_branch, + events=events, + agent_name=agent_name, + ) total_chars = 0 - for _, _, _, summary_event in effective_compactions: - total_chars += _count_text_chars_in_event(summary_event) - - for event in events: - if event.actions and event.actions.compaction: - continue - if _is_timestamp_compacted(event.timestamp): - continue - total_chars += _count_text_chars_in_event(event) + for content in effective_contents: + total_chars += _count_text_chars_in_content(content) if total_chars <= 0: return None @@ -129,7 +119,12 @@ def _estimate_prompt_token_count(events: list[Event]) -> int | None: return total_chars // 4 -def _latest_prompt_token_count(events: list[Event]) -> int | None: +def _latest_prompt_token_count( + events: list[Event], + *, + current_branch: str | None = None, + agent_name: str = '', +) -> int | None: """Returns the most recently observed prompt token count, if available.""" for event in reversed(events): if ( @@ -137,23 +132,29 @@ def _latest_prompt_token_count(events: list[Event]) -> int | None: and event.usage_metadata.prompt_token_count is not None ): return event.usage_metadata.prompt_token_count - return _estimate_prompt_token_count(events) + return _estimate_prompt_token_count( + events=events, + current_branch=current_branch, + agent_name=agent_name, + ) def _latest_compaction_event(events: list[Event]) -> Event | None: - """Returns the compaction event with the greatest covered end timestamp.""" + """Returns the latest non-subsumed compaction event by stream order.""" + compactions = _valid_compactions(events) latest_event = None - latest_end = 0.0 - for event in events: - if ( - event.actions - and event.actions.compaction - and event.actions.compaction.end_timestamp is not None + latest_index = -1 + for event_index, start_ts, end_ts, event in compactions: + if _is_compaction_subsumed( + start_timestamp=start_ts, + end_timestamp=end_ts, + event_index=event_index, + compactions=compactions, ): - end_ts = event.actions.compaction.end_timestamp - if end_ts is not None and end_ts >= latest_end: - latest_end = end_ts - latest_event = event + continue + if event_index > latest_index: + latest_index = event_index + latest_event = event return latest_event @@ -167,55 +168,73 @@ def _latest_compaction_end_timestamp(events: list[Event]) -> float: return latest_event.actions.compaction.end_timestamp -async def _run_compaction_for_token_threshold( - app: App, session: Session, session_service: BaseSessionService -): - """Runs post-invocation compaction based on a token threshold. +def _has_token_threshold_config(config: EventsCompactionConfig | None) -> bool: + """Returns whether token-threshold compaction is fully configured.""" + return bool( + config + and config.token_threshold is not None + and config.event_retention_size is not None + ) - If triggered, this compacts older raw events and keeps the last - `event_retention_size` raw events un-compacted. - """ - config = app.events_compaction_config - if not config: - return False - if config.token_threshold is None or config.event_retention_size is None: - return False - prompt_token_count = _latest_prompt_token_count(session.events) - if prompt_token_count is None or prompt_token_count < config.token_threshold: - return False +def _has_sliding_window_config(config: EventsCompactionConfig | None) -> bool: + """Returns whether sliding-window compaction is fully configured.""" + return bool( + config + and config.compaction_interval is not None + and config.overlap_size is not None + ) - latest_compaction_event = _latest_compaction_event(session.events) - last_compacted_end_timestamp = 0.0 - if ( - latest_compaction_event - and latest_compaction_event.actions - and latest_compaction_event.actions.compaction - and latest_compaction_event.actions.compaction.end_timestamp is not None - ): - last_compacted_end_timestamp = ( - latest_compaction_event.actions.compaction.end_timestamp + +def _ensure_compaction_summarizer( + *, config: EventsCompactionConfig, agent: BaseAgent +) -> None: + """Ensures compaction config has a summarizer initialized.""" + if config.summarizer is not None: + return + + from ..agents.llm_agent import LlmAgent + + if not isinstance(agent, LlmAgent): + raise ValueError( + 'No LlmAgent model available for event compaction summarizer.' ) + config.summarizer = LlmEventSummarizer(llm=agent.canonical_model) + + +def _events_to_compact_for_token_threshold( + *, + events: list[Event], + event_retention_size: int, +) -> list[Event]: + """Collects token-threshold compaction candidates with rolling-summary seed. + + If a previous compaction exists, include its summary as the first event so + the next summary can supersede it. + """ + latest_compaction_event = _latest_compaction_event(events) + last_compacted_end_timestamp = _latest_compaction_end_timestamp(events) + candidate_events = [ - e - for e in session.events - if not (e.actions and e.actions.compaction) - and e.timestamp > last_compacted_end_timestamp + event + for event in events + if not (event.actions and event.actions.compaction) + and event.timestamp > last_compacted_end_timestamp ] + if len(candidate_events) <= event_retention_size: + return [] - if len(candidate_events) <= config.event_retention_size: - return False - - if config.event_retention_size == 0: + if event_retention_size == 0: events_to_compact = candidate_events else: - events_to_compact = candidate_events[: -config.event_retention_size] + split_index = _safe_token_compaction_split_index( + candidate_events=candidate_events, + event_retention_size=event_retention_size, + ) + events_to_compact = candidate_events[:split_index] if not events_to_compact: - return False + return [] - # Rolling summary: if a previous compaction exists, seed the next summary with - # the previous compaction summary content so new compactions can subsume older - # ones while still keeping `event_retention_size` raw events visible. if ( latest_compaction_event and latest_compaction_event.actions @@ -231,10 +250,101 @@ async def _run_compaction_for_token_threshold( branch=latest_compaction_event.branch, invocation_id=Event.new_id(), ) - events_to_compact = [seed_event] + events_to_compact + return [seed_event] + events_to_compact - if not config.summarizer: - config.summarizer = LlmEventSummarizer(llm=app.root_agent.canonical_model) + return events_to_compact + + +def _event_function_call_ids(event: Event) -> set[str]: + """Returns function call ids found in an event.""" + function_call_ids: set[str] = set() + for function_call in event.get_function_calls(): + if function_call.id: + function_call_ids.add(function_call.id) + return function_call_ids + + +def _event_function_response_ids(event: Event) -> set[str]: + """Returns function response ids found in an event.""" + function_response_ids: set[str] = set() + for function_response in event.get_function_responses(): + if function_response.id: + function_response_ids.add(function_response.id) + return function_response_ids + + +def _safe_token_compaction_split_index( + *, + candidate_events: list[Event], + event_retention_size: int, +) -> int: + """Returns a split index that avoids orphaning retained tool responses. + + Retained events (tail of candidate events) may contain function responses. + If their matching function call events are in the compacted prefix, contents + assembly can fail. This method shifts the split earlier so matching function + call events are retained together with their responses. + + Iterates backwards through candidate_events once, maintaining a running set + of unmatched response IDs. The latest valid split point where no unmatched + responses remain is returned. + """ + initial_split = len(candidate_events) - event_retention_size + if initial_split <= 0: + return 0 + + unmatched_response_ids: set[str] = set() + best_split = 0 + + for i in range(len(candidate_events) - 1, -1, -1): + event = candidate_events[i] + unmatched_response_ids.update(_event_function_response_ids(event)) + call_ids = _event_function_call_ids(event) + unmatched_response_ids -= call_ids + + if not unmatched_response_ids and i <= initial_split: + best_split = i + break + + return best_split + + +async def _run_compaction_for_token_threshold_config( + *, + config: EventsCompactionConfig | None, + session: Session, + session_service: BaseSessionService, + agent: BaseAgent, + agent_name: str = '', + current_branch: str | None = None, +) -> bool: + """Runs token-threshold compaction for a provided compaction config.""" + if not _has_token_threshold_config(config): + return False + if config is None: + return False + + if config.token_threshold is None or config.event_retention_size is None: + return False + + prompt_token_count = _latest_prompt_token_count( + session.events, + current_branch=current_branch, + agent_name=agent_name, + ) + if prompt_token_count is None or prompt_token_count < config.token_threshold: + return False + + events_to_compact = _events_to_compact_for_token_threshold( + events=session.events, + event_retention_size=config.event_retention_size, + ) + if not events_to_compact: + return False + + _ensure_compaction_summarizer(config=config, agent=agent) + if config.summarizer is None: + return False compaction_event = await config.summarizer.maybe_summarize_events( events=events_to_compact @@ -246,8 +356,30 @@ async def _run_compaction_for_token_threshold( return False -async def _run_compaction_for_sliding_window( +async def _run_compaction_for_token_threshold( app: App, session: Session, session_service: BaseSessionService +): + """Runs post-invocation compaction based on a token threshold. + + If triggered, this compacts older raw events and keeps the last + `event_retention_size` raw events un-compacted. + """ + return await _run_compaction_for_token_threshold_config( + config=app.events_compaction_config, + session=session, + session_service=session_service, + agent=app.root_agent, + agent_name='', + current_branch=None, + ) + + +async def _run_compaction_for_sliding_window( + app: App, + session: Session, + session_service: BaseSessionService, + *, + skip_token_compaction: bool = False, ): """Runs compaction for SlidingWindowCompactor. @@ -327,22 +459,30 @@ async def _run_compaction_for_sliding_window( app: The application instance. session: The session containing events to compact. session_service: The session service for appending events. + skip_token_compaction: Whether to skip token-threshold compaction. """ events = session.events if not events: return None + config = app.events_compaction_config + if config is None: + return None + # Prefer token-threshold compaction if configured and triggered. - if ( - app.events_compaction_config - and app.events_compaction_config.token_threshold is not None - ): + if not skip_token_compaction and _has_token_threshold_config(config): token_compacted = await _run_compaction_for_token_threshold( app, session, session_service ) if token_compacted: return None + if not _has_sliding_window_config(config): + return None + + if config.compaction_interval is None or config.overlap_size is None: + return None + # Find the last compaction event and its range. last_compacted_end_timestamp = 0.0 for event in reversed(events): @@ -373,7 +513,7 @@ async def _run_compaction_for_sliding_window( if invocation_latest_timestamps[inv_id] > last_compacted_end_timestamp ] - if len(new_invocation_ids) < app.events_compaction_config.compaction_interval: + if len(new_invocation_ids) < config.compaction_interval: return None # Not enough new invocations to trigger compaction. # Determine the range of invocations to compact. @@ -385,9 +525,7 @@ async def _run_compaction_for_sliding_window( first_new_inv_id = new_invocation_ids[0] first_new_inv_idx = unique_invocation_ids.index(first_new_inv_id) - start_idx = max( - 0, first_new_inv_idx - app.events_compaction_config.overlap_size - ) + start_idx = max(0, first_new_inv_idx - config.overlap_size) start_inv_id = unique_invocation_ids[start_idx] # Find the index of the last event with end_inv_id. @@ -419,15 +557,12 @@ async def _run_compaction_for_sliding_window( if not events_to_compact: return None - if not app.events_compaction_config.summarizer: - app.events_compaction_config.summarizer = LlmEventSummarizer( - llm=app.root_agent.canonical_model - ) + _ensure_compaction_summarizer(config=config, agent=app.root_agent) + if config.summarizer is None: + return None - compaction_event = ( - await app.events_compaction_config.summarizer.maybe_summarize_events( - events=events_to_compact - ) + compaction_event = await config.summarizer.maybe_summarize_events( + events=events_to_compact ) if compaction_event: await session_service.append_event(session=session, event=compaction_event) diff --git a/src/google/adk/flows/llm_flows/compaction.py b/src/google/adk/flows/llm_flows/compaction.py new file mode 100644 index 00000000..f4b60ba9 --- /dev/null +++ b/src/google/adk/flows/llm_flows/compaction.py @@ -0,0 +1,58 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Request processor that runs token-threshold event compaction.""" + +from __future__ import annotations + +from typing import AsyncGenerator +from typing import TYPE_CHECKING + +from ...apps.compaction import _has_token_threshold_config +from ...apps.compaction import _run_compaction_for_token_threshold_config +from ...events.event import Event +from ._base_llm_processor import BaseLlmRequestProcessor + +if TYPE_CHECKING: + from ...agents.invocation_context import InvocationContext + from ...models.llm_request import LlmRequest + + +class CompactionRequestProcessor(BaseLlmRequestProcessor): + """Compacts session events before contents are prepared for model calls.""" + + async def run_async( + self, invocation_context: InvocationContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + del llm_request + config = invocation_context.events_compaction_config + if not _has_token_threshold_config(config): + return + yield # Required for AsyncGenerator. + + token_compacted = await _run_compaction_for_token_threshold_config( + config=config, + session=invocation_context.session, + session_service=invocation_context.session_service, + agent=invocation_context.agent, + agent_name=invocation_context.agent.name, + current_branch=invocation_context.branch, + ) + if token_compacted: + invocation_context.token_compaction_checked = True + return + yield # Required for AsyncGenerator. + + +request_processor = CompactionRequestProcessor() diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index 0a26cdce..e0bd00ff 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -22,6 +22,7 @@ from . import _code_execution from . import _nl_planning from . import _output_schema_processor from . import basic +from . import compaction from . import contents from . import context_cache_processor from . import identity @@ -42,6 +43,9 @@ def _create_request_processors(): request_confirmation.request_processor, instructions.request_processor, identity.request_processor, + # Compaction should run before contents so compacted events are reflected + # in the model request context. + compaction.request_processor, contents.request_processor, # Context cache processor sets up cache config and finds # existing cache metadata. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bc0251a8..cdb878cf 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -553,7 +553,10 @@ class Runner: if self.app and self.app.events_compaction_config: logger.debug('Running event compactor.') await _run_compaction_for_sliding_window( - self.app, session, self.session_service + self.app, + session, + self.session_service, + skip_token_compaction=invocation_context.token_compaction_checked, ) async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen: @@ -1362,6 +1365,9 @@ class Runner: credential_service=self.credential_service, plugin_manager=self.plugin_manager, context_cache_config=self.context_cache_config, + events_compaction_config=( + self.app.events_compaction_config if self.app else None + ), invocation_id=invocation_id, agent=self.agent, session=session, diff --git a/tests/unittests/apps/test_compaction.py b/tests/unittests/apps/test_compaction.py index fadcd39d..6960c8d4 100644 --- a/tests/unittests/apps/test_compaction.py +++ b/tests/unittests/apps/test_compaction.py @@ -50,6 +50,7 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): invocation_id: str, text: str, prompt_token_count: int | None = None, + thought: bool = False, ) -> Event: usage_metadata = None if prompt_token_count is not None: @@ -60,7 +61,60 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): timestamp=timestamp, invocation_id=invocation_id, author='user', - content=Content(role='user', parts=[Part(text=text)]), + content=Content(role='user', parts=[Part(text=text, thought=thought)]), + usage_metadata=usage_metadata, + ) + + def _create_function_call_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + ) -> Event: + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='model', + parts=[ + Part( + function_call=types.FunctionCall( + id=function_call_id, name='tool', args={} + ) + ) + ], + ), + ) + + def _create_function_response_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + prompt_token_count: int | None = None, + ) -> Event: + usage_metadata = None + if prompt_token_count is not None: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=prompt_token_count + ) + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id=function_call_id, + name='tool', + response={'result': 'ok'}, + ) + ) + ], + ), usage_metadata=usage_metadata, ) @@ -249,9 +303,21 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): token_threshold=50_000, event_retention_size=5, ) + self.assertEqual(config.compaction_interval, 2) + self.assertEqual(config.overlap_size, 1) self.assertEqual(config.token_threshold, 50_000) self.assertEqual(config.event_retention_size, 5) + def test_events_compaction_config_accepts_sliding_window_fields(self): + config = EventsCompactionConfig( + compaction_interval=2, + overlap_size=1, + ) + self.assertEqual(config.compaction_interval, 2) + self.assertEqual(config.overlap_size, 1) + self.assertIsNone(config.token_threshold) + self.assertIsNone(config.event_retention_size) + def test_events_compaction_config_rejects_partial_token_fields( self, ): @@ -262,6 +328,23 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): token_threshold=50_000, ) + def test_events_compaction_config_rejects_partial_sliding_fields( + self, + ): + with pytest.raises(ValidationError): + EventsCompactionConfig( + compaction_interval=2, + ) + + with pytest.raises(ValidationError): + EventsCompactionConfig( + overlap_size=0, + ) + + def test_events_compaction_config_rejects_missing_modes(self): + with pytest.raises(ValidationError): + EventsCompactionConfig() + def test_latest_prompt_token_count_fallback_applies_compaction(self): events = [ self._create_event(1.0, 'inv1', 'a' * 40), @@ -275,6 +358,25 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): # Visible text after compaction is: 'S' + ('c' * 20) = 21 chars. self.assertEqual(estimated_token_count, 21 // 4) + def test_latest_prompt_token_count_fallback_uses_effective_contents(self): + events = [ + self._create_event(1.0, 'inv1', 'visible'), + Event( + timestamp=2.0, + invocation_id='inv2', + author='model', + content=Content( + role='model', + parts=[Part(text='hidden-thought', thought=True)], + ), + ), + ] + + estimated_token_count = compaction_module._latest_prompt_token_count(events) + + # Thought-only events are filtered by contents processing. + self.assertEqual(estimated_token_count, len('visible') // 4) + async def test_run_compaction_for_token_threshold_keeps_retention_events( self, ): @@ -324,6 +426,136 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): session=session, event=mock_compacted_event ) + async def test_run_compaction_for_token_threshold_keeps_tool_call_pair( + self, + ): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'tool-call-1'), + self._create_function_response_event( + 3.0, + 'inv2', + 'tool-call-1', + prompt_token_count=100, + ), + ], + ) + + mock_compacted_event = self._create_compacted_event( + 1.0, 1.0, 'Summary inv1' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + self.assertEqual( + [e.invocation_id for e in compacted_events_arg], + ['inv1'], + ) + self.mock_session_service.append_event.assert_called_once_with( + session=session, event=mock_compacted_event + ) + + async def test_run_compaction_for_token_threshold_equal_threshold_compacts( + self, + ): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=100, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2', prompt_token_count=100), + ], + ) + + mock_compacted_event = self._create_compacted_event( + 1.0, 1.0, 'Summary inv1' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + self.assertEqual( + [e.invocation_id for e in compacted_events_arg], + ['inv1'], + ) + self.mock_session_service.append_event.assert_called_once_with( + session=session, event=mock_compacted_event + ) + + async def test_run_compaction_skip_token_compaction(self): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2', prompt_token_count=100), + ], + ) + + await _run_compaction_for_sliding_window( + app, + session, + self.mock_session_service, + skip_token_compaction=True, + ) + + self.mock_compactor.maybe_summarize_events.assert_not_called() + self.mock_session_service.append_event.assert_not_called() + async def test_run_compaction_for_token_threshold_seeds_previous_compaction( self, ): @@ -482,6 +714,68 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): session=session, event=mock_compacted_event ) + async def test_run_compaction_for_token_threshold_uses_latest_ordered_seed( + self, + ): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2'), + self._create_event(3.0, 'inv3', 'e3'), + self._create_event(4.0, 'inv4', 'e4'), + self._create_event(5.0, 'inv5', 'e5'), + self._create_event(15.0, 'inv6', 'e6'), + self._create_event(20.0, 'inv7', 'e7'), + self._create_compacted_event( + 15.0, 20.0, 'Summary 15-20', appended_ts=21.0 + ), + self._create_compacted_event( + 1.0, 5.0, 'Summary 1-5', appended_ts=22.0 + ), + self._create_event(23.0, 'inv8', 'e8'), + self._create_event(24.0, 'inv9', 'e9', prompt_token_count=120), + ], + ) + + mock_compacted_event = self._create_compacted_event( + 1.0, 23.0, 'Summary 1-23' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + self.assertEqual( + compacted_events_arg[0].content.parts[0].text, 'Summary 1-5' + ) + self.assertEqual( + [e.invocation_id for e in compacted_events_arg[1:]], + ['inv6', 'inv7', 'inv8'], + ) + self.mock_session_service.append_event.assert_called_once_with( + session=session, event=mock_compacted_event + ) + def test_get_contents_with_multiple_compactions(self): # Event timestamps: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 diff --git a/tests/unittests/flows/llm_flows/test_compaction_processor.py b/tests/unittests/flows/llm_flows/test_compaction_processor.py new file mode 100644 index 00000000..9f747c4b --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_compaction_processor.py @@ -0,0 +1,346 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for request-phase token compaction processor.""" + +from unittest.mock import AsyncMock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.apps.app import EventsCompactionConfig +from google.adk.apps.llm_event_summarizer import LlmEventSummarizer +from google.adk.events.event import Event +from google.adk.flows.llm_flows import compaction +from google.adk.flows.llm_flows import contents +from google.adk.flows.llm_flows.single_flow import SingleFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.session import Session +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + + +def _create_event( + *, + timestamp: float, + invocation_id: str, + text: str, + prompt_token_count: int | None = None, +) -> Event: + usage_metadata = None + if prompt_token_count is not None: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=prompt_token_count + ) + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='user', + content=Content(role='user', parts=[Part(text=text)]), + usage_metadata=usage_metadata, + ) + + +def test_single_flow_includes_compaction_before_contents(): + flow = SingleFlow() + + compaction_index = flow.request_processors.index(compaction.request_processor) + contents_index = flow.request_processors.index(contents.request_processor) + + assert compaction_index < contents_index + + +@pytest.mark.asyncio +async def test_compaction_request_processor_no_token_config(): + session = Session(app_name='app', user_id='user', id='session', events=[]) + session_service = AsyncMock(spec=BaseSessionService) + invocation_context = InvocationContext( + invocation_id='invocation', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + compaction_interval=2, + overlap_size=0, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert not invocation_context.token_compaction_checked + session_service.append_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_compaction_request_processor_runs_token_compaction(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + _create_event(timestamp=2.0, invocation_id='inv2', text='e2'), + _create_event( + timestamp=3.0, + invocation_id='inv3', + text='e3', + prompt_token_count=100, + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + compacted_event = Event(author='compactor', invocation_id=Event.new_id()) + mock_summarizer.maybe_summarize_events.return_value = compacted_event + + invocation_context = InvocationContext( + invocation_id='invocation', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert invocation_context.token_compaction_checked + + compacted_events_arg = mock_summarizer.maybe_summarize_events.call_args[1][ + 'events' + ] + assert [event.invocation_id for event in compacted_events_arg] == [ + 'inv1', + 'inv2', + ] + session_service.append_event.assert_called_once_with( + session=session, event=compacted_event + ) + + +@pytest.mark.asyncio +async def test_compaction_request_processor_compacts_with_latest_tool_response(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + _create_event(timestamp=2.0, invocation_id='inv2', text='e2'), + Event( + timestamp=3.0, + invocation_id='current-inv', + author='agent', + content=Content( + role='model', + parts=[ + Part( + function_call=types.FunctionCall( + id='call-1', name='tool', args={} + ) + ) + ], + ), + ), + Event( + timestamp=4.0, + invocation_id='current-inv', + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id='call-1', + name='tool', + response={'result': 'ok'}, + ) + ) + ], + ), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=100 + ), + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + compacted_event = Event(author='compactor', invocation_id=Event.new_id()) + mock_summarizer.maybe_summarize_events.return_value = compacted_event + + invocation_context = InvocationContext( + invocation_id='current-inv', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert invocation_context.token_compaction_checked + compacted_events_arg = mock_summarizer.maybe_summarize_events.call_args[1][ + 'events' + ] + assert [event.invocation_id for event in compacted_events_arg] == [ + 'inv1', + 'inv2', + ] + session_service.append_event.assert_called_once_with( + session=session, event=compacted_event + ) + + +@pytest.mark.asyncio +async def test_compaction_request_processor_can_compact_current_user_event(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + Event( + timestamp=2.0, + invocation_id='current-inv', + author='user', + content=Content( + role='user', + parts=[Part(text='latest user message')], + ), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=100 + ), + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + compacted_event = Event(author='compactor', invocation_id=Event.new_id()) + mock_summarizer.maybe_summarize_events.return_value = compacted_event + + invocation_context = InvocationContext( + invocation_id='current-inv', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=0, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert invocation_context.token_compaction_checked + compacted_events_arg = mock_summarizer.maybe_summarize_events.call_args[1][ + 'events' + ] + assert [event.invocation_id for event in compacted_events_arg] == [ + 'inv1', + 'current-inv', + ] + session_service.append_event.assert_called_once_with( + session=session, event=compacted_event + ) + + +@pytest.mark.asyncio +async def test_compaction_request_processor_not_marked_when_not_compacted(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + _create_event( + timestamp=2.0, + invocation_id='inv2', + text='e2', + prompt_token_count=40, + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + mock_summarizer.maybe_summarize_events.return_value = Event( + author='compactor', + invocation_id=Event.new_id(), + ) + + invocation_context = InvocationContext( + invocation_id='invocation', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert not invocation_context.token_compaction_checked + mock_summarizer.maybe_summarize_events.assert_not_called() + session_service.append_event.assert_not_called() From 6ea3696bcc243a7c0c1cc6b33219f54482612d14 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Fri, 20 Feb 2026 15:23:11 -0800 Subject: [PATCH 014/102] chore: Migrate /agents to use the new feature decorator Co-authored-by: Xuan Yang PiperOrigin-RevId: 873097395 --- src/google/adk/agents/agent_config.py | 6 +++--- src/google/adk/agents/base_agent.py | 9 +++++---- src/google/adk/agents/base_agent_config.py | 5 +++-- src/google/adk/agents/common_configs.py | 9 +++++---- src/google/adk/agents/config_agent_utils.py | 13 +++++++------ src/google/adk/agents/context_cache_config.py | 5 +++-- src/google/adk/agents/llm_agent.py | 7 ++++--- src/google/adk/agents/loop_agent.py | 7 ++++--- src/google/adk/agents/loop_agent_config.py | 5 +++-- src/google/adk/agents/parallel_agent_config.py | 5 +++-- src/google/adk/agents/sequential_agent.py | 5 +++-- src/google/adk/agents/sequential_agent_config.py | 5 +++-- src/google/adk/features/_feature_registry.py | 8 ++++++++ 13 files changed, 54 insertions(+), 35 deletions(-) diff --git a/src/google/adk/agents/agent_config.py b/src/google/adk/agents/agent_config.py index add31f4b..2d3c6270 100644 --- a/src/google/adk/agents/agent_config.py +++ b/src/google/adk/agents/agent_config.py @@ -16,14 +16,14 @@ from __future__ import annotations from typing import Annotated from typing import Any -from typing import get_args from typing import Union from pydantic import Discriminator from pydantic import RootModel from pydantic import Tag -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .base_agent_config import BaseAgentConfig from .llm_agent_config import LlmAgentConfig from .loop_agent_config import LoopAgentConfig @@ -68,6 +68,6 @@ ConfigsUnion = Annotated[ # Use a RootModel to represent the agent directly at the top level. # The `discriminator` is applied to the union within the RootModel. -@experimental +@experimental(FeatureName.AGENT_CONFIG) class AgentConfig(RootModel[ConfigsUnion]): """The config for the YAML schema to create an agent.""" diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 7e46436a..3d0a14d4 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -40,10 +40,11 @@ from typing_extensions import TypeAlias from ..events.event import Event from ..events.event_actions import EventActions +from ..features import experimental +from ..features import FeatureName from ..telemetry import tracing from ..telemetry.tracing import tracer from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental from .base_agent_config import BaseAgentConfig from .callback_context import CallbackContext @@ -70,7 +71,7 @@ AfterAgentCallback: TypeAlias = Union[ SelfAgent = TypeVar('SelfAgent', bound='BaseAgent') -@experimental +@experimental(FeatureName.AGENT_STATE) class BaseAgentState(BaseModel): """Base class for all agent states.""" @@ -618,7 +619,7 @@ class BaseAgent(BaseModel): @final @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def from_config( cls: Type[SelfAgent], config: BaseAgentConfig, @@ -642,7 +643,7 @@ class BaseAgent(BaseModel): return cls(**kwargs) @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def _parse_config( cls: Type[SelfAgent], config: BaseAgentConfig, diff --git a/src/google/adk/agents/base_agent_config.py b/src/google/adk/agents/base_agent_config.py index 9f1f5566..3859cb35 100644 --- a/src/google/adk/agents/base_agent_config.py +++ b/src/google/adk/agents/base_agent_config.py @@ -26,14 +26,15 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .common_configs import AgentRefConfig from .common_configs import CodeConfig TBaseAgentConfig = TypeVar('TBaseAgentConfig', bound='BaseAgentConfig') -@experimental +@experimental(FeatureName.AGENT_CONFIG) class BaseAgentConfig(BaseModel): """The config for the YAML schema of a BaseAgent. diff --git a/src/google/adk/agents/common_configs.py b/src/google/adk/agents/common_configs.py index 4e4c49f3..49baa8a4 100644 --- a/src/google/adk/agents/common_configs.py +++ b/src/google/adk/agents/common_configs.py @@ -24,10 +24,11 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import model_validator -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName -@experimental +@experimental(FeatureName.AGENT_CONFIG) class ArgumentConfig(BaseModel): """An argument passed to a function or a class's constructor.""" @@ -43,7 +44,7 @@ class ArgumentConfig(BaseModel): """The argument value.""" -@experimental +@experimental(FeatureName.AGENT_CONFIG) class CodeConfig(BaseModel): """Code reference config for a variable, a function, or a class. @@ -81,7 +82,7 @@ class CodeConfig(BaseModel): """ -@experimental +@experimental(FeatureName.AGENT_CONFIG) class AgentRefConfig(BaseModel): """The config for the reference to another agent.""" diff --git a/src/google/adk/agents/config_agent_utils.py b/src/google/adk/agents/config_agent_utils.py index 446eac88..2c1c9bd9 100644 --- a/src/google/adk/agents/config_agent_utils.py +++ b/src/google/adk/agents/config_agent_utils.py @@ -22,7 +22,8 @@ from typing import List import yaml -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .agent_config import AgentConfig from .base_agent import BaseAgent from .base_agent_config import BaseAgentConfig @@ -30,7 +31,7 @@ from .common_configs import AgentRefConfig from .common_configs import CodeConfig -@experimental +@experimental(FeatureName.AGENT_CONFIG) def from_config(config_path: str) -> BaseAgent: """Build agent from a configfile path. @@ -102,7 +103,7 @@ def _load_config_from_path(config_path: str) -> AgentConfig: return AgentConfig.model_validate(config_data) -@experimental +@experimental(FeatureName.AGENT_CONFIG) def resolve_fully_qualified_name(name: str) -> Any: try: module_path, obj_name = name.rsplit(".", 1) @@ -112,7 +113,7 @@ def resolve_fully_qualified_name(name: str) -> Any: raise ValueError(f"Invalid fully qualified name: {name}") from e -@experimental +@experimental(FeatureName.AGENT_CONFIG) def resolve_agent_reference( ref_config: AgentRefConfig, referencing_agent_config_abs_path: str ) -> BaseAgent: @@ -170,7 +171,7 @@ def _resolve_agent_code_reference(code: str) -> Any: return obj -@experimental +@experimental(FeatureName.AGENT_CONFIG) def resolve_code_reference(code_config: CodeConfig) -> Any: """Resolve a code reference to actual Python object. @@ -199,7 +200,7 @@ def resolve_code_reference(code_config: CodeConfig) -> Any: return obj -@experimental +@experimental(FeatureName.AGENT_CONFIG) def resolve_callbacks(callbacks_config: List[CodeConfig]) -> Any: """Resolve callbacks from configuration. diff --git a/src/google/adk/agents/context_cache_config.py b/src/google/adk/agents/context_cache_config.py index 855e28c3..9e6d19ca 100644 --- a/src/google/adk/agents/context_cache_config.py +++ b/src/google/adk/agents/context_cache_config.py @@ -18,10 +18,11 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName -@experimental +@experimental(FeatureName.AGENT_CONFIG) class ContextCacheConfig(BaseModel): """Configuration for context caching across all agents in an app. diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 4e07651c..8b555b74 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -41,6 +41,8 @@ from typing_extensions import TypeAlias from ..code_executors.base_code_executor import BaseCodeExecutor from ..events.event import Event +from ..features import experimental +from ..features import FeatureName from ..flows.llm_flows.auto_flow import AutoFlow from ..flows.llm_flows.base_llm_flow import BaseLlmFlow from ..flows.llm_flows.single_flow import SingleFlow @@ -55,7 +57,6 @@ from ..tools.function_tool import FunctionTool from ..tools.tool_configs import ToolConfig from ..tools.tool_context import ToolContext from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental from .base_agent import BaseAgent from .base_agent import BaseAgentState from .base_agent_config import BaseAgentConfig @@ -883,7 +884,7 @@ class LlmAgent(BaseAgent): ) @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def _resolve_tools( cls, tool_configs: list[ToolConfig], config_abs_path: str ) -> list[Any]: @@ -942,7 +943,7 @@ class LlmAgent(BaseAgent): @override @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def _parse_config( cls: Type[LlmAgent], config: LlmAgentConfig, diff --git a/src/google/adk/agents/loop_agent.py b/src/google/adk/agents/loop_agent.py index 9296714f..2980f68a 100644 --- a/src/google/adk/agents/loop_agent.py +++ b/src/google/adk/agents/loop_agent.py @@ -26,8 +26,9 @@ from typing import Optional from typing_extensions import override from ..events.event import Event +from ..features import experimental +from ..features import FeatureName from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental from .base_agent import BaseAgent from .base_agent import BaseAgentState from .base_agent_config import BaseAgentConfig @@ -37,7 +38,7 @@ from .loop_agent_config import LoopAgentConfig logger = logging.getLogger('google_adk.' + __name__) -@experimental +@experimental(FeatureName.AGENT_STATE) class LoopAgentState(BaseAgentState): """State for LoopAgent.""" @@ -153,7 +154,7 @@ class LoopAgent(BaseAgent): @override @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def _parse_config( cls: type[LoopAgent], config: LoopAgentConfig, diff --git a/src/google/adk/agents/loop_agent_config.py b/src/google/adk/agents/loop_agent_config.py index 1aaa0ef9..78fc790b 100644 --- a/src/google/adk/agents/loop_agent_config.py +++ b/src/google/adk/agents/loop_agent_config.py @@ -21,11 +21,12 @@ from typing import Optional from pydantic import ConfigDict from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .base_agent_config import BaseAgentConfig -@experimental +@experimental(FeatureName.AGENT_CONFIG) class LoopAgentConfig(BaseAgentConfig): """The config for the YAML schema of a LoopAgent.""" diff --git a/src/google/adk/agents/parallel_agent_config.py b/src/google/adk/agents/parallel_agent_config.py index 77eb1a68..96a75b65 100644 --- a/src/google/adk/agents/parallel_agent_config.py +++ b/src/google/adk/agents/parallel_agent_config.py @@ -19,11 +19,12 @@ from __future__ import annotations from pydantic import ConfigDict from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .base_agent_config import BaseAgentConfig -@experimental +@experimental(FeatureName.AGENT_CONFIG) class ParallelAgentConfig(BaseAgentConfig): """The config for the YAML schema of a ParallelAgent.""" diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index eec1dea9..06a2377b 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -24,8 +24,9 @@ from typing import Type from typing_extensions import override from ..events.event import Event +from ..features import experimental +from ..features import FeatureName from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental from .base_agent import BaseAgent from .base_agent import BaseAgentState from .base_agent_config import BaseAgentConfig @@ -36,7 +37,7 @@ from .sequential_agent_config import SequentialAgentConfig logger = logging.getLogger('google_adk.' + __name__) -@experimental +@experimental(FeatureName.AGENT_STATE) class SequentialAgentState(BaseAgentState): """State for SequentialAgent.""" diff --git a/src/google/adk/agents/sequential_agent_config.py b/src/google/adk/agents/sequential_agent_config.py index 763527e9..44551c42 100644 --- a/src/google/adk/agents/sequential_agent_config.py +++ b/src/google/adk/agents/sequential_agent_config.py @@ -19,11 +19,12 @@ from __future__ import annotations from pydantic import ConfigDict from pydantic import Field -from ..agents.base_agent import experimental from ..agents.base_agent_config import BaseAgentConfig +from ..features import experimental +from ..features import FeatureName -@experimental +@experimental(FeatureName.AGENT_CONFIG) class SequentialAgentConfig(BaseAgentConfig): """The config for the YAML schema of a SequentialAgent.""" diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 81089162..9b633c2d 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -26,6 +26,8 @@ from ..utils.env_utils import is_env_enabled class FeatureName(str, Enum): """Feature names.""" + AGENT_CONFIG = "AGENT_CONFIG" + AGENT_STATE = "AGENT_STATE" AUTHENTICATED_FUNCTION_TOOL = "AUTHENTICATED_FUNCTION_TOOL" BASE_AUTHENTICATED_TOOL = "BASE_AUTHENTICATED_TOOL" BIG_QUERY_TOOLSET = "BIG_QUERY_TOOLSET" @@ -79,6 +81,12 @@ class FeatureConfig: # Central registry: FeatureName -> FeatureConfig _FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = { + FeatureName.AGENT_CONFIG: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), + FeatureName.AGENT_STATE: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.AUTHENTICATED_FUNCTION_TOOL: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), From e8019b1b1b0b43dcc5fa23075942b31db502ffdd Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 20 Feb 2026 15:23:26 -0800 Subject: [PATCH 015/102] fix: Refactor LiteLLM streaming response parsing for compatibility with LiteLLM 1.81+ Updates _model_response_to_chunk to better handle LiteLLM's streaming delta/message structure, including prioritizing delta when it contains meaningful content and preserving reasoning_content Close #4225 Co-authored-by: George Weale PiperOrigin-RevId: 873097502 --- pyproject.toml | 4 +- src/google/adk/models/lite_llm.py | 236 +++++++++++++++++-------- tests/unittests/models/test_litellm.py | 135 ++++++++++---- 3 files changed, 263 insertions(+), 112 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a65ec61..46064446 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,7 +124,7 @@ test = [ "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5, <1.80.17", # For LiteLLM tests + "litellm>=1.75.5, <2.0.0", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "openai>=1.100.2", # For LiteLLM "opentelemetry-instrumentation-google-genai>=0.3b0, <1.0.0", @@ -156,7 +156,7 @@ extensions = [ "docker>=7.0.0", # For ContainerCodeExecutor "kubernetes>=29.0.0", # For GkeCodeExecutor "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5, <1.80.17", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it + "litellm>=1.75.5, <2.0.0", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. "llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex. "lxml>=5.3.0", # For load_web_page tool. diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index b954d8a0..e85772c5 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -70,7 +70,9 @@ if TYPE_CHECKING: from litellm import Function from litellm import Message from litellm import ModelResponse + from litellm import ModelResponseStream from litellm import OpenAIMessageContent + from litellm.types.utils import Delta else: litellm = None acompletion = None @@ -85,7 +87,9 @@ else: Function = None Message = None ModelResponse = None + Delta = None OpenAIMessageContent = None + ModelResponseStream = None logger = logging.getLogger("google_adk." + __name__) @@ -151,6 +155,7 @@ _LITELLM_GLOBAL_SYMBOLS = ( "Function", "Message", "ModelResponse", + "ModelResponseStream", "OpenAIMessageContent", "acompletion", "completion", @@ -382,15 +387,11 @@ def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: ] -def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: - """Fetches the reasoning payload from a LiteLLM message or dict.""" +def _extract_reasoning_value(message: Message | Delta | None) -> Any: + """Fetches the reasoning payload from a LiteLLM message.""" if message is None: return None - if hasattr(message, "reasoning_content"): - return getattr(message, "reasoning_content") - if isinstance(message, dict): - return message.get("reasoning_content") - return None + return message.get("reasoning_content") class ChatCompletionFileUrlObject(TypedDict, total=False): @@ -1264,7 +1265,7 @@ def _function_declaration_to_tool_param( def _model_response_to_chunk( - response: ModelResponse, + response: ModelResponse | ModelResponseStream, ) -> Generator[ Tuple[ Optional[ @@ -1282,6 +1283,9 @@ def _model_response_to_chunk( ]: """Converts a litellm message to text, function or usage metadata chunk. + LiteLLM streaming chunks carry `delta`, while non-streaming chunks carry + `message`. + Args: response: The response from the model. @@ -1290,18 +1294,45 @@ def _model_response_to_chunk( """ _ensure_litellm_imported() - message = None - if response.get("choices", None): - message = response["choices"][0].get("message", None) - finish_reason = response["choices"][0].get("finish_reason", None) - # check streaming delta - if message is None and response["choices"][0].get("delta", None): - message = response["choices"][0]["delta"] + def _has_meaningful_signal(message: Message | Delta | None) -> bool: + if message is None: + return False + return bool( + message.get("content") + or message.get("tool_calls") + or message.get("function_call") + or message.get("reasoning_content") + ) + + if isinstance(response, ModelResponseStream): + message_field = "delta" + elif isinstance(response, ModelResponse): + message_field = "message" + else: + raise TypeError( + "Unexpected response type from LiteLLM: %r" % (type(response),) + ) + + choices = response.get("choices") + if not choices: + yield None, None + else: + choice = choices[0] + finish_reason = choice.get("finish_reason") + if message_field == "delta": + message = choice.get("delta") + else: + message = choice.get("message") + + if message is not None and not _has_meaningful_signal(message): + message = None message_content: Optional[OpenAIMessageContent] = None tool_calls: list[ChatCompletionMessageToolCall] = [] reasoning_parts: List[types.Part] = [] + if message is not None: + # Both Delta and Message support dict-like .get() access ( message_content, tool_calls, @@ -1318,39 +1349,46 @@ def _model_response_to_chunk( if tool_calls: for idx, tool_call in enumerate(tool_calls): - # aggregate tool_call - if tool_call.type == "function": - func_name = tool_call.function.name - func_args = tool_call.function.arguments - func_index = getattr(tool_call, "index", idx) + # LiteLLM tool call objects support dict-like .get() access + if tool_call.get("type") == "function": + function_obj = tool_call.get("function") + if not function_obj: + continue + func_name = function_obj.get("name") + func_args = function_obj.get("arguments") + func_index = tool_call.get("index", idx) + tool_call_id = tool_call.get("id") # Ignore empty chunks that don't carry any information. if not func_name and not func_args: continue yield FunctionChunk( - id=tool_call.id, + id=tool_call_id, name=func_name, args=func_args, index=func_index, ), finish_reason - if finish_reason and not (message_content or tool_calls): + if finish_reason and not (message_content or tool_calls or reasoning_parts): yield None, finish_reason - if not message: - yield None, None - # Ideally usage would be expected with the last ModelResponseStream with a # finish_reason set. But this is not the case we are observing from litellm. # So we are sending it as a separate chunk to be set on the llm_response. - if response.get("usage", None): - yield UsageMetadataChunk( - prompt_tokens=response["usage"].get("prompt_tokens", 0), - completion_tokens=response["usage"].get("completion_tokens", 0), - total_tokens=response["usage"].get("total_tokens", 0), - cached_prompt_tokens=_extract_cached_prompt_tokens(response["usage"]), - ), None + usage = response.get("usage") + if usage: + try: + yield UsageMetadataChunk( + prompt_tokens=usage.get("prompt_tokens", 0) or 0, + completion_tokens=usage.get("completion_tokens", 0) or 0, + total_tokens=usage.get("total_tokens", 0) or 0, + cached_prompt_tokens=_extract_cached_prompt_tokens(usage), + ), None + except AttributeError as e: + raise TypeError( + "Unexpected LiteLLM usage type: %r" % (type(usage),) + ) from e def _model_response_to_generate_content_response( @@ -1902,6 +1940,57 @@ class LiteLlm(BaseLlm): aggregated_llm_response_with_tool_call = None usage_metadata = None fallback_index = 0 + + def _finalize_tool_call_response( + *, model_version: str, finish_reason: str + ) -> LlmResponse: + tool_calls = [] + for index, func_data in function_calls.items(): + if func_data["id"]: + tool_calls.append( + ChatCompletionMessageToolCall( + type="function", + id=func_data["id"], + function=Function( + name=func_data["name"], + arguments=func_data["args"], + index=index, + ), + ) + ) + llm_response = _message_to_generate_content_response( + ChatCompletionAssistantMessage( + role="assistant", + content=text, + tool_calls=tool_calls, + ), + model_version=model_version, + thought_parts=list(reasoning_parts) if reasoning_parts else None, + ) + llm_response.finish_reason = _map_finish_reason(finish_reason) + return llm_response + + def _finalize_text_response( + *, model_version: str, finish_reason: str + ) -> LlmResponse: + message_content = text if text else None + llm_response = _message_to_generate_content_response( + ChatCompletionAssistantMessage( + role="assistant", + content=message_content, + ), + model_version=model_version, + thought_parts=list(reasoning_parts) if reasoning_parts else None, + ) + llm_response.finish_reason = _map_finish_reason(finish_reason) + return llm_response + + def _reset_stream_buffers() -> None: + nonlocal text, reasoning_parts + text = "" + reasoning_parts = [] + function_calls.clear() + async for part in await self.llm_client.acompletion(**completion_args): for chunk, finish_reason in _model_response_to_chunk(part): if isinstance(chunk, FunctionChunk): @@ -1951,58 +2040,49 @@ class LiteLlm(BaseLlm): cached_content_token_count=chunk.cached_prompt_tokens, ) - if ( - finish_reason == "tool_calls" or finish_reason == "stop" - ) and function_calls: - tool_calls = [] - for index, func_data in function_calls.items(): - if func_data["id"]: - tool_calls.append( - ChatCompletionMessageToolCall( - type="function", - id=func_data["id"], - function=Function( - name=func_data["name"], - arguments=func_data["args"], - index=index, - ), - ) - ) + # LiteLLM 1.81+ can set finish_reason="stop" on partial chunks. Only + # finalize tool calls on an explicit tool_calls finish_reason, or on a + # stop-only chunk (no content/tool deltas). + if function_calls and ( + finish_reason == "tool_calls" + or (finish_reason == "stop" and chunk is None) + ): aggregated_llm_response_with_tool_call = ( - _message_to_generate_content_response( - ChatCompletionAssistantMessage( - role="assistant", - content=text, - tool_calls=tool_calls, - ), + _finalize_tool_call_response( model_version=part.model, - thought_parts=list(reasoning_parts) - if reasoning_parts - else None, + finish_reason=finish_reason, ) ) - aggregated_llm_response_with_tool_call.finish_reason = ( - _map_finish_reason(finish_reason) - ) - text = "" - reasoning_parts = [] - function_calls.clear() - elif finish_reason == "stop" and (text or reasoning_parts): - message_content = text if text else None - aggregated_llm_response = _message_to_generate_content_response( - ChatCompletionAssistantMessage( - role="assistant", content=message_content - ), + _reset_stream_buffers() + elif ( + finish_reason == "stop" + and (text or reasoning_parts) + and chunk is None + and not function_calls + ): + # Only aggregate text response when we have a true stop signal + # chunk is None means no content in this chunk, just finish signal. + # LiteLLM 1.81+ sets finish_reason="stop" on partial chunks with + # content. + aggregated_llm_response = _finalize_text_response( model_version=part.model, - thought_parts=list(reasoning_parts) - if reasoning_parts - else None, + finish_reason=finish_reason, ) - aggregated_llm_response.finish_reason = _map_finish_reason( - finish_reason - ) - text = "" - reasoning_parts = [] + _reset_stream_buffers() + + if function_calls and not aggregated_llm_response_with_tool_call: + aggregated_llm_response_with_tool_call = _finalize_tool_call_response( + model_version=part.model, + finish_reason="tool_calls", + ) + _reset_stream_buffers() + + if (text or reasoning_parts) and not aggregated_llm_response: + aggregated_llm_response = _finalize_text_response( + model_version=part.model, + finish_reason="stop", + ) + _reset_stream_buffers() # waiting until streaming ends to yield the llm_response as litellm tends # to send chunk that contains usage_metadata after the chunk with diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 3e3ecce0..39f6b540 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -45,6 +45,7 @@ from google.adk.models.lite_llm import _to_litellm_role from google.adk.models.lite_llm import FunctionChunk from google.adk.models.lite_llm import LiteLlm from google.adk.models.lite_llm import LiteLLMClient +from google.adk.models.lite_llm import ReasoningChunk from google.adk.models.lite_llm import TextChunk from google.adk.models.lite_llm import UsageMetadataChunk from google.adk.models.llm_request import LlmRequest @@ -57,6 +58,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall from litellm.types.utils import Choices from litellm.types.utils import Delta from litellm.types.utils import ModelResponse +from litellm.types.utils import ModelResponseStream from litellm.types.utils import StreamingChoices from pydantic import BaseModel from pydantic import Field @@ -129,7 +131,7 @@ FILE_BYTES_TEST_CASES = [ ] STREAMING_MODEL_RESPONSE = [ - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -141,7 +143,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -153,7 +155,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -165,7 +167,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -187,7 +189,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -209,7 +211,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -532,7 +534,7 @@ def test_schema_to_dict_filters_none_enum_values(): MULTIPLE_FUNCTION_CALLS_STREAM = [ - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -553,7 +555,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -574,7 +576,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -595,7 +597,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -616,7 +618,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason="tool_calls", @@ -627,7 +629,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ STREAM_WITH_EMPTY_CHUNK = [ - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -648,7 +650,7 @@ STREAM_WITH_EMPTY_CHUNK = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -670,7 +672,7 @@ STREAM_WITH_EMPTY_CHUNK = [ ] ), # This is the problematic empty chunk that should be ignored. - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -691,7 +693,7 @@ STREAM_WITH_EMPTY_CHUNK = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[StreamingChoices(finish_reason="tool_calls", delta=Delta())] ), ] @@ -727,7 +729,7 @@ def mock_response(): # indices all 0 # finish_reason stop instead of tool_calls NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -748,7 +750,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -769,7 +771,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -790,7 +792,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -811,7 +813,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason="stop", @@ -2707,7 +2709,7 @@ def test_to_litellm_role(): "stop", ), ( - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -2729,10 +2731,10 @@ def test_to_litellm_role(): ] ), [FunctionChunk(id="1", name="test_function", args='{"key": "va')], - UsageMetadataChunk( - prompt_tokens=0, completion_tokens=0, total_tokens=0 - ), None, + # LiteLLM 1.81+ defaults finish_reason to "stop" for partial chunks, + # older versions return None. Both are valid for streaming chunks. + (None, "stop"), ), ( ModelResponse(choices=[{"finish_reason": "tool_calls"}]), @@ -2813,6 +2815,38 @@ def test_to_litellm_role(): ), "tool_calls", ), + ( + ModelResponseStream( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta(role="assistant", content="Hello"), + ) + ] + ), + [TextChunk(text="Hello")], + None, + (None, "stop"), + ), + ( + ModelResponseStream( + choices=[ + StreamingChoices( + finish_reason="stop", + delta=Delta( + role="assistant", reasoning_content="thinking..." + ), + ) + ] + ), + [ + ReasoningChunk( + parts=[types.Part(text="thinking...", thought=True)] + ) + ], + None, + "stop", + ), ], ) def test_model_response_to_chunk( @@ -2836,7 +2870,10 @@ def test_model_response_to_chunk( else: assert isinstance(chunk, type(expected_chunk)) assert chunk == expected_chunk - assert finished == expected_finished + if isinstance(expected_finished, tuple): + assert finished in expected_finished + else: + assert finished == expected_finished if expected_usage_chunk is None: assert usage_chunk is None @@ -2845,6 +2882,38 @@ def test_model_response_to_chunk( assert usage_chunk == expected_usage_chunk +def test_model_response_to_chunk_does_not_mutate_delta_object(): + """Verify that _model_response_to_chunk doesn't mutate the Delta object. + + In real streaming responses, LiteLLM's StreamingChoices only has 'delta' + (message is explicitly popped in StreamingChoices constructor). The delta + object itself carries reasoning_content when present. + """ + delta = Delta( + role="assistant", content="Hello", reasoning_content="thinking..." + ) + response = ModelResponseStream( + choices=[StreamingChoices(delta=delta, finish_reason=None)] + ) + + chunks = [chunk for chunk, _ in _model_response_to_chunk(response) if chunk] + + assert ( + ReasoningChunk(parts=[types.Part(text="thinking...", thought=True)]) + in chunks + ) + assert TextChunk(text="Hello") in chunks + + # Verify we don't accidentally mutate the original delta object. + assert delta.content == "Hello" + assert delta.reasoning_content == "thinking..." + + +def test_model_response_to_chunk_rejects_dict_response(): + with pytest.raises(TypeError): + list(_model_response_to_chunk({"choices": []})) + + @pytest.mark.asyncio async def test_acompletion_additional_args(mock_acompletion, mock_client): lite_llm_instance = LiteLlm( @@ -3056,7 +3125,7 @@ async def test_generate_content_async_stream_sets_finish_reason( mock_completion, lite_llm_instance ): mock_completion.return_value = iter([ - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -3065,7 +3134,7 @@ async def test_generate_content_async_stream_sets_finish_reason( ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -3074,7 +3143,7 @@ async def test_generate_content_async_stream_sets_finish_reason( ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[StreamingChoices(finish_reason="stop", delta=Delta())], ), @@ -3107,7 +3176,7 @@ async def test_generate_content_async_stream_with_usage_metadata( streaming_model_response_with_usage_metadata = [ *STREAMING_MODEL_RESPONSE, - ModelResponse( + ModelResponseStream( usage={ "prompt_tokens": 10, "completion_tokens": 5, @@ -3176,7 +3245,7 @@ async def test_generate_content_async_stream_with_usage_metadata( """Tests that cached prompt tokens are propagated in streaming mode.""" streaming_model_response_with_usage_metadata = [ *STREAMING_MODEL_RESPONSE, - ModelResponse( + ModelResponseStream( usage={ "prompt_tokens": 10, "completion_tokens": 5, @@ -3657,7 +3726,7 @@ async def test_finish_reason_propagation( async def test_finish_reason_unknown_maps_to_other( mock_acompletion, lite_llm_instance ): - """Test that unknown finish_reason values map to FinishReason.OTHER.""" + """Test that unmapped finish_reason values map to FinishReason.OTHER.""" mock_response = ModelResponse( choices=[ Choices( @@ -3665,7 +3734,9 @@ async def test_finish_reason_unknown_maps_to_other( role="assistant", content="Test response", ), - finish_reason="unknown_reason_type", + # LiteLLM validates finish_reason to a known set. Use a value that + # LiteLLM accepts but ADK does not explicitly map. + finish_reason="eos", ) ] ) From 4260ef0c7c37ecdfea295fb0e1a933bb0df78bea Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Fri, 20 Feb 2026 16:11:31 -0800 Subject: [PATCH 016/102] feat: Add schema auto-upgrade, tool provenance, HITL tracing, and span hierarchy fix to BigQuery Agent Analytics plugin This CL adds four enhancements to the BigQuery Agent Analytics plugin and fixes a span hierarchy corruption bug. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - **Schema Auto-Upgrade:** Additive-only schema migration that automatically adds missing columns to existing BQ tables on startup. A `adk_schema_version` label on the table (starting at `"1"`, bumped with each schema change) makes the check idempotent — the diff runs at most once per version. Enabled by default (`auto_schema_upgrade=True`) because upgrades are additive-only and fail-safe. Pre-versioning tables (no label) are treated as outdated, diffed, and stamped. No previous schema versions need to be stored; the logic diffs actual columns against the current canonical schema. - **Tool Provenance:** Adds `tool_origin` to TOOL_* event content, distinguishing six origin types — `LOCAL` (FunctionTool), `MCP` (McpTool), `A2A` (AgentTool wrapping RemoteA2aAgent), `SUB_AGENT` (AgentTool), `TRANSFER_AGENT` (TransferToAgentTool), and `UNKNOWN` (fallback) — via `isinstance()` checks with lazy imports to avoid circular dependencies. - **HITL Tracing:** Emits dedicated HITL event types (`HITL_CONFIRMATION_REQUEST`, `HITL_CREDENTIAL_REQUEST`, `HITL_INPUT_REQUEST` + `_COMPLETED` variants) for human-in-the-loop interactions. Detection lives in `on_event_callback` (for synthetic `adk_request_*` FunctionCall events emitted by the framework) and `on_user_message_callback` (for `adk_request_*` FunctionResponse completions sent by the user), not in tool callbacks — because `adk_request_*` names are synthetic function calls that bypass `before_tool_callback`/`after_tool_callback` entirely. - **Span Hierarchy Fix (#4561):** Removes `context.attach()`/`context.detach()` calls from `TraceManager.push_span()`, `attach_current_span()`, and `pop_span()`. The plugin was injecting its spans into the shared OTel context, which corrupted the framework's span hierarchy when an external exporter (e.g. `opentelemetry-instrumentation-vertexai`) was active — causing `call_llm` to be re-parented under `llm_request` and parent spans to show shorter durations than children. The plugin now tracks span_id/parent_span_id via its internal contextvar stack without mutating ambient OTel context. Co-authored-by: Haiyuan Cao PiperOrigin-RevId: 873114688 --- .../bigquery_agent_analytics_plugin.py | 285 ++++++- .../test_bigquery_agent_analytics_plugin.py | 716 ++++++++++++++++++ 2 files changed, 970 insertions(+), 31 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 5b0fcf55..97a25496 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -51,7 +51,6 @@ from google.cloud.bigquery import schema as bq_schema from google.cloud.bigquery_storage_v1 import types as bq_storage_types from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient from google.genai import types -from opentelemetry import context from opentelemetry import trace import pyarrow as pa @@ -71,6 +70,24 @@ tracer = trace.get_tracer( "google.adk.plugins.bigquery_agent_analytics", __version__ ) +# Bumped when the schema changes (1 → 2 → 3 …). Used as a table +# label for governance and to decide whether auto-upgrade should run. +_SCHEMA_VERSION = "1" +_SCHEMA_VERSION_LABEL_KEY = "adk_schema_version" + +# Human-in-the-loop (HITL) tool names that receive additional +# dedicated event types alongside the normal TOOL_* events. +_HITL_TOOL_NAMES = frozenset({ + "adk_request_credential", + "adk_request_confirmation", + "adk_request_input", +}) +_HITL_EVENT_MAP = MappingProxyType({ + "adk_request_credential": "HITL_CREDENTIAL_REQUEST", + "adk_request_confirmation": "HITL_CONFIRMATION_REQUEST", + "adk_request_input": "HITL_INPUT_REQUEST", +}) + def _safe_callback(func): """Decorator that catches and logs exceptions in plugin callbacks. @@ -132,6 +149,47 @@ def _format_content( return " | ".join(parts), truncated +def _get_tool_origin(tool: "BaseTool") -> str: + """Returns the provenance category of a tool. + + Uses lazy imports to avoid circular dependencies. + + Args: + tool: The tool instance. + + Returns: + One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT, or UNKNOWN. + """ + # Import lazily to avoid circular dependencies. + # pylint: disable=g-import-not-at-top + from ..tools.agent_tool import AgentTool # pytype: disable=import-error + from ..tools.function_tool import FunctionTool # pytype: disable=import-error + from ..tools.transfer_to_agent_tool import TransferToAgentTool # pytype: disable=import-error + + try: + from ..tools.mcp_tool.mcp_tool import McpTool # pytype: disable=import-error + except ImportError: + McpTool = None + + try: + from ..agents.remote_a2a_agent import RemoteA2aAgent # pytype: disable=import-error + except ImportError: + RemoteA2aAgent = None + + # Order matters: TransferToAgentTool is a subclass of FunctionTool. + if McpTool is not None and isinstance(tool, McpTool): + return "MCP" + if isinstance(tool, TransferToAgentTool): + return "TRANSFER_AGENT" + if isinstance(tool, AgentTool): + if RemoteA2aAgent is not None and isinstance(tool.agent, RemoteA2aAgent): + return "A2A" + return "SUB_AGENT" + if isinstance(tool, FunctionTool): + return "LOCAL" + return "UNKNOWN" + + def _recursive_smart_truncate( obj: Any, max_len: int, seen: Optional[set[int]] = None ) -> tuple[Any, bool]: @@ -435,6 +493,11 @@ class BigQueryLoggerConfig: log_session_metadata: bool = True # Static custom tags (e.g. {"agent_role": "sales"}) custom_tags: dict[str, Any] = field(default_factory=dict) + # Automatically add new columns to existing tables when the plugin + # schema evolves. Only additive changes are made (columns are never + # dropped or altered). Safe to leave enabled; a version label on the + # table ensures the diff runs at most once per schema version. + auto_schema_upgrade: bool = True # ============================================================================== @@ -450,12 +513,17 @@ _root_agent_name_ctx = contextvars.ContextVar( class _SpanRecord: """A single record on the unified span stack. - Consolidates span, token, id, ownership, and timing into one object + Consolidates span, id, ownership, and timing into one object so all stacks stay in sync by construction. + + Note: The plugin intentionally does NOT attach its spans to the + ambient OTel context (no ``context.attach``). This prevents the + plugin from corrupting the framework's span hierarchy when an + external OTel exporter (e.g. ``opentelemetry-instrumentation-vertexai``) + is active. See https://github.com/google/adk-python/issues/4561. """ span: trace.Span - token: Any # opentelemetry context token span_id: str owns_span: bool start_time_ns: int @@ -513,17 +581,26 @@ class TraceManager: @staticmethod def push_span( - callback_context: CallbackContext, span_name: Optional[str] = "adk-span" + callback_context: CallbackContext, + span_name: Optional[str] = "adk-span", ) -> str: """Starts a new span and pushes it onto the stack. - If OTel is not configured (returning non-recording spans), a UUID fallback - is generated to ensure span_id and parent_span_id are populated in logs. + The span is created but NOT attached to the ambient OTel context, + so it cannot corrupt the framework's own span hierarchy. The + plugin tracks span_id / parent_span_id internally via its own + contextvar stack. + + If OTel is not configured (returning non-recording spans), a UUID + fallback is generated to ensure span_id and parent_span_id are + populated in BigQuery logs. """ TraceManager.init_trace(callback_context) + # Create the span without attaching it to the ambient context. + # This avoids re-parenting framework spans like ``call_llm`` + # or ``execute_tool``. See #4561. span = tracer.start_span(span_name) - token = context.attach(trace.set_span_in_context(span)) if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") @@ -532,7 +609,6 @@ class TraceManager: record = _SpanRecord( span=span, - token=token, span_id=span_id_str, owns_span=True, start_time_ns=time.time_ns(), @@ -548,11 +624,14 @@ class TraceManager: def attach_current_span( callback_context: CallbackContext, ) -> str: - """Attaches the current OTEL span to the stack without owning it.""" + """Records the current OTel span on the stack without owning it. + + The span is NOT re-attached to the ambient context; it is only + tracked internally for span_id / parent_span_id resolution. + """ TraceManager.init_trace(callback_context) span = trace.get_current_span() - token = context.attach(trace.set_span_in_context(span)) if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") @@ -561,7 +640,6 @@ class TraceManager: record = _SpanRecord( span=span, - token=token, span_id=span_id_str, owns_span=False, start_time_ns=time.time_ns(), @@ -575,7 +653,11 @@ class TraceManager: @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: - """Ends the current span and pops it from the stack.""" + """Ends the current span and pops it from the stack. + + No ambient OTel context is detached because we never attached + one in the first place (see ``push_span``). + """ records = _span_records_ctx.get() if not records: return None, None @@ -595,8 +677,6 @@ class TraceManager: if record.owns_span: record.span.end() - context.detach(record.token) - return record.span_id, duration_ms @staticmethod @@ -1822,16 +1902,25 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): ) def _ensure_schema_exists(self) -> None: - """Ensures the BigQuery table exists with the correct schema.""" + """Ensures the BigQuery table exists with the correct schema. + + When ``config.auto_schema_upgrade`` is True and the table already + exists, missing columns are added automatically (additive only). + A ``adk_schema_version`` label is written for governance. + """ try: - self.client.get_table(self.full_table_id) + existing_table = self.client.get_table(self.full_table_id) + if self.config.auto_schema_upgrade: + self._maybe_upgrade_schema(existing_table) except cloud_exceptions.NotFound: logger.info("Table %s not found, creating table.", self.full_table_id) tbl = bigquery.Table(self.full_table_id, schema=self._schema) tbl.time_partitioning = bigquery.TimePartitioning( - type_=bigquery.TimePartitioningType.DAY, field="timestamp" + type_=bigquery.TimePartitioningType.DAY, + field="timestamp", ) tbl.clustering_fields = self.config.clustering_fields + tbl.labels = {_SCHEMA_VERSION_LABEL_KEY: _SCHEMA_VERSION} try: self.client.create_table(tbl) except cloud_exceptions.Conflict: @@ -1851,6 +1940,46 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): exc_info=True, ) + def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None: + """Adds missing columns to an existing table (additive only). + + Args: + existing_table: The current BigQuery table object. + """ + stored_version = (existing_table.labels or {}).get( + _SCHEMA_VERSION_LABEL_KEY + ) + if stored_version == _SCHEMA_VERSION: + return + + existing_names = {f.name for f in existing_table.schema} + new_fields = [f for f in self._schema if f.name not in existing_names] + + if new_fields: + merged = list(existing_table.schema) + new_fields + existing_table.schema = merged + logger.info( + "Auto-upgrading table %s: adding columns %s", + self.full_table_id, + [f.name for f in new_fields], + ) + + # Always stamp the version label so we skip on next run. + labels = dict(existing_table.labels or {}) + labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION + existing_table.labels = labels + + try: + update_fields = ["schema", "labels"] + self.client.update_table(existing_table, update_fields) + except Exception as e: + logger.error( + "Schema auto-upgrade failed for %s: %s", + self.full_table_id, + e, + exc_info=True, + ) + async def shutdown(self, timeout: float | None = None) -> None: """Shuts down the plugin and releases resources. @@ -2123,16 +2252,42 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): ) -> None: """Parity with V1: Logs USER_MESSAGE_RECEIVED event. + Also detects HITL completion responses (user-sent + ``FunctionResponse`` parts with ``adk_request_*`` names) and emits + dedicated ``HITL_*_COMPLETED`` events. + Args: invocation_context: The context of the current invocation. user_message: The message content received from the user. """ + callback_ctx = CallbackContext(invocation_context) await self._log_event( "USER_MESSAGE_RECEIVED", - CallbackContext(invocation_context), + callback_ctx, raw_content=user_message, ) + # Detect HITL completion responses in the user message. + if user_message and user_message.parts: + for part in user_message.parts: + if part.function_response: + hitl_event = _HITL_EVENT_MAP.get(part.function_response.name) + if hitl_event: + resp_truncated, is_truncated = _recursive_smart_truncate( + part.function_response.response or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_response.name, + "result": resp_truncated, + } + await self._log_event( + hitl_event + "_COMPLETED", + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + @_safe_callback async def on_event_callback( self, @@ -2140,24 +2295,76 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): invocation_context: InvocationContext, event: "Event", ) -> None: - """Logs state changes from events to BigQuery. + """Logs state changes and HITL events from the event stream. - Checks each event for a non-empty state_delta and logs it as a - STATE_DELTA event. This captures state changes from all sources - (tools, agents, LLM, manual), not just tool callbacks. + - Checks each event for a non-empty state_delta and logs it as a + STATE_DELTA event. + - Detects synthetic ``adk_request_*`` function calls (HITL pause + events) and their corresponding function responses (HITL + completions) and emits dedicated HITL event types. + + The HITL detection must happen here (not in tool callbacks) because + ``adk_request_credential``, ``adk_request_confirmation``, and + ``adk_request_input`` are synthetic function calls injected by the + framework — they never go through ``before_tool_callback`` / + ``after_tool_callback``. Args: invocation_context: The context for the current invocation. event: The event raised by the runner. """ + callback_ctx = CallbackContext(invocation_context) + + # --- State delta logging --- if event.actions and event.actions.state_delta: await self._log_event( "STATE_DELTA", - CallbackContext(invocation_context), + callback_ctx, event_data=EventData( extra_attributes={"state_delta": dict(event.actions.state_delta)} ), ) + + # --- HITL event logging --- + if event.content and event.content.parts: + for part in event.content.parts: + # Detect HITL function calls (request events). + if part.function_call: + hitl_event = _HITL_EVENT_MAP.get(part.function_call.name) + if hitl_event: + args_truncated, is_truncated = _recursive_smart_truncate( + part.function_call.args or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_call.name, + "args": args_truncated, + } + await self._log_event( + hitl_event, + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + # Detect HITL function responses (completion events). + if part.function_response: + hitl_event = _HITL_EVENT_MAP.get(part.function_response.name) + if hitl_event: + resp_truncated, is_truncated = _recursive_smart_truncate( + part.function_response.response or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_response.name, + "result": resp_truncated, + } + await self._log_event( + hitl_event + "_COMPLETED", + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + return None async def on_state_change_callback( @@ -2460,7 +2667,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length ) - content_dict = {"tool": tool.name, "args": args_truncated} + tool_origin = _get_tool_origin(tool) + content_dict = { + "tool": tool.name, + "args": args_truncated, + "tool_origin": tool_origin, + } TraceManager.push_span(tool_context, "tool") await self._log_event( "TOOL_STARTING", @@ -2489,20 +2701,26 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): resp_truncated, is_truncated = _recursive_smart_truncate( result, self.config.max_content_length ) - content_dict = {"tool": tool.name, "result": resp_truncated} + tool_origin = _get_tool_origin(tool) + content_dict = { + "tool": tool.name, + "result": resp_truncated, + "tool_origin": tool_origin, + } span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + event_data = EventData( + latency_ms=duration, + span_id_override=span_id, + parent_span_id_override=parent_span_id, + ) await self._log_event( "TOOL_COMPLETED", tool_context, raw_content=content_dict, is_truncated=is_truncated, - event_data=EventData( - latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, - ), + event_data=event_data, ) @_safe_callback @@ -2525,7 +2743,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length ) - content_dict = {"tool": tool.name, "args": args_truncated} + tool_origin = _get_tool_origin(tool) + content_dict = { + "tool": tool.name, + "args": args_truncated, + "tool_origin": tool_origin, + } _, duration = TraceManager.pop_span() await self._log_event( "TOOL_ERROR", diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index e9f617c4..549263fb 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -3949,3 +3949,719 @@ class TestMultiSubagentToolLogging: # All rows share the same session for row in rows: assert row["session_id"] == "session-multi" + + +class TestSchemaAutoUpgrade: + """Tests for _ensure_schema_exists with auto_schema_upgrade.""" + + def _make_plugin(self, auto_schema_upgrade=False): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + auto_schema_upgrade=auto_schema_upgrade, + ) + with mock.patch("google.cloud.bigquery.Client"): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + plugin.client = mock.MagicMock() + plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}" + plugin._schema = bigquery_agent_analytics_plugin._get_events_schema() + return plugin + + def test_create_table_sets_version_label(self): + """New tables get the schema version label.""" + plugin = self._make_plugin() + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin._ensure_schema_exists() + plugin.client.create_table.assert_called_once() + tbl = plugin.client.create_table.call_args[0][0] + assert ( + tbl.labels[bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + + def test_no_upgrade_when_disabled(self): + """Auto-upgrade disabled: existing table is not modified.""" + plugin = self._make_plugin(auto_schema_upgrade=False) + existing = mock.MagicMock() + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + plugin.client.update_table.assert_not_called() + + def test_upgrade_adds_missing_columns(self): + """Auto-upgrade adds columns missing from existing table.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {"other": "label"} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + updated_names = {f.name for f in updated_table.schema} + assert "event_type" in updated_names + assert "agent" in updated_names + assert "content" in updated_names + assert ( + updated_table.labels[ + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY + ] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + + def test_skip_upgrade_when_version_matches(self): + """No update when stored version matches current.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = plugin._schema + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + plugin.client.update_table.assert_not_called() + + def test_upgrade_error_is_logged_not_raised(self): + """Schema upgrade errors are logged, not propagated.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin.client.update_table.side_effect = Exception("boom") + # Should not raise + plugin._ensure_schema_exists() + + def test_upgrade_preserves_existing_columns(self): + """Existing columns are never dropped or altered during upgrade.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + # Simulate a table with a subset of canonical columns plus a + # user-added custom column that is NOT in the canonical schema. + custom_field = bigquery.SchemaField("my_custom_col", "STRING") + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("event_type", "STRING"), + custom_field, + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + updated_table = plugin.client.update_table.call_args[0][0] + updated_names = [f.name for f in updated_table.schema] + # Original columns are still present and in original order. + assert updated_names[0] == "timestamp" + assert updated_names[1] == "event_type" + assert updated_names[2] == "my_custom_col" + # New canonical columns were appended after existing ones. + assert "agent" in updated_names + assert "content" in updated_names + + def test_upgrade_from_no_label_treats_as_outdated(self): + """A table with no version label is treated as needing upgrade.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = list(plugin._schema) # All columns present + existing.labels = {} # No version label + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + # update_table should be called to stamp the version label even + # though no new columns were needed. + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + assert ( + updated_table.labels[ + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY + ] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + + def test_upgrade_from_older_version_label(self): + """A table with an older version label triggers upgrade.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("event_type", "STRING"), + ] + # Simulate a table stamped with an older version. + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: "0", + } + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + # Version label should be updated to current. + assert ( + updated_table.labels[ + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY + ] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + # Missing columns should have been added. + updated_names = {f.name for f in updated_table.schema} + assert "agent" in updated_names + assert "content" in updated_names + + def test_upgrade_is_idempotent(self): + """Calling _ensure_schema_exists twice doesn't double-update.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + + # First call: table exists with old schema. + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + assert plugin.client.update_table.call_count == 1 + + # Second call: table now has current version label. + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.update_table.reset_mock() + plugin._ensure_schema_exists() + plugin.client.update_table.assert_not_called() + + def test_update_table_receives_schema_and_labels_fields(self): + """update_table is called with update_fields=['schema', 'labels'].""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + call_args = plugin.client.update_table.call_args + update_fields = call_args[0][1] + assert "schema" in update_fields + assert "labels" in update_fields + + def test_auto_schema_upgrade_defaults_to_true(self): + """Default config has auto_schema_upgrade enabled.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + assert config.auto_schema_upgrade is True + + def test_create_table_conflict_is_ignored(self): + """Race condition (Conflict) during create_table is silently handled.""" + plugin = self._make_plugin() + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.create_table.side_effect = cloud_exceptions.Conflict( + "already exists" + ) + # Should not raise. + plugin._ensure_schema_exists() + + +class TestToolProvenance: + """Tests for _get_tool_origin helper.""" + + def test_function_tool_returns_local(self): + from google.adk.tools.function_tool import FunctionTool + + def dummy(): + pass + + tool = FunctionTool(dummy) + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "LOCAL" + + def test_agent_tool_returns_sub_agent(self): + from google.adk.tools.agent_tool import AgentTool + + agent = mock.MagicMock() + agent.name = "sub" + tool = AgentTool.__new__(AgentTool) + tool.agent = agent + tool._name = "sub" + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "SUB_AGENT" + + def test_transfer_tool_returns_transfer_agent(self): + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + tool = TransferToAgentTool(agent_names=["other"]) + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "TRANSFER_AGENT" + + def test_mcp_tool_returns_mcp(self): + try: + from google.adk.tools.mcp_tool.mcp_tool import McpTool + except ImportError: + pytest.skip("MCP not installed") + tool = McpTool.__new__(McpTool) + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "MCP" + + def test_a2a_agent_tool_returns_a2a(self): + from google.adk.tools.agent_tool import AgentTool + + try: + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + except ImportError: + pytest.skip("A2A agent not available") + + remote_agent = mock.MagicMock(spec=RemoteA2aAgent) + remote_agent.name = "remote" + remote_agent.description = "remote a2a agent" + tool = AgentTool.__new__(AgentTool) + tool.agent = remote_agent + tool._name = "remote" + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "A2A" + + def test_unknown_tool_returns_unknown(self): + tool = mock.MagicMock(spec=base_tool_lib.BaseTool) + tool.name = "mystery" + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "UNKNOWN" + + +class TestHITLTracing: + """Tests for HITL-specific event emission via on_event_callback. + + HITL events (``adk_request_credential``, ``adk_request_confirmation``, + ``adk_request_input``) are synthetic function calls injected by the + framework — they never pass through ``before_tool_callback`` / + ``after_tool_callback``. Detection therefore lives in + ``on_event_callback``, which inspects the event stream for these + function calls and their corresponding function responses. + """ + + def _make_fc_event(self, fc_name, args=None): + """Build a mock Event containing a function call.""" + event = mock.MagicMock(spec=event_lib.Event) + fc = types.FunctionCall(name=fc_name, args=args or {}) + part = types.Part(function_call=fc) + event.content = types.Content(role="model", parts=[part]) + event.actions = event_actions_lib.EventActions() + return event + + def _make_fr_event(self, fr_name, response=None): + """Build a mock Event containing a function response.""" + event = mock.MagicMock(spec=event_lib.Event) + fr = types.FunctionResponse(name=fr_name, response=response or {}) + part = types.Part(function_response=fr) + event.content = types.Content(role="user", parts=[part]) + event.actions = event_actions_lib.EventActions() + return event + + @pytest.mark.asyncio + async def test_hitl_confirmation_emits_additional_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fc_event("adk_request_confirmation", {"confirm": True}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + assert "HITL_CONFIRMATION_REQUEST" in event_types + + @pytest.mark.asyncio + async def test_hitl_credential_emits_additional_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fc_event("adk_request_credential", {"auth": "oauth2"}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + assert "HITL_CREDENTIAL_REQUEST" in event_types + + @pytest.mark.asyncio + async def test_hitl_completion_emits_additional_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fr_event("adk_request_confirmation", {"confirmed": True}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + assert "HITL_CONFIRMATION_REQUEST_COMPLETED" in event_types + + @pytest.mark.asyncio + async def test_regular_tool_no_hitl_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fc_event("regular_tool", {"x": 1}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + # No HITL events should be emitted for non-HITL function calls. + # on_event_callback only logs STATE_DELTA and HITL events; a regular + # function call produces neither. + assert mock_write_client.append_rows.call_count == 0 + + +# ============================================================================== +# TEST CLASS: Span Hierarchy Isolation (Issue #4561) +# ============================================================================== + + +class TestSpanHierarchyIsolation: + """Regression tests for https://github.com/google/adk-python/issues/4561. + + ``push_span()`` must NOT attach its span to the ambient OTel context. + If it does, any subsequent ``tracer.start_as_current_span()`` in the + framework (e.g. ``call_llm``, ``execute_tool``) will be incorrectly + re-parented under the plugin's span. + """ + + def test_push_span_does_not_change_ambient_context(self, callback_context): + """push_span must not mutate the current OTel span.""" + span_before = trace.get_current_span() + + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + + span_after = trace.get_current_span() + assert span_after is span_before + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + def test_attach_current_span_does_not_change_ambient_context( + self, callback_context + ): + """attach_current_span must not mutate the current OTel span.""" + span_before = trace.get_current_span() + + bigquery_agent_analytics_plugin.TraceManager.attach_current_span( + callback_context + ) + + span_after = trace.get_current_span() + assert span_after is span_before + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + def test_pop_span_does_not_change_ambient_context(self, callback_context): + """pop_span must not mutate the current OTel span.""" + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + span_before = trace.get_current_span() + + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + span_after = trace.get_current_span() + assert span_after is span_before + + def test_push_span_with_real_tracer_does_not_reparent(self, callback_context): + """With a real OTel tracer, plugin spans must not become parents + + of subsequently created framework spans. + """ + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = TracerProvider() + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + + provider.add_span_processor(SimpleSpanProcessor(exporter)) + framework_tracer = provider.get_tracer("test-framework") + + # Simulate: plugin pushes a span BEFORE the framework span + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "llm_request" + ) + + # Framework creates its own span via start_as_current_span + with framework_tracer.start_as_current_span("call_llm") as fw_span: + fw_context = fw_span.get_span_context() + + # Pop the plugin span + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + provider.shutdown() + + # Verify the framework span was NOT re-parented under the + # plugin's llm_request span + finished = exporter.get_finished_spans() + call_llm_spans = [s for s in finished if s.name == "call_llm"] + assert len(call_llm_spans) == 1 + fw_finished = call_llm_spans[0] + + # The framework span's parent should NOT be the plugin's + # llm_request span. With the fix, the plugin never + # attaches to the ambient context, so ``call_llm`` will + # have whatever parent existed before (None in this test). + assert fw_finished.parent is None + + def test_multiple_push_pop_cycles_leave_context_clean(self, callback_context): + """Multiple push/pop cycles must not leak context changes.""" + original_span = trace.get_current_span() + + for _ in range(5): + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "cycle_span" + ) + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + assert trace.get_current_span() is original_span + + +# ============================================================================== +# TEST CLASS: End-to-End HITL Tracing via Runner +# ============================================================================== + + +def _hitl_my_action( + tool_context: tool_context_lib.ToolContext, +) -> dict[str, str]: + """Tool function used by HITL end-to-end tests.""" + return {"result": f"confirmed={tool_context.tool_confirmation.confirmed}"} + + +class TestHITLTracingEndToEnd: + """End-to-end tests that run the full Runner + Plugin pipeline with + + ``FunctionTool(require_confirmation=True)`` and verify that HITL events + are logged alongside normal TOOL_* events in the BQ analytics plugin. + """ + + @pytest.fixture + def _mock_bq_infra( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + ): + """Bundle all BQ mocking fixtures.""" + yield mock_write_client + + @pytest.mark.asyncio + async def test_confirmation_flow_emits_hitl_events( + self, + _mock_bq_infra, + dummy_arrow_schema, + ): + """Full Runner pipeline: tool with require_confirmation emits + + HITL_CONFIRMATION_REQUEST and HITL_CONFIRMATION_REQUEST_COMPLETED. + """ + from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + from google.adk.tools.function_tool import FunctionTool + from google.genai.types import FunctionCall + from google.genai.types import FunctionResponse + from google.genai.types import Part + + from .. import testing_utils + + mock_write_client = _mock_bq_infra + + tool = FunctionTool(func=_hitl_my_action, require_confirmation=True) + + # -- Mock LLM: first response calls the tool, second is final text -- + llm_responses = [ + testing_utils.LlmResponse( + content=testing_utils.ModelContent( + parts=[ + Part(function_call=FunctionCall(name=tool.name, args={})) + ] + ) + ), + testing_utils.LlmResponse( + content=testing_utils.ModelContent( + parts=[Part(text="Done, action confirmed.")] + ) + ), + ] + mock_model = testing_utils.MockModel(responses=llm_responses) + + # -- Build the plugin -- + bq_plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + await bq_plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + # -- Build agent + runner WITH the plugin -- + from google.adk.agents.llm_agent import LlmAgent + + agent = LlmAgent(name="hitl_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent, plugins=[bq_plugin]) + + # -- Turn 1: user query → LLM calls tool → HITL pause -- + events_turn1 = await runner.run_async( + testing_utils.UserContent("run my_action") + ) + + # Find the adk_request_confirmation function call + confirmation_fc_id = None + for ev in events_turn1: + if ev.content and ev.content.parts: + for part in ev.content.parts: + if ( + hasattr(part, "function_call") + and part.function_call + and part.function_call.name + == REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + ): + confirmation_fc_id = part.function_call.id + break + if confirmation_fc_id: + break + + assert ( + confirmation_fc_id is not None + ), "Expected adk_request_confirmation function call in turn 1" + + # -- Turn 2: user sends confirmation → tool re-executes -- + user_confirmation = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=confirmation_fc_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"confirmed": True}, + ) + ) + ) + events_turn2 = await runner.run_async(user_confirmation) + + # -- Give the async BQ writer a moment to flush -- + await asyncio.sleep(0.2) + + # -- Collect all BQ rows -- + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + + # -- Verify standard events are present -- + assert "TOOL_STARTING" in event_types + assert "TOOL_COMPLETED" in event_types + + # -- Verify HITL-specific events are present -- + assert ( + "HITL_CONFIRMATION_REQUEST" in event_types + ), f"Expected HITL_CONFIRMATION_REQUEST in {event_types}" + assert ( + "HITL_CONFIRMATION_REQUEST_COMPLETED" in event_types + ), f"Expected HITL_CONFIRMATION_REQUEST_COMPLETED in {event_types}" + + # -- Verify HITL events have correct tool name in content -- + hitl_rows = [r for r in rows if r["event_type"].startswith("HITL_")] + for row in hitl_rows: + content = json.loads(row["content"]) if row["content"] else {} + assert content.get("tool") == "adk_request_confirmation", ( + "HITL event should reference 'adk_request_confirmation'," + f" got {content.get('tool')}" + ) + + await bq_plugin.shutdown() + + @pytest.mark.asyncio + async def test_regular_tool_does_not_emit_hitl_events( + self, + _mock_bq_infra, + dummy_arrow_schema, + ): + """A tool WITHOUT require_confirmation should not produce HITL events.""" + from google.adk.tools.function_tool import FunctionTool + from google.genai.types import FunctionCall + from google.genai.types import Part + + from .. import testing_utils + + mock_write_client = _mock_bq_infra + + def regular_tool() -> str: + return "done" + + tool = FunctionTool(func=regular_tool) + + llm_responses = [ + testing_utils.LlmResponse( + content=testing_utils.ModelContent( + parts=[ + Part(function_call=FunctionCall(name=tool.name, args={})) + ] + ) + ), + testing_utils.LlmResponse( + content=testing_utils.ModelContent(parts=[Part(text="All done.")]) + ), + ] + mock_model = testing_utils.MockModel(responses=llm_responses) + + bq_plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + await bq_plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + from google.adk.agents.llm_agent import LlmAgent + + agent = LlmAgent(name="regular_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent, plugins=[bq_plugin]) + + await runner.run_async(testing_utils.UserContent("run regular_tool")) + await asyncio.sleep(0.2) + + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + + # Standard tool events should be present + assert "TOOL_STARTING" in event_types + assert "TOOL_COMPLETED" in event_types + + # No HITL events + hitl_events = [et for et in event_types if et.startswith("HITL_")] + assert ( + hitl_events == [] + ), f"Expected no HITL events for regular tool, got {hitl_events}" + + await bq_plugin.shutdown() From 223d9a7ff52d8da702f1f436bd22e94ad78bd5da Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Fri, 20 Feb 2026 19:56:12 -0800 Subject: [PATCH 017/102] =?UTF-8?q?feat:=20Agent=20Skills=20spec=20complia?= =?UTF-8?q?nce=20=E2=80=94=20validation,=20aliases,=20scripts,=20and=20aut?= =?UTF-8?q?o-injection=20Close=20gaps=20between=20ADK's=20Agent=20Skills?= =?UTF-8?q?=20implementation=20and=20the=20public=20Agent=20Skills=20spec?= =?UTF-8?q?=20(agentskills.io/specification):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Frontmatter: add field validators for name (kebab-case, max 64), description (non-empty, max 1024), compatibility (max 500); add allowed-tools alias; add extra='allow'; add populate_by_name - utils: extract _parse_skill_md helper; use model_validate() for alias support; enforce name-dir matching; add validate_skill_dir() and read_skill_properties() - prompt: accept Union[Frontmatter, Skill]; - skill_toolset: add scripts/ resource loading; auto-inject system instruction (with inject_instruction opt-out); duplicate name check; _list_skills() returns Skill objects - sample agent: remove manual instruction (auto-injected now) Co-authored-by: Haiyuan Cao PiperOrigin-RevId: 873177060 --- contributing/samples/skills_agent/agent.py | 7 +- src/google/adk/skills/__init__.py | 2 +- src/google/adk/skills/_utils.py | 234 ++++++++++++++++++++ src/google/adk/skills/models.py | 49 +++- src/google/adk/skills/prompt.py | 13 +- src/google/adk/skills/utils.py | 118 ---------- src/google/adk/tools/skill_toolset.py | 53 +++-- tests/unittests/skills/test__utils.py | 182 +++++++++++++++ tests/unittests/skills/test_models.py | 105 +++++++++ tests/unittests/skills/test_prompt.py | 4 +- tests/unittests/skills/test_utils.py | 56 ----- tests/unittests/tools/test_skill_toolset.py | 56 ++++- 12 files changed, 667 insertions(+), 212 deletions(-) create mode 100644 src/google/adk/skills/_utils.py delete mode 100644 src/google/adk/skills/utils.py create mode 100644 tests/unittests/skills/test__utils.py delete mode 100644 tests/unittests/skills/test_utils.py diff --git a/contributing/samples/skills_agent/agent.py b/contributing/samples/skills_agent/agent.py index 39eec53c..6cd69ffb 100644 --- a/contributing/samples/skills_agent/agent.py +++ b/contributing/samples/skills_agent/agent.py @@ -19,7 +19,7 @@ import pathlib from google.adk import Agent from google.adk.skills import load_skill_from_dir from google.adk.skills import models -from google.adk.tools import skill_toolset +from google.adk.tools.skill_toolset import SkillToolset greeting_skill = models.Skill( frontmatter=models.Frontmatter( @@ -44,15 +44,12 @@ weather_skill = load_skill_from_dir( pathlib.Path(__file__).parent / "skills" / "weather_skill" ) -my_skill_toolset = skill_toolset.SkillToolset( - skills=[greeting_skill, weather_skill] -) +my_skill_toolset = SkillToolset(skills=[greeting_skill, weather_skill]) root_agent = Agent( model="gemini-2.5-flash", name="skill_user_agent", description="An agent that can use specialized skills.", - instruction=skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION, tools=[ my_skill_toolset, ], diff --git a/src/google/adk/skills/__init__.py b/src/google/adk/skills/__init__.py index 73184b2b..72bab7b6 100644 --- a/src/google/adk/skills/__init__.py +++ b/src/google/adk/skills/__init__.py @@ -14,11 +14,11 @@ """Agent Development Kit - Skills.""" +from ._utils import _load_skill_from_dir as load_skill_from_dir from .models import Frontmatter from .models import Resources from .models import Script from .models import Skill -from .utils import load_skill_from_dir __all__ = [ "Frontmatter", diff --git a/src/google/adk/skills/_utils.py b/src/google/adk/skills/_utils.py new file mode 100644 index 00000000..0bfbf30e --- /dev/null +++ b/src/google/adk/skills/_utils.py @@ -0,0 +1,234 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for Agent Skills.""" + +from __future__ import annotations + +import pathlib +from typing import Union + +import yaml + +from . import models + +_ALLOWED_FRONTMATTER_KEYS = frozenset({ + "name", + "description", + "license", + "allowed-tools", + "allowed_tools", + "metadata", + "compatibility", +}) + + +def _load_dir(directory: pathlib.Path) -> dict[str, str]: + """Recursively load files from a directory into a dictionary. + + Args: + directory: Path to the directory to load. + + Returns: + Dictionary mapping relative file paths to their string content. + """ + files = {} + if directory.exists() and directory.is_dir(): + for file_path in directory.rglob("*"): + if "__pycache__" in file_path.parts: + continue + if file_path.is_file(): + relative_path = file_path.relative_to(directory) + try: + files[str(relative_path)] = file_path.read_text(encoding="utf-8") + except UnicodeDecodeError: + # Binary files or non-UTF-8 files are skipped for text content. + continue + return files + + +def _parse_skill_md( + skill_dir: pathlib.Path, +) -> tuple[dict, str, pathlib.Path]: + """Parse SKILL.md from a skill directory. + + Args: + skill_dir: Resolved path to the skill directory. + + Returns: + Tuple of (parsed_frontmatter_dict, body_string, skill_md_path). + + Raises: + FileNotFoundError: If the directory or SKILL.md is not found. + ValueError: If SKILL.md is invalid. + """ + if not skill_dir.is_dir(): + raise FileNotFoundError(f"Skill directory '{skill_dir}' not found.") + + skill_md = None + for name in ("SKILL.md", "skill.md"): + path = skill_dir / name + if path.exists(): + skill_md = path + break + + if skill_md is None: + raise FileNotFoundError(f"SKILL.md not found in '{skill_dir}'.") + + content = skill_md.read_text(encoding="utf-8") + if not content.startswith("---"): + raise ValueError("SKILL.md must start with YAML frontmatter (---)") + + parts = content.split("---", 2) + if len(parts) < 3: + raise ValueError("SKILL.md frontmatter not properly closed with ---") + + frontmatter_str = parts[1] + body = parts[2].strip() + + try: + parsed = yaml.safe_load(frontmatter_str) + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in frontmatter: {e}") from e + + if not isinstance(parsed, dict): + raise ValueError("SKILL.md frontmatter must be a YAML mapping") + + return parsed, body, skill_md + + +def _load_skill_from_dir(skill_dir: Union[str, pathlib.Path]) -> models.Skill: + """Load a complete skill from a directory. + + Args: + skill_dir: Path to the skill directory. + + Returns: + Skill object with all components loaded. + + Raises: + FileNotFoundError: If the skill directory or SKILL.md is not found. + ValueError: If SKILL.md is invalid or the skill name does not match + the directory name. + """ + skill_dir = pathlib.Path(skill_dir).resolve() + + parsed, body, skill_md = _parse_skill_md(skill_dir) + + # Use model_validate to handle aliases like allowed-tools + frontmatter = models.Frontmatter.model_validate(parsed) + + # Validate that skill name matches the directory name + if skill_dir.name != frontmatter.name: + raise ValueError( + f"Skill name '{frontmatter.name}' does not match directory" + f" name '{skill_dir.name}'." + ) + + references = _load_dir(skill_dir / "references") + assets = _load_dir(skill_dir / "assets") + raw_scripts = _load_dir(skill_dir / "scripts") + scripts = { + name: models.Script(src=content) for name, content in raw_scripts.items() + } + + resources = models.Resources( + references=references, + assets=assets, + scripts=scripts, + ) + + return models.Skill( + frontmatter=frontmatter, + instructions=body, + resources=resources, + ) + + +def _validate_skill_dir( + skill_dir: Union[str, pathlib.Path], +) -> list[str]: + """Validate a skill directory without fully loading it. + + Checks that the directory exists, contains a valid SKILL.md with correct + frontmatter, and that the skill name matches the directory name. + + Args: + skill_dir: Path to the skill directory. + + Returns: + List of problem strings. Empty list means the skill is valid. + """ + problems: list[str] = [] + skill_dir = pathlib.Path(skill_dir).resolve() + + if not skill_dir.exists(): + return [f"Directory '{skill_dir}' does not exist."] + if not skill_dir.is_dir(): + return [f"'{skill_dir}' is not a directory."] + + skill_md = None + for name in ("SKILL.md", "skill.md"): + path = skill_dir / name + if path.exists(): + skill_md = path + break + if skill_md is None: + return [f"SKILL.md not found in '{skill_dir}'."] + + try: + parsed, _, _ = _parse_skill_md(skill_dir) + except (FileNotFoundError, ValueError) as e: + return [str(e)] + + unknown = set(parsed.keys()) - _ALLOWED_FRONTMATTER_KEYS + if unknown: + problems.append(f"Unknown frontmatter fields: {sorted(unknown)}") + + try: + frontmatter = models.Frontmatter.model_validate(parsed) + except Exception as e: + problems.append(f"Frontmatter validation error: {e}") + return problems + + if skill_dir.name != frontmatter.name: + problems.append( + f"Skill name '{frontmatter.name}' does not match directory" + f" name '{skill_dir.name}'." + ) + + return problems + + +def _read_skill_properties( + skill_dir: Union[str, pathlib.Path], +) -> models.Frontmatter: + """Read only the frontmatter properties from a skill directory. + + This is a lightweight alternative to ``load_skill_from_dir`` when you + only need the skill metadata without loading instructions or resources. + + Args: + skill_dir: Path to the skill directory. + + Returns: + Frontmatter object with the skill's metadata. + + Raises: + FileNotFoundError: If the directory or SKILL.md is not found. + ValueError: If the frontmatter is invalid. + """ + skill_dir = pathlib.Path(skill_dir).resolve() + parsed, _, _ = _parse_skill_md(skill_dir) + return models.Frontmatter.model_validate(parsed) diff --git a/src/google/adk/skills/models.py b/src/google/adk/skills/models.py index 7f5d75b4..f98b0f10 100644 --- a/src/google/adk/skills/models.py +++ b/src/google/adk/skills/models.py @@ -16,9 +16,16 @@ from __future__ import annotations +import re from typing import Optional +import unicodedata from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import field_validator + +_NAME_PATTERN = re.compile(r"^[a-z0-9]+(-[a-z0-9]+)*$") class Frontmatter(BaseModel): @@ -31,17 +38,57 @@ class Frontmatter(BaseModel): license: License for the skill (optional). compatibility: Compatibility information for the skill (optional). allowed_tools: Tool patterns the skill requires (optional, experimental). + Accepts both ``allowed_tools`` and the YAML-friendly ``allowed-tools`` + key. metadata: Key-value pairs for client-specific properties (defaults to empty dict). """ + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + ) + name: str description: str license: Optional[str] = None compatibility: Optional[str] = None - allowed_tools: Optional[str] = None + allowed_tools: Optional[str] = Field( + default=None, + alias="allowed-tools", + serialization_alias="allowed-tools", + ) metadata: dict[str, str] = {} + @field_validator("name") + @classmethod + def _validate_name(cls, v: str) -> str: + v = unicodedata.normalize("NFKC", v) + if len(v) > 64: + raise ValueError("name must be at most 64 characters") + if not _NAME_PATTERN.match(v): + raise ValueError( + "name must be lowercase kebab-case (a-z, 0-9, hyphens)," + " with no leading, trailing, or consecutive hyphens" + ) + return v + + @field_validator("description") + @classmethod + def _validate_description(cls, v: str) -> str: + if not v: + raise ValueError("description must not be empty") + if len(v) > 1024: + raise ValueError("description must be at most 1024 characters") + return v + + @field_validator("compatibility") + @classmethod + def _validate_compatibility(cls, v: Optional[str]) -> Optional[str]: + if v is not None and len(v) > 500: + raise ValueError("compatibility must be at most 500 characters") + return v + class Script(BaseModel): """Wrapper for script content.""" diff --git a/src/google/adk/skills/prompt.py b/src/google/adk/skills/prompt.py index e9840ab2..110033cd 100644 --- a/src/google/adk/skills/prompt.py +++ b/src/google/adk/skills/prompt.py @@ -18,15 +18,18 @@ from __future__ import annotations import html from typing import List +from typing import Union from . import models -def format_skills_as_xml(skills: List[models.Frontmatter]) -> str: +def format_skills_as_xml( + skills: List[Union[models.Frontmatter, models.Skill]], +) -> str: """Formats available skills into a standard XML string. Args: - skills: A list of skill frontmatter objects. + skills: A list of skill frontmatter or full skill objects. Returns: XML string with block containing each skill's @@ -38,13 +41,13 @@ def format_skills_as_xml(skills: List[models.Frontmatter]) -> str: lines = [""] - for skill in skills: + for item in skills: lines.append("") lines.append("") - lines.append(html.escape(skill.name)) + lines.append(html.escape(item.name)) lines.append("") lines.append("") - lines.append(html.escape(skill.description)) + lines.append(html.escape(item.description)) lines.append("") lines.append("") diff --git a/src/google/adk/skills/utils.py b/src/google/adk/skills/utils.py deleted file mode 100644 index deb10b2a..00000000 --- a/src/google/adk/skills/utils.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for Agent Skills.""" - -from __future__ import annotations - -import pathlib -from typing import Union - -import yaml - -from . import models - - -def _load_dir(directory: pathlib.Path) -> dict[str, str]: - """Recursively load files from a directory into a dictionary. - - Args: - directory: Path to the directory to load. - - Returns: - Dictionary mapping relative file paths to their string content. - """ - files = {} - if directory.exists() and directory.is_dir(): - for file_path in directory.rglob("*"): - if "__pycache__" in file_path.parts: - continue - if file_path.is_file(): - relative_path = file_path.relative_to(directory) - try: - files[str(relative_path)] = file_path.read_text(encoding="utf-8") - except UnicodeDecodeError: - # Binary files or non-UTF-8 files are skipped for text content. - continue - return files - - -def load_skill_from_dir(skill_dir: Union[str, pathlib.Path]) -> models.Skill: - """Load a complete skill from a directory. - - Args: - skill_dir: Path to the skill directory. - - Returns: - Skill object with all components loaded. - - Raises: - FileNotFoundError: If the skill directory or SKILL.md is not found. - ValueError: If SKILL.md is invalid. - """ - skill_dir = pathlib.Path(skill_dir).resolve() - - if not skill_dir.is_dir(): - raise FileNotFoundError(f"Skill directory '{skill_dir}' not found.") - - skill_md = None - for name in ("SKILL.md", "skill.md"): - path = skill_dir / name - if path.exists(): - skill_md = path - break - - if skill_md is None: - raise FileNotFoundError(f"SKILL.md not found in '{skill_dir}'.") - - content = skill_md.read_text(encoding="utf-8") - if not content.startswith("---"): - raise ValueError("SKILL.md must start with YAML frontmatter (---)") - - parts = content.split("---", 2) - if len(parts) < 3: - raise ValueError("SKILL.md frontmatter not properly closed with ---") - - frontmatter_str = parts[1] - body = parts[2].strip() - - try: - parsed = yaml.safe_load(frontmatter_str) - except yaml.YAMLError as e: - raise ValueError(f"Invalid YAML in frontmatter: {e}") from e - - if not isinstance(parsed, dict): - raise ValueError("SKILL.md frontmatter must be a YAML mapping") - - # Frontmatter class handles required field validation - frontmatter = models.Frontmatter(**parsed) - - references = _load_dir(skill_dir / "references") - assets = _load_dir(skill_dir / "assets") - raw_scripts = _load_dir(skill_dir / "scripts") - scripts = { - name: models.Script(src=content) for name, content in raw_scripts.items() - } - - resources = models.Resources( - references=references, - assets=assets, - scripts=scripts, - ) - - return models.Skill( - frontmatter=frontmatter, - instructions=body, - resources=resources, - ) diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index 34cad5c5..f90dfdb2 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -39,12 +39,13 @@ Skills are folders of instructions and resources that extend your capabilities f - **SKILL.md** (required): The main instruction file with skill metadata and detailed markdown instructions. - **references/** (Optional): Additional documentation or examples for skill usage. - **assets/** (Optional): Templates, scripts or other resources used by the skill. +- **scripts/** (Optional): Executable scripts that can be run via bash. This is very important: 1. If a skill seems relevant to the current user query, you MUST use the `load_skill` tool with `name=""` to read its full instructions before proceeding. 2. Once you have read the instructions, follow them exactly as documented before replying to the user. For example, If the instruction lists multiple steps, please make sure you complete all of them in order. -3. The `load_skill_resource` tool is for viewing files within a skill's directory (e.g., `references/*`, `assets/*`). Do NOT use other tools to access these files. +3. The `load_skill_resource` tool is for viewing files within a skill's directory (e.g., `references/*`, `assets/*`, `scripts/*`). Do NOT use other tools to access these files. """ @@ -74,8 +75,8 @@ class ListSkillsTool(BaseTool): async def run_async( self, *, args: dict[str, Any], tool_context: ToolContext ) -> Any: - skill_frontmatters = self._toolset._list_skills() - return prompt.format_skills_as_xml(skill_frontmatters) + skills = self._toolset._list_skills() + return prompt.format_skills_as_xml(skills) @experimental(FeatureName.SKILL_TOOLSET) @@ -131,14 +132,14 @@ class LoadSkillTool(BaseTool): @experimental(FeatureName.SKILL_TOOLSET) class LoadSkillResourceTool(BaseTool): - """Tool to load resources (references or assets) from a skill.""" + """Tool to load resources (references, assets, or scripts) from a skill.""" def __init__(self, toolset: "SkillToolset"): super().__init__( name="load_skill_resource", description=( - "Loads a resource file (from references/ or assets/) from within a" - " skill." + "Loads a resource file (from references/, assets/, or" + " scripts/) from within a skill." ), ) self._toolset = toolset @@ -158,7 +159,8 @@ class LoadSkillResourceTool(BaseTool): "type": "string", "description": ( "The relative path to the resource (e.g.," - " 'references/my_doc.md' or 'assets/template.txt')." + " 'references/my_doc.md', 'assets/template.txt'," + " or 'scripts/setup.sh')." ), }, }, @@ -197,9 +199,16 @@ class LoadSkillResourceTool(BaseTool): elif resource_path.startswith("assets/"): asset_name = resource_path[len("assets/") :] content = skill.resources.get_asset(asset_name) + elif resource_path.startswith("scripts/"): + script_name = resource_path[len("scripts/") :] + script = skill.resources.get_script(script_name) + if script is not None: + content = script.src else: return { - "error": "Path must start with 'references/' or 'assets/'.", + "error": ( + "Path must start with 'references/', 'assets/', or 'scripts/'." + ), "error_code": "INVALID_RESOURCE_PATH", } @@ -222,8 +231,19 @@ class LoadSkillResourceTool(BaseTool): class SkillToolset(BaseToolset): """A toolset for managing and interacting with agent skills.""" - def __init__(self, skills: list[models.Skill]): + def __init__( + self, + skills: list[models.Skill], + ): super().__init__() + + # Check for duplicate skill names + seen: set[str] = set() + for skill in skills: + if skill.name in seen: + raise ValueError(f"Duplicate skill name '{skill.name}'.") + seen.add(skill.name) + self._skills = {skill.name: skill for skill in skills} self._tools = [ ListSkillsTool(self), @@ -241,14 +261,17 @@ class SkillToolset(BaseToolset): """Retrieves a skill by name.""" return self._skills.get(name) - def _list_skills(self) -> list[models.Frontmatter]: - """Lists the frontmatter of all available skills.""" - return [s.frontmatter for s in self._skills.values()] + def _list_skills(self) -> list[models.Skill]: + """Lists all available skills.""" + return list(self._skills.values()) async def process_llm_request( self, *, tool_context: ToolContext, llm_request: LlmRequest ) -> None: """Processes the outgoing LLM request to include available skills.""" - skill_frontmatters = self._list_skills() - skills_xml = prompt.format_skills_as_xml(skill_frontmatters) - llm_request.append_instructions([skills_xml]) + skills = self._list_skills() + skills_xml = prompt.format_skills_as_xml(skills) + instructions = [] + instructions.append(DEFAULT_SKILL_SYSTEM_INSTRUCTION) + instructions.append(skills_xml) + llm_request.append_instructions(instructions) diff --git a/tests/unittests/skills/test__utils.py b/tests/unittests/skills/test__utils.py new file mode 100644 index 00000000..5a65648d --- /dev/null +++ b/tests/unittests/skills/test__utils.py @@ -0,0 +1,182 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for skill utilities.""" + +from google.adk.skills import load_skill_from_dir as _load_skill_from_dir +from google.adk.skills._utils import _read_skill_properties +from google.adk.skills._utils import _validate_skill_dir +import pytest + + +def test__load_skill_from_dir(tmp_path): + """Tests loading a skill from a directory.""" + skill_dir = tmp_path / "test-skill" + skill_dir.mkdir() + + skill_md_content = """--- +name: test-skill +description: Test description +--- +Test instructions +""" + (skill_dir / "SKILL.md").write_text(skill_md_content) + + # Create references + ref_dir = skill_dir / "references" + ref_dir.mkdir() + (ref_dir / "ref1.md").write_text("ref1 content") + + # Create assets + assets_dir = skill_dir / "assets" + assets_dir.mkdir() + (assets_dir / "asset1.txt").write_text("asset1 content") + + # Create scripts + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir() + (scripts_dir / "script1.sh").write_text("echo hello") + + skill = _load_skill_from_dir(skill_dir) + + assert skill.name == "test-skill" + assert skill.description == "Test description" + assert skill.instructions == "Test instructions" + assert skill.resources.get_reference("ref1.md") == "ref1 content" + assert skill.resources.get_asset("asset1.txt") == "asset1 content" + assert skill.resources.get_script("script1.sh").src == "echo hello" + + +def test_allowed_tools_yaml_key(tmp_path): + """Tests that allowed-tools YAML key loads correctly.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +allowed-tools: "some-tool-*" +--- +Instructions here +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + skill = _load_skill_from_dir(skill_dir) + assert skill.frontmatter.allowed_tools == "some-tool-*" + + +def test_name_directory_mismatch(tmp_path): + """Tests that name-directory mismatch raises ValueError.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +--- +Body +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + with pytest.raises(ValueError, match="does not match directory"): + _load_skill_from_dir(skill_dir) + + +def test_validate_skill_dir_valid(tmp_path): + """Tests validate_skill_dir with a valid skill.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +--- +Body +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + problems = _validate_skill_dir(skill_dir) + assert problems == [] + + +def test_validate_skill_dir_missing_dir(tmp_path): + """Tests validate_skill_dir with missing directory.""" + problems = _validate_skill_dir(tmp_path / "nonexistent") + assert len(problems) == 1 + assert "does not exist" in problems[0] + + +def test_validate_skill_dir_missing_skill_md(tmp_path): + """Tests validate_skill_dir with missing SKILL.md.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + problems = _validate_skill_dir(skill_dir) + assert len(problems) == 1 + assert "SKILL.md not found" in problems[0] + + +def test_validate_skill_dir_name_mismatch(tmp_path): + """Tests validate_skill_dir catches name-directory mismatch.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +--- +Body +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + problems = _validate_skill_dir(skill_dir) + assert any("does not match" in p for p in problems) + + +def test_validate_skill_dir_unknown_fields(tmp_path): + """Tests validate_skill_dir detects unknown frontmatter fields.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +unknown-field: something +--- +Body +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + problems = _validate_skill_dir(skill_dir) + assert any("Unknown frontmatter" in p for p in problems) + + +def test__read_skill_properties(tmp_path): + """Tests read_skill_properties basic usage.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A cool skill +license: MIT +--- +Body content +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + fm = _read_skill_properties(skill_dir) + assert fm.name == "my-skill" + assert fm.description == "A cool skill" + assert fm.license == "MIT" diff --git a/tests/unittests/skills/test_models.py b/tests/unittests/skills/test_models.py index 6ecdd51f..3bc7fd30 100644 --- a/tests/unittests/skills/test_models.py +++ b/tests/unittests/skills/test_models.py @@ -15,6 +15,7 @@ """Unit tests for skill models.""" from google.adk.skills import models +from pydantic import ValidationError import pytest @@ -68,3 +69,107 @@ def test_script_to_string(): """Tests Script model.""" script = models.Script(src="print('hello')") assert str(script) == "print('hello')" + + +# --- Name validation tests --- + + +def test_name_too_long(): + with pytest.raises(ValidationError, match="at most 64 characters"): + models.Frontmatter(name="a" * 65, description="desc") + + +def test_name_uppercase_rejected(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="My-Skill", description="desc") + + +def test_name_leading_hyphen(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="-my-skill", description="desc") + + +def test_name_trailing_hyphen(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="my-skill-", description="desc") + + +def test_name_consecutive_hyphens(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="my--skill", description="desc") + + +def test_name_invalid_chars_underscore(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="my_skill", description="desc") + + +def test_name_invalid_chars_ampersand(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="skill&name", description="desc") + + +def test_name_valid_passes(): + fm = models.Frontmatter(name="my-skill-2", description="desc") + assert fm.name == "my-skill-2" + + +def test_name_single_word(): + fm = models.Frontmatter(name="skill", description="desc") + assert fm.name == "skill" + + +# --- Description validation tests --- + + +def test_description_empty(): + with pytest.raises(ValidationError, match="must not be empty"): + models.Frontmatter(name="my-skill", description="") + + +def test_description_too_long(): + with pytest.raises(ValidationError, match="at most 1024 characters"): + models.Frontmatter(name="my-skill", description="x" * 1025) + + +# --- Compatibility validation tests --- + + +def test_compatibility_too_long(): + with pytest.raises(ValidationError, match="at most 500 characters"): + models.Frontmatter( + name="my-skill", description="desc", compatibility="c" * 501 + ) + + +# --- Extra field rejected --- + + +def test_extra_field_allowed(): + fm = models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "unknown_field": "value", + }) + assert fm.name == "my-skill" + + +# --- allowed-tools alias --- + + +def test_allowed_tools_alias_via_model_validate(): + fm = models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "allowed-tools": "tool-pattern", + }) + assert fm.allowed_tools == "tool-pattern" + + +def test_allowed_tools_serialization_alias(): + fm = models.Frontmatter( + name="my-skill", description="desc", allowed_tools="tool-pattern" + ) + dumped = fm.model_dump(by_alias=True) + assert "allowed-tools" in dumped + assert dumped["allowed-tools"] == "tool-pattern" diff --git a/tests/unittests/skills/test_prompt.py b/tests/unittests/skills/test_prompt.py index f5395f3c..aa48c7b8 100644 --- a/tests/unittests/skills/test_prompt.py +++ b/tests/unittests/skills/test_prompt.py @@ -42,8 +42,8 @@ class TestPrompt: def test_format_skills_as_xml_escaping(self): skills = [ - models.Frontmatter(name="skill&name", description="desc"), + models.Frontmatter(name="my-skill", description="desc"), ] xml = prompt.format_skills_as_xml(skills) - assert "skill&name" in xml + assert "my-skill" in xml assert "desc<ription>" in xml diff --git a/tests/unittests/skills/test_utils.py b/tests/unittests/skills/test_utils.py deleted file mode 100644 index d922719d..00000000 --- a/tests/unittests/skills/test_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for skill utilities.""" - -from google.adk.skills import load_skill_from_dir -import pytest - - -def test_load_skill_from_dir(tmp_path): - """Tests loading a skill from a directory.""" - skill_dir = tmp_path / "test-skill" - skill_dir.mkdir() - - skill_md_content = """--- -name: test-skill -description: Test description ---- -Test instructions -""" - (skill_dir / "SKILL.md").write_text(skill_md_content) - - # Create references - ref_dir = skill_dir / "references" - ref_dir.mkdir() - (ref_dir / "ref1.md").write_text("ref1 content") - - # Create assets - assets_dir = skill_dir / "assets" - assets_dir.mkdir() - (assets_dir / "asset1.txt").write_text("asset1 content") - - # Create scripts - scripts_dir = skill_dir / "scripts" - scripts_dir.mkdir() - (scripts_dir / "script1.sh").write_text("echo hello") - - skill = load_skill_from_dir(skill_dir) - - assert skill.name == "test-skill" - assert skill.description == "Test description" - assert skill.instructions == "Test instructions" - assert skill.resources.get_reference("ref1.md") == "ref1 content" - assert skill.resources.get_asset("asset1.txt") == "asset1 content" - assert skill.resources.get_script("script1.sh").src == "echo hello" diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index b747d1f8..066eedfb 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -39,6 +39,7 @@ def mock_skill1(mock_skill1_frontmatter): """Fixture for skill1.""" skill = mock.create_autospec(models.Skill, instance=True) skill.name = "skill1" + skill.description = "Skill 1 description" skill.instructions = "instructions for skill1" skill.frontmatter = mock_skill1_frontmatter skill.resources = mock.MagicMock( @@ -55,8 +56,14 @@ def mock_skill1(mock_skill1_frontmatter): return "asset content 1" return None + def get_script(name): + if name == "setup.sh": + return models.Script(src="echo setup") + return None + skill.resources.get_reference.side_effect = get_ref skill.resources.get_asset.side_effect = get_asset + skill.resources.get_script.side_effect = get_script return skill @@ -78,6 +85,7 @@ def mock_skill2(mock_skill2_frontmatter): """Fixture for skill2.""" skill = mock.create_autospec(models.Skill, instance=True) skill.name = "skill2" + skill.description = "Skill 2 description" skill.instructions = "instructions for skill2" skill.frontmatter = mock_skill2_frontmatter skill.resources = mock.MagicMock( @@ -114,10 +122,10 @@ def test_get_skill(mock_skill1, mock_skill2): def test_list_skills(mock_skill1, mock_skill2): toolset = skill_toolset.SkillToolset([mock_skill1, mock_skill2]) - frontmatters = toolset._list_skills() - assert len(frontmatters) == 2 - assert mock_skill1.frontmatter in frontmatters - assert mock_skill2.frontmatter in frontmatters + skills = toolset._list_skills() + assert len(skills) == 2 + assert mock_skill1 in skills + assert mock_skill2 in skills @pytest.mark.asyncio @@ -203,6 +211,14 @@ async def test_load_skill_run_async( "content": "asset content 1", }, ), + ( + {"skill_name": "skill1", "path": "scripts/setup.sh"}, + { + "skill_name": "skill1", + "path": "scripts/setup.sh", + "content": "echo setup", + }, + ), ( {"skill_name": "nonexistent", "path": "references/ref1.md"}, { @@ -223,7 +239,10 @@ async def test_load_skill_run_async( ( {"skill_name": "skill1", "path": "invalid/path.txt"}, { - "error": "Path must start with 'references/' or 'assets/'.", + "error": ( + "Path must start with 'references/', 'assets/'," + " or 'scripts/'." + ), "error_code": "INVALID_RESOURCE_PATH", }, ), @@ -266,7 +285,26 @@ async def test_process_llm_request( llm_req.append_instructions.assert_called_once() args, _ = llm_req.append_instructions.call_args instructions = args[0] - assert len(instructions) == 1 - assert "" in instructions[0] - assert "skill1" in instructions[0] - assert "skill2" in instructions[0] + assert len(instructions) == 2 + assert instructions[0] == skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + assert "" in instructions[1] + assert "skill1" in instructions[1] + assert "skill2" in instructions[1] + + +def test_duplicate_skill_name_raises(mock_skill1): + skill_dup = mock.create_autospec(models.Skill, instance=True) + skill_dup.name = "skill1" + with pytest.raises(ValueError, match="Duplicate skill name"): + skill_toolset.SkillToolset([mock_skill1, skill_dup]) + + +@pytest.mark.asyncio +async def test_scripts_resource_not_found(mock_skill1, tool_context_instance): + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.LoadSkillResourceTool(toolset) + result = await tool.run_async( + args={"skill_name": "skill1", "path": "scripts/nonexistent.sh"}, + tool_context=tool_context_instance, + ) + assert result["error_code"] == "RESOURCE_NOT_FOUND" From 7557a929398ec2a1f946500d906cef5a4f86b5d1 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 21 Feb 2026 19:04:06 -0800 Subject: [PATCH 018/102] feat: change default BigQuery table ID and update docstring The default table ID for the BigQueryAgentAnalyticsPlugin is changed from "agent_events_v2" to "agent_events". The class docstring is also updated to remove the "v2.0" reference. Co-authored-by: Haiyuan Cao PiperOrigin-RevId: 873485931 --- src/google/adk/plugins/bigquery_agent_analytics_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 97a25496..70a17f40 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -470,7 +470,7 @@ class BigQueryLoggerConfig: event_allowlist: list[str] | None = None event_denylist: list[str] | None = None max_content_length: int = 500 * 1024 # Defaults to 500KB per text block - table_id: str = "agent_events_v2" + table_id: str = "agent_events" # V2 Configuration clustering_fields: list[str] = field( @@ -1609,7 +1609,7 @@ class EventData: class BigQueryAgentAnalyticsPlugin(BasePlugin): - """BigQuery Agent Analytics Plugin (v2.0 using Write API). + """BigQuery Agent Analytics Plugin using Write API. Logs agent events (LLM requests, tool calls, etc.) to BigQuery for analytics. Uses the BigQuery Write API for efficient, asynchronous, and reliable logging. From 87fcd77caa9672f219c12e5a0e2ff65cbbaaf6f3 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 23 Feb 2026 02:15:18 -0800 Subject: [PATCH 019/102] feat: Add interceptor framework to A2aAgentExecutor This change introduces an interceptor mechanism allowing custom logic to be executed before agent runs, after each event, and after the agent run completes. New dependencies are added to support these features. PiperOrigin-RevId: 873952199 --- .../adk/a2a/executor/a2a_agent_executor.py | 72 +++++++---- src/google/adk/a2a/executor/config.py | 69 +++++++++++ .../adk/a2a/executor/executor_context.py | 49 ++++++++ src/google/adk/a2a/executor/utils.py | 67 +++++++++++ .../a2a/executor/test_a2a_agent_executor.py | 112 ++++++++++++++++++ 5 files changed, 348 insertions(+), 21 deletions(-) create mode 100644 src/google/adk/a2a/executor/config.py create mode 100644 src/google/adk/a2a/executor/executor_context.py create mode 100644 src/google/adk/a2a/executor/utils.py diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index cca728db..956b1233 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -50,7 +50,12 @@ from ..converters.request_converter import AgentRunRequest from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key from ..experimental import a2a_experimental +from .config import ExecuteInterceptor +from .executor_context import ExecutorContext from .task_result_aggregator import TaskResultAggregator +from .utils import execute_after_agent_interceptors +from .utils import execute_after_event_interceptors +from .utils import execute_before_agent_interceptors logger = logging.getLogger('google_adk.' + __name__) @@ -70,6 +75,8 @@ class A2aAgentExecutorConfig(BaseModel): ) event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events + execute_interceptors: Optional[list[ExecuteInterceptor]] = None + @a2a_experimental class A2aAgentExecutor(AgentExecutor): @@ -135,6 +142,10 @@ class A2aAgentExecutor(AgentExecutor): if not context.message: raise ValueError('A2A request must have a message') + context = await execute_before_agent_interceptors( + context, self._config.execute_interceptors + ) + # for new task, create a task submitted event if not context.current_task: await event_queue.enqueue_event( @@ -202,6 +213,13 @@ class A2aAgentExecutor(AgentExecutor): run_config=run_request.run_config, ) + self._executor_context = ExecutorContext( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + runner=runner, + ) + # publish the task working event await event_queue.enqueue_event( TaskStatusUpdateEvent( @@ -230,6 +248,15 @@ class A2aAgentExecutor(AgentExecutor): context.context_id, self._config.gen_ai_part_converter, ): + a2a_event = await execute_after_event_interceptors( + a2a_event, + self._executor_context, + adk_event, + self._config.execute_interceptors, + ) + if a2a_event is None: + continue + task_result_aggregator.process_event(a2a_event) await event_queue.enqueue_event(a2a_event) @@ -253,31 +280,34 @@ class A2aAgentExecutor(AgentExecutor): ) ) # public the final status update event - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=TaskState.completed, - timestamp=datetime.now(timezone.utc).isoformat(), - ), - context_id=context.context_id, - final=True, - ) + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.completed, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=True, ) else: - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=task_result_aggregator.task_state, - timestamp=datetime.now(timezone.utc).isoformat(), - message=task_result_aggregator.task_status_message, - ), - context_id=context.context_id, - final=True, - ) + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=task_result_aggregator.task_state, + timestamp=datetime.now(timezone.utc).isoformat(), + message=task_result_aggregator.task_status_message, + ), + context_id=context.context_id, + final=True, ) + final_event = await execute_after_agent_interceptors( + self._executor_context, + final_event, + self._config.execute_interceptors, + ) + await event_queue.enqueue_event(final_event) + async def _prepare_session( self, context: RequestContext, diff --git a/src/google/adk/a2a/executor/config.py b/src/google/adk/a2a/executor/config.py new file mode 100644 index 00000000..79e88546 --- /dev/null +++ b/src/google/adk/a2a/executor/config.py @@ -0,0 +1,69 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from typing import Awaitable +from typing import Callable +from typing import Optional +from typing import Union + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent +from a2a.types import TaskStatusUpdateEvent + +from ...events.event import Event +from ..converters.utils import _get_adk_metadata_key +from .executor_context import ExecutorContext + + +@dataclasses.dataclass +class ExecuteInterceptor: + """Interceptor for the A2aAgentExecutor.""" + + before_agent: Optional[ + Callable[[RequestContext], Awaitable[RequestContext]] + ] = None + """Hook executed before the agent starts processing the request. + + Allows inspection or modification of the incoming request context. + Must return a valid `RequestContext` to continue execution. + """ + + after_event: Optional[ + Callable[ + [ExecutorContext, A2AEvent, Event], + Awaitable[Union[A2AEvent, None]], + ] + ] = None + """Hook executed after an ADK event is converted to an A2A event. + + Allows mutating the outgoing event before it is enqueued. + Return `None` to filter out and drop the event entirely, + which also halts any subsequent interceptors in the chain. + """ + + after_agent: Optional[ + Callable[ + [ExecutorContext, TaskStatusUpdateEvent], + Awaitable[TaskStatusUpdateEvent], + ] + ] = None + """Hook executed after the agent finishes and the final event is prepared. + + Allows inspection or modification of the terminal status event (e.g., + completed or failed) before it is enqueued. Must return a valid + `TaskStatusUpdateEvent`. + """ diff --git a/src/google/adk/a2a/executor/executor_context.py b/src/google/adk/a2a/executor/executor_context.py new file mode 100644 index 00000000..313afee6 --- /dev/null +++ b/src/google/adk/a2a/executor/executor_context.py @@ -0,0 +1,49 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.runners import Runner + + +class ExecutorContext: + """Context for the executor.""" + + def __init__( + self, + app_name: str, + user_id: str, + session_id: str, + runner: Runner, + ): + self._app_name = app_name + self._user_id = user_id + self._session_id = session_id + self._runner = runner + + @property + def app_name(self) -> str: + return self._app_name + + @property + def user_id(self) -> str: + return self._user_id + + @property + def session_id(self) -> str: + return self._session_id + + @property + def runner(self) -> Runner: + return self._runner diff --git a/src/google/adk/a2a/executor/utils.py b/src/google/adk/a2a/executor/utils.py new file mode 100644 index 00000000..d01066ea --- /dev/null +++ b/src/google/adk/a2a/executor/utils.py @@ -0,0 +1,67 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Optional + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent +from a2a.types import TaskStatusUpdateEvent + +from ...events.event import Event +from ..converters.utils import _get_adk_metadata_key +from .config import ExecuteInterceptor +from .executor_context import ExecutorContext + + +async def execute_before_agent_interceptors( + context: RequestContext, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> RequestContext: + if execute_interceptors: + for interceptor in execute_interceptors: + if interceptor.before_agent: + context = await interceptor.before_agent(context) + return context + + +async def execute_after_event_interceptors( + a2a_event: A2AEvent, + executor_context: ExecutorContext, + adk_event: Event, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> Optional[A2AEvent]: + if execute_interceptors: + for interceptor in execute_interceptors: + if interceptor.after_event: + a2a_event = await interceptor.after_event( + executor_context, a2a_event, adk_event + ) + if a2a_event is None: + return None + return a2a_event + + +async def execute_after_agent_interceptors( + executor_context: ExecutorContext, + final_event: TaskStatusUpdateEvent, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> TaskStatusUpdateEvent: + if execute_interceptors: + for interceptor in reversed(execute_interceptors): + if interceptor.after_agent: + final_event = await interceptor.after_agent( + executor_context, final_event + ) + return final_event diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 40736d95..787b260f 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -17,13 +17,17 @@ from unittest.mock import Mock from unittest.mock import patch from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent from a2a.server.events.event_queue import EventQueue from a2a.types import Message +from a2a.types import Part +from a2a.types import Role from a2a.types import TaskState from a2a.types import TextPart from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig +from google.adk.a2a.executor.config import ExecuteInterceptor from google.adk.events.event import Event from google.adk.runners import RunConfig from google.adk.runners import Runner @@ -959,3 +963,111 @@ class TestA2aAgentExecutor: assert final_event.status.message == test_message assert final_event.task_id == "test-task-id" assert final_event.context_id == "test-context-id" + + @pytest.mark.asyncio + async def test_after_event_interceptors_receive_correct_arguments_and_can_modify_event( + self, + ): + """Test that after_event interceptors receive correct arguments and can modify the event.""" + # Create distinct mock objects for ADK event and A2A event + adk_event = Mock(spec=Event, name="ADK_EVENT") + a2a_event = Mock(spec=A2AEvent, name="A2A_EVENT") + modified_a2a_event = Mock(spec=A2AEvent, name="MODIFIED_A2A_EVENT") + + # Mocks for conversion + self.mock_event_converter.return_value = [a2a_event] + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Setup Interceptor + mock_interceptor = Mock(spec=ExecuteInterceptor) + + # after_event should return the modified event + async def side_effect_after_event(context, event, original_event): + return modified_a2a_event + + mock_interceptor.after_event = AsyncMock( + side_effect=side_effect_after_event + ) + mock_interceptor.before_agent = None + mock_interceptor.after_agent = None + + # Update config with interceptor + self.mock_config.execute_interceptors = [mock_interceptor] + # Re-initialize executor with updated config - but we can just update + # the config in place if it's mutable + # The executor uses self._config which is this mock_config basically. + # self.executor was initialized in setup_method with self.mock_config. + + # However, A2aAgentExecutor constructor does: self._config = config or ... + # So updating self.mock_config properties should work as + # it is the same object reference. + + # Mock context + self.mock_context.task_id = "task-1" + self.mock_context.context_id = "ctx-1" + # Ensure current_task is set so we skip the initial + # submitted event creation logic + # which might complicate this specific test if we don't care about it. + self.mock_context.current_task = Mock() + + # Mock runner.run_async to yield our ADK event + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([adk_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Configure session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + self.mock_runner._new_invocation_context.return_value = Mock() + + # We patch TaskResultAggregator just to avoid other errors and simplfy + with patch( + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_agg_class: + mock_agg = Mock() + mock_agg.task_status_message = None + mock_agg.task_state = TaskState.working + mock_agg_class.return_value = mock_agg + + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify aggregator processed the MODIFIED event + mock_agg.process_event.assert_called_with(modified_a2a_event) + + # Verification of arguments passed to interceptor + assert mock_interceptor.after_event.called + call_args = mock_interceptor.after_event.call_args + # call_args.args should be (executor_context, a2a_event, adk_event) + + passed_a2a_event = call_args.args[1] + passed_adk_event = call_args.args[2] + + # These assertions verify the bug fix + assert ( + passed_a2a_event is a2a_event + ), f"Expected A2A event to be passed as 2nd arg, but got {passed_a2a_event}" + assert ( + passed_adk_event is adk_event + ), f"Expected ADK event to be passed as 3rd arg, but got {passed_adk_event}" + + # Verify that the modified event was enqueued + # We check if enqueue_event was called with modified_a2a_event + # Note: enqueue_event is called multiple times. + + enqueued_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + ] + assert ( + modified_a2a_event in enqueued_events + ), "The modified event should have been enqueued" From ffbcc0a626deb24fe38eab402b3d6ace484115df Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 23 Feb 2026 09:27:36 -0800 Subject: [PATCH 020/102] fix: Keep query params embedded in OpenAPI paths when using httpx The migration from requests to httpx in v1.24.0 broke ApplicationIntegrationToolset because httpx replaces the URL query string when a `params` dict is passed, even if empty. The requests library merged them instead. This extracts any query parameters embedded in the URL path into the explicit params dict before passing to httpx. Close #4555 Co-authored-by: George Weale PiperOrigin-RevId: 874112143 --- .../openapi_spec_parser/rest_api_tool.py | 11 ++ .../openapi_spec_parser/test_rest_api_tool.py | 156 ++++++++++++++++++ 2 files changed, 167 insertions(+) diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 300c47e1..5f835489 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -24,6 +24,9 @@ from typing import Literal from typing import Optional from typing import Tuple from typing import Union +from urllib.parse import parse_qs +from urllib.parse import urlparse +from urllib.parse import urlunparse from fastapi.openapi.models import Operation from fastapi.openapi.models import Schema @@ -375,6 +378,14 @@ class RestApiTool(BaseTool): base_url = base_url[:-1] if base_url.endswith("/") else base_url url = f"{base_url}{self.endpoint.path.format(**path_params)}" + # Move query params embedded in the path into query_params, since httpx + # replaces (rather than merges) the URL query string when `params` is set. + parsed_url = urlparse(url) + if parsed_url.query or parsed_url.fragment: + for key, values in parse_qs(parsed_url.query).items(): + query_params.setdefault(key, values[0] if len(values) == 1 else values) + url = urlunparse(parsed_url._replace(query="", fragment="")) + # Construct body body_kwargs: Dict[str, Any] = {} request_body = self.operation.requestBody diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index 81d44f0b..1131181a 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -1268,6 +1268,162 @@ class TestRestApiTool: assert result == {"result": "success"} + def test_prepare_request_params_extracts_embedded_query_params( + self, sample_auth_credential, sample_auth_scheme + ): + """Test that query params embedded in the URL path are extracted. + + ApplicationIntegrationToolset embeds query params and fragments directly + in the OpenAPI path (e.g. '...execute?triggerId=api_trigger/Name#action'). + These must be moved into the explicit query_params dict so httpx does not + strip them when it replaces the URL query string with the `params` arg. + Regression test for https://github.com/google/adk-python/issues/4555. + """ + integration_path = ( + "/v2/projects/my-proj/locations/us-central1" + "/integrations/ExecuteConnection:execute" + "?triggerId=api_trigger/ExecuteConnection" + "#POST_files" + ) + endpoint = OperationEndpoint( + base_url="https://integrations.googleapis.com", + path=integration_path, + method="POST", + ) + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + + request_params = tool._prepare_request_params([], {}) + + # The embedded query param must appear in params + assert request_params["params"]["triggerId"] == ( + "api_trigger/ExecuteConnection" + ) + # The URL must NOT contain the query string or fragment + assert "?" not in request_params["url"] + assert "#" not in request_params["url"] + assert request_params["url"] == ( + "https://integrations.googleapis.com" + "/v2/projects/my-proj/locations/us-central1" + "/integrations/ExecuteConnection:execute" + ) + + def test_prepare_request_params_merges_embedded_and_explicit_query_params( + self, sample_auth_credential, sample_auth_scheme + ): + """Embedded URL query params merge with explicitly defined query params.""" + endpoint = OperationEndpoint( + base_url="https://example.com", + path="/api?embedded_key=embedded_val", + method="GET", + ) + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + params = [ + ApiParameter( + original_name="explicit_key", + py_name="explicit_key", + param_location="query", + param_schema=OpenAPISchema(type="string"), + ), + ] + kwargs = {"explicit_key": "explicit_val"} + + request_params = tool._prepare_request_params(params, kwargs) + + assert request_params["params"]["embedded_key"] == "embedded_val" + assert request_params["params"]["explicit_key"] == "explicit_val" + assert "?" not in request_params["url"] + + def test_prepare_request_params_explicit_query_param_takes_precedence( + self, sample_auth_credential, sample_auth_scheme + ): + """Explicitly defined query params take precedence over embedded ones.""" + endpoint = OperationEndpoint( + base_url="https://example.com", + path="/api?key=embedded", + method="GET", + ) + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + params = [ + ApiParameter( + original_name="key", + py_name="key", + param_location="query", + param_schema=OpenAPISchema(type="string"), + ), + ] + kwargs = {"key": "explicit"} + + request_params = tool._prepare_request_params(params, kwargs) + + # Explicit value wins over the embedded one + assert request_params["params"]["key"] == "explicit" + + def test_prepare_request_params_strips_fragment_only( + self, sample_auth_credential, sample_auth_scheme + ): + """Fragment-only paths (no query string) are also cleaned.""" + endpoint = OperationEndpoint( + base_url="https://example.com", + path="/api#fragment", + method="GET", + ) + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + + request_params = tool._prepare_request_params([], {}) + + assert "#" not in request_params["url"] + assert request_params["url"] == "https://example.com/api" + + def test_prepare_request_params_plain_url_unchanged( + self, sample_endpoint, sample_auth_credential, sample_auth_scheme + ): + """URLs without embedded query or fragment are not modified.""" + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=sample_endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + + request_params = tool._prepare_request_params([], {}) + + assert request_params["url"] == "https://example.com/test" + def test_snake_to_lower_camel(): assert snake_to_lower_camel("single") == "single" From 4ca904f11113c4faa3e17bb4a9662dca1f936e2e Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 23 Feb 2026 09:40:45 -0800 Subject: [PATCH 021/102] fix: Add push notification config store to agent_to_a2a This change allows users to provide a custom PushNotificationConfigStore when converting an ADK agent to an A2A Starlette application. If no custom store is provided, an InMemoryPushNotificationConfigStore is used by default, thios now lets A2A push notification configuration RPCs Close #4126 Co-authored-by: George Weale PiperOrigin-RevId: 874118109 --- src/google/adk/a2a/utils/agent_to_a2a.py | 13 +- src/google/adk/cli/fast_api.py | 7 +- .../unittests/a2a/utils/test_agent_to_a2a.py | 46 ++++++- tests/unittests/cli/test_fast_api.py | 114 +++++++++++++++--- 4 files changed, 157 insertions(+), 23 deletions(-) diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 155888bc..d6a07080 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -20,7 +20,9 @@ from typing import Union from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryPushNotificationConfigStore from a2a.server.tasks import InMemoryTaskStore +from a2a.server.tasks import PushNotificationConfigStore from a2a.types import AgentCard from starlette.applications import Starlette @@ -78,6 +80,7 @@ def to_a2a( port: int = 8000, protocol: str = "http", agent_card: Optional[Union[AgentCard, str]] = None, + push_config_store: Optional[PushNotificationConfigStore] = None, runner: Optional[Runner] = None, ) -> Starlette: """Convert an ADK agent to a A2A Starlette application. @@ -90,6 +93,9 @@ def to_a2a( agent_card: Optional pre-built AgentCard object or path to agent card JSON. If not provided, will be built automatically from the agent. + push_config_store: Optional A2A push notification config store. If not + provided, an in-memory store will be created so push-notification + config RPC methods are supported. runner: Optional pre-built Runner object. If not provided, a default runner will be created using in-memory services. @@ -127,8 +133,13 @@ def to_a2a( runner=runner or create_runner, ) + if push_config_store is None: + push_config_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandler( - agent_executor=agent_executor, task_store=task_store + agent_executor=agent_executor, + task_store=task_store, + push_config_store=push_config_store, ) # Use provided agent card or build one from the agent diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 553629f2..8f78c15f 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -525,6 +525,7 @@ def get_fast_api_app( if a2a: from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server.tasks import InMemoryPushNotificationConfigStore from a2a.server.tasks import InMemoryTaskStore from a2a.types import AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -563,8 +564,12 @@ def get_fast_api_app( runner=create_a2a_runner_loader(app_name), ) + push_config_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandler( - agent_executor=agent_executor, task_store=a2a_task_store + agent_executor=agent_executor, + task_store=a2a_task_store, + push_config_store=push_config_store, ) with (p / "agent.json").open("r", encoding="utf-8") as f: diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index a9ff6d01..21c96d7e 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import ANY from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryPushNotificationConfigStore from a2a.server.tasks import InMemoryTaskStore from a2a.types import AgentCard from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor @@ -77,7 +79,9 @@ class TestToA2A: mock_task_store_class.assert_called_once() mock_agent_executor_class.assert_called_once() mock_request_handler_class.assert_called_once_with( - agent_executor=mock_agent_executor, task_store=mock_task_store + agent_executor=mock_agent_executor, + push_config_store=ANY, + task_store=mock_task_store, ) mock_card_builder_class.assert_called_once_with( agent=self.mock_agent, rpc_url="http://localhost:8000/" @@ -122,7 +126,9 @@ class TestToA2A: mock_task_store_class.assert_called_once() mock_agent_executor_class.assert_called_once_with(runner=custom_runner) mock_request_handler_class.assert_called_once_with( - agent_executor=mock_agent_executor, task_store=mock_task_store + agent_executor=mock_agent_executor, + push_config_store=ANY, + task_store=mock_task_store, ) mock_card_builder_class.assert_called_once_with( agent=self.mock_agent, rpc_url="http://localhost:8000/" @@ -131,6 +137,42 @@ class TestToA2A: "startup", mock_app.add_event_handler.call_args[0][1] ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") + @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") + @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore") + @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") + @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") + def test_to_a2a_passes_custom_push_config_store( + self, + mock_starlette_class, + mock_card_builder_class, + mock_task_store_class, + mock_request_handler_class, + mock_agent_executor_class, + ): + """Test to_a2a forwards a custom push config store.""" + mock_app = Mock(spec=Starlette) + mock_starlette_class.return_value = mock_app + mock_task_store = Mock(spec=InMemoryTaskStore) + mock_task_store_class.return_value = mock_task_store + mock_agent_executor = Mock(spec=A2aAgentExecutor) + mock_agent_executor_class.return_value = mock_agent_executor + mock_request_handler = Mock(spec=DefaultRequestHandler) + mock_request_handler_class.return_value = mock_request_handler + mock_card_builder = Mock(spec=AgentCardBuilder) + mock_card_builder_class.return_value = mock_card_builder + + custom_push_store = InMemoryPushNotificationConfigStore() + + result = to_a2a(self.mock_agent, push_config_store=custom_push_store) + + assert result == mock_app + mock_request_handler_class.assert_called_once_with( + agent_executor=mock_agent_executor, + push_config_store=custom_push_store, + task_store=mock_task_store, + ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore") diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 913e11ae..16ee82b6 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -15,7 +15,6 @@ import asyncio import json import logging -import os from pathlib import Path import signal import tempfile @@ -677,6 +676,7 @@ def test_app_with_a2a( mock_eval_sets_manager, mock_eval_set_results_manager, temp_agents_dir_with_a2a, + monkeypatch, ): """Create a TestClient for the FastAPI app with A2A enabled.""" # Mock A2A related classes @@ -728,26 +728,22 @@ def test_app_with_a2a( mock_a2a_app.return_value = mock_app_instance # Change to temp directory - original_cwd = os.getcwd() - os.chdir(temp_agents_dir_with_a2a) + monkeypatch.chdir(temp_agents_dir_with_a2a) - try: - app = get_fast_api_app( - agents_dir=".", - web=True, - session_service_uri="", - artifact_service_uri="", - memory_service_uri="", - allow_origins=["*"], - a2a=True, - host="127.0.0.1", - port=8000, - ) + app = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + host="127.0.0.1", + port=8000, + ) - client = TestClient(app) - yield client - finally: - os.chdir(original_cwd) + client = TestClient(app) + yield client ################################################# @@ -1406,6 +1402,86 @@ def test_a2a_agent_discovery(test_app_with_a2a): logger.info("A2A agent discovery test passed") +def test_a2a_request_handler_uses_push_config_store( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + temp_agents_dir_with_a2a, + monkeypatch, +): + """Test A2A request handler gets push config store when supported.""" + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.create_session_service_from_options", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.create_artifact_service_from_options", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.create_memory_service_from_options", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store, + patch( + "a2a.server.tasks.InMemoryPushNotificationConfigStore" + ) as mock_push_config_store_class, + patch( + "google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor" + ) as mock_executor, + patch( + "a2a.server.request_handlers.DefaultRequestHandler" + ) as mock_handler, + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + ): + mock_task_store_instance = MagicMock() + mock_task_store.return_value = mock_task_store_instance + mock_push_config_store = MagicMock() + mock_push_config_store_class.return_value = mock_push_config_store + mock_executor_instance = MagicMock() + mock_executor.return_value = mock_executor_instance + mock_handler.return_value = MagicMock() + mock_a2a_app_instance = MagicMock() + mock_a2a_app_instance.routes.return_value = [] + mock_a2a_app.return_value = mock_a2a_app_instance + + monkeypatch.chdir(temp_agents_dir_with_a2a) + _ = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + host="127.0.0.1", + port=8000, + ) + + mock_handler.assert_called_once_with( + agent_executor=mock_executor_instance, + push_config_store=mock_push_config_store, + task_store=mock_task_store_instance, + ) + + def test_a2a_disabled_by_default(test_app): """Test that A2A functionality is disabled by default.""" # The regular test_app fixture has a2a=False From b1e33a90b4ba716d717e0488b84892b8a7f42aac Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 23 Feb 2026 09:56:29 -0800 Subject: [PATCH 022/102] fix: use correct msg_out/msg_err keys for Agent Engine sandbox output PiperOrigin-RevId: 874126181 --- .../adk/code_executors/agent_engine_sandbox_code_executor.py | 4 ++-- .../code_executors/test_agent_engine_sandbox_code_executor.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py index f601d045..69d1778a 100644 --- a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py +++ b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py @@ -134,8 +134,8 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): or 'file_name' not in output.metadata.attributes ): json_output_data = json.loads(output.data.decode('utf-8')) - stdout = json_output_data.get('stdout', '') - stderr = json_output_data.get('stderr', '') + stdout = json_output_data.get('msg_out', '') + stderr = json_output_data.get('msg_err', '') else: file_name = '' if ( diff --git a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py index c9480601..6022527f 100644 --- a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py +++ b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py @@ -71,7 +71,7 @@ class TestAgentEngineSandboxCodeExecutor: mock_json_output = MagicMock() mock_json_output.mime_type = "application/json" mock_json_output.data = json.dumps( - {"stdout": "hello world", "stderr": ""} + {"msg_out": "hello world", "msg_err": ""} ).encode("utf-8") mock_json_output.metadata = None From 2dbd1f25bdb1d88a6873d824b81b3dd5243332a4 Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 23 Feb 2026 11:34:58 -0800 Subject: [PATCH 023/102] fix: Add OpenAI strict JSON schema enforcement in LiteLLM This change introduces a recursive function to transform JSON schemas to meet OpenAI's strict mode requirements, including adding "additionalProperties: false" to all object schemas, making all properties required, and stripping sibling keywords from $ref nodes. The schema conversion uses deep copies to not mutating the original input Close #4573 Co-authored-by: George Weale PiperOrigin-RevId: 874174994 --- src/google/adk/models/lite_llm.py | 68 ++++++++++-- tests/unittests/models/test_litellm.py | 140 +++++++++++++++++++++++++ 2 files changed, 197 insertions(+), 11 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index e85772c5..dad5543f 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -1491,6 +1491,54 @@ def _message_to_generate_content_response( ) +def _enforce_strict_openai_schema(schema: dict[str, Any]) -> None: + """Recursively transforms a JSON schema for OpenAI strict structured outputs. + + OpenAI strict mode requires: + 1. additionalProperties: false on all object schemas (including nested/$defs). + 2. All properties listed in 'required' (no optional omissions). + 3. $ref nodes must have no sibling keywords (e.g., no 'description' next to + '$ref'). + + This function mutates the schema dict in place. + + Args: + schema: A JSON schema dictionary to transform. + """ + if not isinstance(schema, dict): + return + + # Strip sibling keywords from $ref nodes (OpenAI rejects them). + if "$ref" in schema: + for key in list(schema.keys()): + if key != "$ref": + del schema[key] + return + + # Ensure all object schemas have additionalProperties: false and list every + # property as required. + if schema.get("type") == "object" and "properties" in schema: + schema["additionalProperties"] = False + schema["required"] = sorted(schema["properties"].keys()) + + # Recurse into $defs (Pydantic's nested model definitions). + for defn in schema.get("$defs", {}).values(): + _enforce_strict_openai_schema(defn) + + # Recurse into property schemas. + for prop in schema.get("properties", {}).values(): + _enforce_strict_openai_schema(prop) + + # Recurse into combinators. + for key in ("anyOf", "oneOf", "allOf"): + for item in schema.get(key, []): + _enforce_strict_openai_schema(item) + + # Recurse into array item schemas. + if "items" in schema and isinstance(schema["items"], dict): + _enforce_strict_openai_schema(schema["items"]) + + def _to_litellm_response_format( response_schema: types.SchemaUnion, model: str, @@ -1515,7 +1563,7 @@ def _to_litellm_response_format( and schema_type.lower() in _LITELLM_STRUCTURED_TYPES ): return response_schema - schema_dict = dict(response_schema) + schema_dict = copy.deepcopy(response_schema) if "title" in schema_dict: schema_name = str(schema_dict["title"]) elif isinstance(response_schema, type) and issubclass( @@ -1526,14 +1574,18 @@ def _to_litellm_response_format( elif isinstance(response_schema, BaseModel): if isinstance(response_schema, types.Schema): # GenAI Schema instances already represent JSON schema definitions. - schema_dict = response_schema.model_dump(exclude_none=True, mode="json") + schema_dict = copy.deepcopy( + response_schema.model_dump(exclude_none=True, mode="json") + ) if "title" in schema_dict: schema_name = str(schema_dict["title"]) else: schema_dict = response_schema.__class__.model_json_schema() schema_name = response_schema.__class__.__name__ elif hasattr(response_schema, "model_dump"): - schema_dict = response_schema.model_dump(exclude_none=True, mode="json") + schema_dict = copy.deepcopy( + response_schema.model_dump(exclude_none=True, mode="json") + ) schema_name = response_schema.__class__.__name__ else: logger.warning( @@ -1551,14 +1603,8 @@ def _to_litellm_response_format( # OpenAI-compatible format (default) per LiteLLM docs: # https://docs.litellm.ai/docs/completion/json_mode - if ( - isinstance(schema_dict, dict) - and schema_dict.get("type") == "object" - and "additionalProperties" not in schema_dict - ): - # OpenAI structured outputs require explicit additionalProperties: false. - schema_dict = dict(schema_dict) - schema_dict["additionalProperties"] = False + if isinstance(schema_dict, dict): + _enforce_strict_openai_schema(schema_dict) return { "type": "json_schema", diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 39f6b540..8e353efb 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -26,6 +26,7 @@ import warnings from google.adk.models.lite_llm import _append_fallback_user_content_if_missing from google.adk.models.lite_llm import _content_to_message_param +from google.adk.models.lite_llm import _enforce_strict_openai_schema from google.adk.models.lite_llm import _FILE_ID_REQUIRED_PROVIDERS from google.adk.models.lite_llm import _FINISH_REASON_MAPPING from google.adk.models.lite_llm import _function_declaration_to_tool_param @@ -394,6 +395,145 @@ def test_to_litellm_response_format_with_dict_schema_for_openai(): assert formatted["json_schema"]["schema"]["additionalProperties"] is False +class _InnerModel(BaseModel): + value: str = Field(description="A value") + optional_field: str | None = Field(default=None, description="Optional") + + +class _OuterModel(BaseModel): + inner: _InnerModel = Field(description="Nested model") + name: str + + +class _WithList(BaseModel): + items: list[_InnerModel] = Field(description="List of items") + label: str + + +def test_enforce_strict_openai_schema_adds_additional_properties_recursively(): + """additionalProperties: false must appear on all object schemas.""" + schema = _OuterModel.model_json_schema() + + _enforce_strict_openai_schema(schema) + + # Root level + assert schema["additionalProperties"] is False + # Nested model in $defs + inner_def = schema["$defs"]["_InnerModel"] + assert inner_def["additionalProperties"] is False + + +def test_enforce_strict_openai_schema_marks_all_properties_required(): + """All properties must appear in 'required', including optional fields.""" + schema = _InnerModel.model_json_schema() + + _enforce_strict_openai_schema(schema) + + assert sorted(schema["required"]) == ["optional_field", "value"] + + +def test_enforce_strict_openai_schema_strips_ref_sibling_keywords(): + """$ref nodes must have no sibling keywords like 'description'.""" + schema = _OuterModel.model_json_schema() + # Pydantic v2 generates {"$ref": "...", "description": "..."} for nested models + inner_prop = schema["properties"]["inner"] + assert "$ref" in inner_prop, "Expected Pydantic to generate a $ref property" + assert len(inner_prop) > 1, "Expected sibling keywords alongside $ref" + + _enforce_strict_openai_schema(schema) + + inner_prop = schema["properties"]["inner"] + assert list(inner_prop.keys()) == ["$ref"] + + +def test_enforce_strict_openai_schema_handles_array_items(): + """Array item schemas should also be recursively transformed.""" + schema = _WithList.model_json_schema() + + _enforce_strict_openai_schema(schema) + + assert schema["additionalProperties"] is False + inner_def = schema["$defs"]["_InnerModel"] + assert inner_def["additionalProperties"] is False + assert sorted(inner_def["required"]) == ["optional_field", "value"] + + +def test_enforce_strict_openai_schema_preserves_anyof_and_default(): + """anyOf structure and default value for Optional fields must be preserved.""" + schema = _InnerModel.model_json_schema() + + _enforce_strict_openai_schema(schema) + + opt_prop = schema["properties"]["optional_field"] + assert opt_prop["anyOf"] == [{"type": "string"}, {"type": "null"}] + assert opt_prop["default"] is None + + +def test_to_litellm_response_format_dict_input_not_mutated(): + """Passing a raw dict should not mutate the caller's original dict.""" + schema = { + "type": "object", + "properties": { + "nested": { + "type": "object", + "properties": {"x": {"type": "string"}}, + } + }, + } + import copy + + original = copy.deepcopy(schema) + + _to_litellm_response_format(schema, model="gpt-4o") + + assert schema == original, "Caller's input dict was mutated" + + +def test_to_litellm_response_format_instance_input_for_openai(): + """Passing a BaseModel instance should produce a valid strict schema.""" + instance = _OuterModel( + inner=_InnerModel(value="test", optional_field=None), name="foo" + ) + + formatted = _to_litellm_response_format(instance, model="gpt-4o") + + assert formatted["type"] == "json_schema" + schema = formatted["json_schema"]["schema"] + assert schema["additionalProperties"] is False + inner_def = schema["$defs"]["_InnerModel"] + assert inner_def["additionalProperties"] is False + assert sorted(inner_def["required"]) == ["optional_field", "value"] + + +def test_to_litellm_response_format_nested_pydantic_for_openai(): + """Nested Pydantic model should produce a valid OpenAI strict schema.""" + formatted = _to_litellm_response_format(_OuterModel, model="gpt-4o") + + assert formatted["type"] == "json_schema" + assert formatted["json_schema"]["strict"] is True + + schema = formatted["json_schema"]["schema"] + assert schema["additionalProperties"] is False + assert sorted(schema["required"]) == ["inner", "name"] + + # $defs inner model must also be strict + inner_def = schema["$defs"]["_InnerModel"] + assert inner_def["additionalProperties"] is False + assert sorted(inner_def["required"]) == ["optional_field", "value"] + + +def test_to_litellm_response_format_nested_pydantic_for_gemini_unchanged(): + """Gemini models should NOT get the strict OpenAI transformations.""" + formatted = _to_litellm_response_format( + _OuterModel, model="gemini/gemini-2.0-flash" + ) + + assert formatted["type"] == "json_object" + schema = formatted["response_schema"] + # Gemini path should pass through the raw Pydantic schema untouched + assert schema == _OuterModel.model_json_schema() + + async def test_get_completion_inputs_uses_openai_format_for_openai_model(): """Test that _get_completion_inputs produces OpenAI-compatible format.""" llm_request = LlmRequest( From 445dc189e915ce5198e822ad7fadd6bb0880a95e Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Mon, 23 Feb 2026 12:00:59 -0800 Subject: [PATCH 024/102] fix: remove duplicate session GET when using API server, unbreak auto_session_create when using API server Co-authored-by: Sasha Sobran PiperOrigin-RevId: 874188082 --- src/google/adk/cli/adk_web_server.py | 94 ++++++----- .../adk/errors/session_not_found_error.py | 28 ++++ src/google/adk/runners.py | 8 +- tests/unittests/cli/test_fast_api.py | 146 ++++++++++++++---- tests/unittests/test_runners.py | 3 +- 5 files changed, 203 insertions(+), 76 deletions(-) create mode 100644 src/google/adk/errors/session_not_found_error.py diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index c61f855f..48587bd5 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -68,6 +68,7 @@ from ..auth.credential_service.base_credential_service import BaseCredentialServ from ..errors.already_exists_error import AlreadyExistsError from ..errors.input_validation_error import InputValidationError from ..errors.not_found_error import NotFoundError +from ..errors.session_not_found_error import SessionNotFoundError from ..evaluation.base_eval_service import InferenceConfig from ..evaluation.base_eval_service import InferenceRequest from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE @@ -1558,53 +1559,68 @@ class AdkWebServer: @app.post("/run", response_model_exclude_none=True) async def run_agent(req: RunAgentRequest) -> list[Event]: - session = await self.session_service.get_session( - app_name=req.app_name, user_id=req.user_id, session_id=req.session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") runner = await self.get_runner_async(req.app_name) - async with Aclosing( - runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - invocation_id=req.invocation_id, - ) - ) as agen: - events = [event async for event in agen] + try: + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + invocation_id=req.invocation_id, + ) + ) as agen: + events = [event async for event in agen] + except SessionNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e logger.info("Generated %s events in agent run", len(events)) logger.debug("Events generated: %s", events) return events @app.post("/run_sse") async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: - # SSE endpoint - session = await self.session_service.get_session( - app_name=req.app_name, user_id=req.user_id, session_id=req.session_id + stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE + runner = await self.get_runner_async(req.app_name) + agen = runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + run_config=RunConfig(streaming_mode=stream_mode), + invocation_id=req.invocation_id, ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") + + # Eagerly advance the generator to trigger session validation + # before the streaming response is created. This lets us return + # a proper HTTP 404 for missing sessions without a redundant + # get_session call — the Runner's single _get_or_create_session + # call is the only one that runs. + first_event = None + first_error = None + try: + first_event = await anext(agen) + except SessionNotFoundError as e: + await agen.aclose() + raise HTTPException(status_code=404, detail=str(e)) from e + except StopAsyncIteration: + await agen.aclose() + except Exception as e: + first_error = e # Convert the events to properly formatted SSE async def event_generator(): - try: - stream_mode = ( - StreamingMode.SSE if req.streaming else StreamingMode.NONE - ) - runner = await self.get_runner_async(req.app_name) - async with Aclosing( - runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - run_config=RunConfig(streaming_mode=stream_mode), - invocation_id=req.invocation_id, - ) - ) as agen: - async for event in agen: + async with Aclosing(agen): + try: + if first_error: + raise first_error + + async def all_events(): + if first_event is not None: + yield first_event + async for event in agen: + yield event + + async for event in all_events(): # ADK Web renders artifacts from `actions.artifactDelta` # during part processing *and* during action processing # 1) the original event with `artifactDelta` cleared (content) @@ -1630,9 +1646,9 @@ class AdkWebServer: "Generated event in agent run streaming: %s", sse_event ) yield f"data: {sse_event}\n\n" - except Exception as e: - logger.exception("Error in event_generator: %s", e) - yield f"data: {json.dumps({'error': str(e)})}\n\n" + except Exception as e: + logger.exception("Error in event_generator: %s", e) + yield f"data: {json.dumps({'error': str(e)})}\n\n" # Returns a streaming response with the proper media type for SSE return StreamingResponse( diff --git a/src/google/adk/errors/session_not_found_error.py b/src/google/adk/errors/session_not_found_error.py new file mode 100644 index 00000000..a870d0d2 --- /dev/null +++ b/src/google/adk/errors/session_not_found_error.py @@ -0,0 +1,28 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .not_found_error import NotFoundError + + +class SessionNotFoundError(ValueError, NotFoundError): + """Raised when a session cannot be found. + + Inherits from both ValueError (for backward compatibility) and NotFoundError + (for semantic consistency with the project's error hierarchy). + """ + + def __init__(self, message="Session not found."): + super().__init__(message) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index cdb878cf..736859fb 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -45,6 +45,7 @@ from .artifacts.base_artifact_service import BaseArtifactService from .artifacts.in_memory_artifact_service import InMemoryArtifactService from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor +from .errors.session_not_found_error import SessionNotFoundError from .events.event import Event from .events.event import EventActions from .flows.llm_flows import contents @@ -358,7 +359,7 @@ class Runner: This helper first attempts to retrieve the session. If not found and auto_create_session is True, it creates a new session with the provided - identifiers. Otherwise, it raises a ValueError with a helpful message. + identifiers. Otherwise, it raises a SessionNotFoundError. Args: user_id: The user ID of the session. @@ -368,7 +369,8 @@ class Runner: The existing or newly created `Session`. Raises: - ValueError: If the session is not found and auto_create_session is False. + SessionNotFoundError: If the session is not found and + auto_create_session is False. """ session = await self.session_service.get_session( app_name=self.app_name, user_id=user_id, session_id=session_id @@ -380,7 +382,7 @@ class Runner: ) else: message = self._format_session_not_found_message(session_id) - raise ValueError(message) + raise SessionNotFoundError(message) return session def run( diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 16ee82b6..d6ccf6e2 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -32,6 +32,7 @@ from google.adk.artifacts.base_artifact_service import ArtifactVersion from google.adk.cli import fast_api as fast_api_module from google.adk.cli.fast_api import get_fast_api_app from google.adk.errors.input_validation_error import InputValidationError +from google.adk.errors.session_not_found_error import SessionNotFoundError from google.adk.evaluation.eval_case import EvalCase from google.adk.evaluation.eval_case import Invocation from google.adk.evaluation.eval_result import EvalSetResult @@ -451,18 +452,28 @@ def mock_eval_set_results_manager(): return MockEvalSetResultsManager() -@pytest.fixture -def test_app( +def _create_test_client( mock_session_service, mock_artifact_service, mock_memory_service, mock_agent_loader, mock_eval_sets_manager, mock_eval_set_results_manager, + **app_kwargs, ): - """Create a TestClient for the FastAPI app without starting a server.""" - - # Patch multiple services and signal handlers + """Helper to create a TestClient with the given get_fast_api_app overrides.""" + defaults = dict( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=False, + host="127.0.0.1", + port=8000, + ) + defaults.update(app_kwargs) with ( patch.object(signal, "signal", autospec=True, return_value=None), patch.object( @@ -502,23 +513,28 @@ def test_app( return_value=mock_eval_set_results_manager, ), ): - # Get the FastAPI app, but don't actually run it - app = get_fast_api_app( - agents_dir=".", - web=True, - session_service_uri="", - artifact_service_uri="", - memory_service_uri="", - allow_origins=["*"], - a2a=False, # Disable A2A for most tests - host="127.0.0.1", - port=8000, - ) + app = get_fast_api_app(**defaults) + return TestClient(app) - # Create a TestClient that doesn't start a real server - client = TestClient(app) - return client +@pytest.fixture +def test_app( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """Create a TestClient for the FastAPI app without starting a server.""" + return _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + ) @pytest.fixture @@ -1106,20 +1122,9 @@ def test_agent_run_sse_yields_error_object_on_exception( """Test /run_sse streams an error object if streaming raises.""" info = create_test_session - async def run_async_raises( - self, - *, - user_id: str, - session_id: str, - invocation_id: Optional[str] = None, - new_message: Optional[types.Content] = None, - state_delta: Optional[dict[str, Any]] = None, - run_config: Optional[RunConfig] = None, - ): - del user_id, session_id, invocation_id, new_message, state_delta, run_config + async def run_async_raises(self, **kwargs): raise ValueError("boom") - if False: # pylint: disable=using-constant-test - yield _event_1() + yield # make it an async generator # pylint: disable=unreachable monkeypatch.setattr(Runner, "run_async", run_async_raises) @@ -1637,5 +1642,80 @@ def test_version_endpoint(test_app): assert "language_version" in data +@pytest.fixture +def test_app_auto_session( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """Create a TestClient with auto_create_session=True.""" + return _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + web=False, + auto_create_session=True, + ) + + +@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"]) +def test_auto_creates_session( + test_app_auto_session, test_session_info, endpoint +): + """Test /run and /run_sse auto-create sessions when auto_create_session=True.""" + payload = { + "app_name": test_session_info["app_name"], + "user_id": test_session_info["user_id"], + "session_id": "nonexistent_session", + "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, + } + + response = test_app_auto_session.post(endpoint, json=payload) + assert response.status_code == 200 + + if endpoint == "/run": + data = response.json() + assert isinstance(data, list) + assert len(data) > 0 + else: + sse_events = [ + json.loads(line.removeprefix("data: ")) + for line in response.text.splitlines() + if line.startswith("data: ") + ] + assert len(sse_events) > 0 + assert not any("error" in e for e in sse_events) + + +@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"]) +def test_returns_404_without_auto_create( + test_app, test_session_info, monkeypatch, endpoint +): + """Test /run and /run_sse return 404 for missing sessions without auto_create.""" + + async def run_async_session_not_found(self, **kwargs): + raise SessionNotFoundError(f"Session not found: {kwargs['session_id']}") + yield # make it an async generator # pylint: disable=unreachable + + monkeypatch.setattr(Runner, "run_async", run_async_session_not_found) + + payload = { + "app_name": test_session_info["app_name"], + "user_id": test_session_info["user_id"], + "session_id": "nonexistent_session", + "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, + } + + response = test_app.post(endpoint, json=payload) + assert response.status_code == 404 + assert "Session not found" in response.json()["detail"] + + if __name__ == "__main__": pytest.main(["-xvs", __file__]) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index ca7eb375..cc3abc65 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -29,6 +29,7 @@ from google.adk.apps.app import App from google.adk.apps.app import ResumabilityConfig from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.cli.utils.agent_loader import AgentLoader +from google.adk.errors.session_not_found_error import SessionNotFoundError from google.adk.events.event import Event from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner @@ -243,7 +244,7 @@ async def test_session_not_found_message_includes_alignment_hint(): new_message=types.Content(role="user", parts=[]), ) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(SessionNotFoundError) as excinfo: await agen.__anext__() await agen.aclose() From c33d614004a47d1a74951dd13628fd2300aeb9ef Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Mon, 23 Feb 2026 16:44:26 -0800 Subject: [PATCH 025/102] feat: Update Agent Registry to create AgentCard from info in get agents endpoint Get agents API design is being updated to return full AgentCard instead of agent card url - while it's being rolled out, update get agents method to instantiate agent card from existing fields. Co-authored-by: Kathy Wu PiperOrigin-RevId: 874285065 --- .../agent_registry/test_agent_registry.py | 46 ++++++++++++------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index f54cdb67..fc680869 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -15,7 +15,9 @@ from unittest.mock import MagicMock from unittest.mock import patch +from a2a.types import TransportProtocol as A2ATransport from google.adk.agents.remote_a2a_agent import RemoteA2aAgent +from google.adk.integrations.agent_registry import _ProtocolType from google.adk.integrations.agent_registry import AgentRegistry from google.adk.tools.mcp_tool.mcp_toolset import McpToolset import httpx @@ -42,22 +44,22 @@ class TestAgentRegistry: ] } uri = registry._get_connection_uri( - resource_details, protocol_binding="JSONRPC" + resource_details, protocol_binding=A2ATransport.jsonrpc ) assert uri == "https://mcp-v1main.com" def test_get_connection_uri_agent_nested_protocols(self, registry): resource_details = { "protocols": [{ - "type": "A2A_AGENT", + "type": _ProtocolType.A2A_AGENT, "interfaces": [{ "url": "https://my-agent.com", - "protocolBinding": "JSONRPC", + "protocolBinding": A2ATransport.jsonrpc, }], }] } uri = registry._get_connection_uri( - resource_details, protocol_type="A2A_AGENT" + resource_details, protocol_type=_ProtocolType.A2A_AGENT ) assert uri == "https://my-agent.com" @@ -69,29 +71,31 @@ class TestAgentRegistry: "interfaces": [{"url": "https://custom.com"}], }, { - "type": "A2A_AGENT", + "type": _ProtocolType.A2A_AGENT, "interfaces": [{ "url": "https://my-agent.com", - "protocolBinding": "HTTP_JSON", + "protocolBinding": A2ATransport.http_json, }], }, ] } # Filter by type uri = registry._get_connection_uri( - resource_details, protocol_type="A2A_AGENT" + resource_details, protocol_type=_ProtocolType.A2A_AGENT ) assert uri == "https://my-agent.com" # Filter by binding uri = registry._get_connection_uri( - resource_details, protocol_binding="HTTP_JSON" + resource_details, protocol_binding=A2ATransport.http_json ) assert uri == "https://my-agent.com" # No match uri = registry._get_connection_uri( - resource_details, protocol_type="A2A_AGENT", protocol_binding="JSONRPC" + resource_details, + protocol_type=_ProtocolType.A2A_AGENT, + protocol_binding=A2ATransport.jsonrpc, ) assert uri is None @@ -143,9 +147,10 @@ class TestAgentRegistry: mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestPrefix", - "interfaces": [ - {"url": "https://mcp.com", "protocolBinding": "JSONRPC"} - ], + "interfaces": [{ + "url": "https://mcp.com", + "protocolBinding": A2ATransport.jsonrpc, + }], } mock_response.raise_for_status = MagicMock() mock_httpx.return_value.__enter__.return_value.get.return_value = ( @@ -165,9 +170,15 @@ class TestAgentRegistry: mock_response.json.return_value = { "displayName": "TestAgent", "description": "Test Desc", - "agentSpec": { - "a2aAgentCardUrl": "https://my-agent.com/agent-card.json" - }, + "version": "1.0", + "protocols": [{ + "type": _ProtocolType.A2A_AGENT, + "interfaces": [{ + "url": "https://my-agent.com", + "protocolBinding": A2ATransport.jsonrpc, + }], + }], + "skills": [{"id": "s1", "name": "Skill 1", "description": "Desc 1"}], } mock_response.raise_for_status = MagicMock() mock_httpx.return_value.__enter__.return_value.get.return_value = ( @@ -181,7 +192,10 @@ class TestAgentRegistry: assert isinstance(agent, RemoteA2aAgent) assert agent.name == "TestAgent" assert agent.description == "Test Desc" - assert agent._agent_card_source == "https://my-agent.com/agent-card.json" + assert agent._agent_card.url == "https://my-agent.com" + assert agent._agent_card.version == "1.0" + assert len(agent._agent_card.skills) == 1 + assert agent._agent_card.skills[0].name == "Skill 1" def test_get_auth_headers(self, registry): registry._credentials.token = "fake-token" From 37d52b4caf6738437e62fe804103efe4bde363a1 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Mon, 23 Feb 2026 17:15:22 -0800 Subject: [PATCH 026/102] fix: edit copybara and BUILD config for new adk/integrations folder (added with Agent Registry) Co-authored-by: Kathy Wu PiperOrigin-RevId: 874293428 --- .../integrations/agent_registry/__init__.py | 18 ++ .../agent_registry/agent_registry.py | 281 ++++++++++++++++++ 2 files changed, 299 insertions(+) create mode 100644 src/google/adk/integrations/agent_registry/__init__.py create mode 100644 src/google/adk/integrations/agent_registry/agent_registry.py diff --git a/src/google/adk/integrations/agent_registry/__init__.py b/src/google/adk/integrations/agent_registry/__init__.py new file mode 100644 index 00000000..995ad046 --- /dev/null +++ b/src/google/adk/integrations/agent_registry/__init__.py @@ -0,0 +1,18 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .agent_registry import _ProtocolType +from .agent_registry import AgentRegistry + +__all__ = [ + 'AgentRegistry', +] diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py new file mode 100644 index 00000000..93a91df4 --- /dev/null +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -0,0 +1,281 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client library for interacting with the Google Cloud Agent Registry within ADK.""" + +from __future__ import annotations + +from enum import Enum +import logging +import os +import re +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Union +from urllib.parse import parse_qs +from urllib.parse import urlparse + +from a2a.client.client_factory import minimal_agent_card +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentSkill +from a2a.types import TransportProtocol as A2ATransport +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.agents.remote_a2a_agent import RemoteA2aAgent +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +import google.auth +import google.auth.transport.requests +import httpx + +logger = logging.getLogger("google_adk." + __name__) + +AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha" + + +class _ProtocolType(str, Enum): + """Supported agent protocol types.""" + + TYPE_UNSPECIFIED = "TYPE_UNSPECIFIED" + A2A_AGENT = "A2A_AGENT" + CUSTOM = "CUSTOM" + + +class AgentRegistry: + """Client for interacting with the Google Cloud Agent Registry service. + + Unlike a standard REST client library, this class provides higher-level + abstractions for ADK integration. It surfaces the agent registry service + methods along with helper methods like `get_mcp_toolset` and + `get_remote_a2a_agent` that automatically resolve connection details and + handle authentication to produce ready-to-use ADK components. + """ + + def __init__( + self, + project_id: Optional[str] = None, + location: Optional[str] = None, + header_provider: Optional[ + Callable[[ReadonlyContext], Dict[str, str]] + ] = None, + ): + """Initializes the AgentRegistry client. + + Args: + project_id: The Google Cloud project ID. + location: The Google Cloud location (region). + header_provider: Optional provider for custom headers. + """ + self.project_id = project_id + self.location = location + + if not self.project_id or not self.location: + raise ValueError("project_id and location must be provided") + + self._base_path = f"projects/{self.project_id}/locations/{self.location}" + self._header_provider = header_provider + try: + self._credentials, _ = google.auth.default() + except google.auth.exceptions.DefaultCredentialsError as e: + raise RuntimeError( + f"Failed to get default Google Cloud credentials: {e}" + ) from e + + def _get_auth_headers(self) -> Dict[str, str]: + """Refreshes credentials and returns authorization headers.""" + try: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + "Content-Type": "application/json", + } + quota_project_id = getattr(self._credentials, "quota_project_id", None) + if quota_project_id: + headers["x-goog-user-project"] = quota_project_id + return headers + except google.auth.exceptions.RefreshError as e: + raise RuntimeError( + f"Failed to refresh Google Cloud credentials: {e}" + ) from e + + def _make_request( + self, path: str, params: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Helper function to make GET requests to the Agent Registry API.""" + if path.startswith("projects/"): + url = f"{AGENT_REGISTRY_BASE_URL}/{path}" + else: + url = f"{AGENT_REGISTRY_BASE_URL}/{self._base_path}/{path}" + + try: + headers = self._get_auth_headers() + with httpx.Client() as client: + response = client.get(url, headers=headers, params=params) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise RuntimeError( + f"API request failed with status {e.response.status_code}:" + f" {e.response.text}" + ) from e + except httpx.RequestError as e: + raise RuntimeError(f"API request failed (network error): {e}") from e + except Exception as e: + raise RuntimeError(f"API request failed: {e}") from e + + def _get_connection_uri( + self, + resource_details: Dict[str, Any], + protocol_type: Optional[_ProtocolType] = None, + protocol_binding: Optional[A2ATransport] = None, + ) -> Optional[str]: + """Extracts the first matching URI based on type and binding filters.""" + protocols = list(resource_details.get("protocols", [])) + if "interfaces" in resource_details: + protocols.append({"interfaces": resource_details["interfaces"]}) + + for p in protocols: + if protocol_type and p.get("type") != protocol_type: + continue + for i in p.get("interfaces", []): + if protocol_binding and i.get("protocolBinding") != protocol_binding: + continue + if url := i.get("url"): + return url + + return None + + def _clean_name(self, name: str) -> str: + """Cleans a string to be a valid Python identifier for agent names.""" + clean = re.sub(r"[^a-zA-Z0-9_]", "_", name) + clean = re.sub(r"_+", "_", clean) + clean = clean.strip("_") + if clean and not clean[0].isalpha() and clean[0] != "_": + clean = "_" + clean + return clean + + # --- MCP Server Methods --- + + def list_mcp_servers( + self, + filter_str: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + ) -> Dict[str, Any]: + """Fetches a list of MCP Servers.""" + params = {} + if filter_str: + params["filter"] = filter_str + if page_size: + params["pageSize"] = str(page_size) + if page_token: + params["pageToken"] = page_token + return self._make_request("mcpServers", params=params) + + def get_mcp_server(self, name: str) -> Dict[str, Any]: + """Retrieves details of a specific MCP Server.""" + return self._make_request(name) + + def get_mcp_toolset(self, mcp_server_name: str) -> McpToolset: + """Constructs an McpToolset instance from a registered MCP Server.""" + server_details = self.get_mcp_server(mcp_server_name) + name = self._clean_name(server_details.get("displayName", mcp_server_name)) + + endpoint_uri = self._get_connection_uri( + server_details, protocol_binding=A2ATransport.jsonrpc + ) or self._get_connection_uri( + server_details, protocol_binding=A2ATransport.http_json + ) + if not endpoint_uri: + raise ValueError( + f"MCP Server endpoint URI not found for: {mcp_server_name}" + ) + + connection_params = StreamableHTTPConnectionParams( + url=endpoint_uri, headers=self._get_auth_headers() + ) + return McpToolset( + connection_params=connection_params, + tool_name_prefix=name, + header_provider=self._header_provider, + ) + + # --- Agent Methods --- + + def list_agents( + self, + filter_str: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + ) -> Dict[str, Any]: + """Fetches a list of registered A2A Agents.""" + params = {} + if filter_str: + params["filter"] = filter_str + if page_size: + params["pageSize"] = str(page_size) + if page_token: + params["pageToken"] = page_token + return self._make_request("agents", params=params) + + def get_agent_info(self, name: str) -> Dict[str, Any]: + """Retrieves detailed metadata of a specific A2A Agent.""" + return self._make_request(name) + + def get_remote_a2a_agent(self, agent_name: str) -> RemoteA2aAgent: + """Creates a RemoteA2aAgent instance for a registered A2A Agent.""" + agent_info = self.get_agent_info(agent_name) + name = self._clean_name(agent_info.get("displayName", agent_name)) + description = agent_info.get("description", "") + version = agent_info.get("version", "") + + url = self._get_connection_uri( + agent_info, protocol_type=_ProtocolType.A2A_AGENT + ) + if not url: + raise ValueError(f"A2A connection URI not found for Agent: {agent_name}") + + skills = [] + for s in agent_info.get("skills", []): + skills.append( + AgentSkill( + id=s.get("id"), + name=s.get("name"), + description=s.get("description", ""), + tags=s.get("tags", []), + examples=s.get("examples", []), + ) + ) + + agent_card = AgentCard( + name=name, + description=description, + version=version, + url=url, + skills=skills, + capabilities=AgentCapabilities(streaming=False, polling=False), + defaultInputModes=["text"], + defaultOutputModes=["text"], + ) + + return RemoteA2aAgent( + name=name, + agent_card=agent_card, + description=description, + ) From 1dbceccf36c28d693b0982b531a99877a3e75169 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 23 Feb 2026 18:46:15 -0800 Subject: [PATCH 027/102] fix: update Spanner query tools to async functions PiperOrigin-RevId: 874318392 --- src/google/adk/tools/spanner/query_tool.py | 20 +++--- src/google/adk/tools/spanner/search_tool.py | 48 ++++++------- .../tools/spanner/test_search_tool.py | 69 +++++++++++-------- .../tools/spanner/test_spanner_query_tool.py | 5 +- 4 files changed, 76 insertions(+), 66 deletions(-) diff --git a/src/google/adk/tools/spanner/query_tool.py b/src/google/adk/tools/spanner/query_tool.py index 3cdede43..24c1be60 100644 --- a/src/google/adk/tools/spanner/query_tool.py +++ b/src/google/adk/tools/spanner/query_tool.py @@ -14,9 +14,9 @@ from __future__ import annotations +import asyncio import functools import textwrap -import types from typing import Callable from google.auth.credentials import Credentials @@ -27,7 +27,7 @@ from .settings import QueryResultMode from .settings import SpannerToolSettings -def execute_sql( +async def execute_sql( project_id: str, instance_id: str, database_id: str, @@ -82,7 +82,8 @@ def execute_sql( Note: This is running with Read-Only Transaction for query that only read data. """ - return utils.execute_sql( + return await asyncio.to_thread( + utils.execute_sql, project_id, instance_id, database_id, @@ -179,15 +180,10 @@ def get_execute_sql(settings: SpannerToolSettings) -> Callable[..., dict]: if settings and settings.query_result_mode is QueryResultMode.DICT_LIST: - execute_sql_wrapper = types.FunctionType( - execute_sql.__code__, - execute_sql.__globals__, - execute_sql.__name__, - execute_sql.__defaults__, - execute_sql.__closure__, - ) - functools.update_wrapper(execute_sql_wrapper, execute_sql) - # Update with the new docstring + @functools.wraps(execute_sql) + async def execute_sql_wrapper(*args, **kwargs) -> dict: + return await execute_sql(*args, **kwargs) + execute_sql_wrapper.__doc__ = _EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING return execute_sql_wrapper diff --git a/src/google/adk/tools/spanner/search_tool.py b/src/google/adk/tools/spanner/search_tool.py index 03f695b8..6fb4a93f 100644 --- a/src/google/adk/tools/spanner/search_tool.py +++ b/src/google/adk/tools/spanner/search_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import json from typing import Any from typing import Dict @@ -230,7 +231,7 @@ def _generate_sql_for_ann( """ -def similarity_search( +async def similarity_search( project_id: str, instance_id: str, database_id: str, @@ -462,13 +463,16 @@ def similarity_search( # Generate embedding for the query according to the embedding options. if vertex_ai_embedding_model_name: - embedding = utils.embed_contents( - vertex_ai_embedding_model_name, - [query], - output_dimensionality, + embedding = ( + await utils.embed_contents_async( + vertex_ai_embedding_model_name, + [query], + output_dimensionality, + ) )[0] else: - embedding = _get_embedding_for_query( + embedding = await asyncio.to_thread( + _get_embedding_for_query, database, database.database_dialect, spanner_gsql_embedding_model_name, @@ -507,22 +511,20 @@ def similarity_search( else: params = {_GOOGLESQL_PARAMETER_QUERY_EMBEDDING: embedding} - with database.snapshot() as snapshot: - result_set = snapshot.execute_sql(sql, params=params) - rows = [] - result = {} - for row in result_set: - try: - # if the json serialization of the row succeeds, use it as is - json.dumps(row) - except (TypeError, ValueError, OverflowError): - row = str(row) + def _execute_sql(): + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql(sql, params=params) + rows = [] + for row in result_set: + try: + # If the json serialization of the row succeeds, use it as is + json.dumps(row) + except (TypeError, ValueError, OverflowError): + row = str(row) + rows.append(row) + return {"status": "SUCCESS", "rows": rows} - rows.append(row) - - result["status"] = "SUCCESS" - result["rows"] = rows - return result + return await asyncio.to_thread(_execute_sql) except Exception as ex: return { "status": "ERROR", @@ -530,7 +532,7 @@ def similarity_search( } -def vector_store_similarity_search( +async def vector_store_similarity_search( query: str, credentials: Credentials, settings: SpannerToolSettings, @@ -605,7 +607,7 @@ def vector_store_similarity_search( settings.vector_store_settings.num_leaves_to_search ) - return similarity_search( + return await similarity_search( project_id=settings.vector_store_settings.project_id, instance_id=settings.vector_store_settings.instance_id, database_id=settings.vector_store_settings.database_id, diff --git a/tests/unittests/tools/spanner/test_search_tool.py b/tests/unittests/tools/spanner/test_search_tool.py index 4532dd56..c6a6c742 100644 --- a/tests/unittests/tools/spanner/test_search_tool.py +++ b/tests/unittests/tools/spanner/test_search_tool.py @@ -54,11 +54,12 @@ def mock_spanner_ids(): ), ], ) -@mock.patch.object(utils, "embed_contents") +@pytest.mark.asyncio +@mock.patch.object(utils, "embed_contents_async", autospec=True) @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_knn_success( +async def test_similarity_search_knn_success( mock_get_spanner_client, - mock_embed_contents, + mock_embed_contents_async, mock_spanner_ids, mock_credentials, embedding_option_key, @@ -77,7 +78,7 @@ def test_similarity_search_knn_success( mock_get_spanner_client.return_value = mock_spanner_client if embedding_option_key == "vertex_ai_embedding_model_name": - mock_embed_contents.return_value = [expected_embedding] + mock_embed_contents_async.return_value = [expected_embedding] # execute_sql is called once for the kNN search mock_snapshot.execute_sql.return_value = iter([("result1",), ("result2",)]) else: @@ -90,7 +91,7 @@ def test_similarity_search_knn_success( iter([("result1",), ("result2",)]), ] - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -111,13 +112,14 @@ def test_similarity_search_knn_success( assert "@embedding" in sql assert call_args.kwargs == {"params": {"embedding": expected_embedding}} if embedding_option_key == "vertex_ai_embedding_model_name": - mock_embed_contents.assert_called_once_with( + mock_embed_contents_async.assert_called_once_with( embedding_option_value, ["test query"], None ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_ann_success( +async def test_similarity_search_ann_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search function with ANN success.""" @@ -139,7 +141,7 @@ def test_similarity_search_ann_success( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -164,13 +166,14 @@ def test_similarity_search_ann_success( assert call_args.kwargs == {"params": {"embedding": [0.1, 0.2, 0.3]}} +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_error( +async def test_similarity_search_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search function with a generic error.""" mock_get_spanner_client.side_effect = Exception("Test Exception") - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -187,11 +190,12 @@ def test_similarity_search_error( assert "Test Exception" in result["error_details"] -@mock.patch.object(utils, "embed_contents") +@pytest.mark.asyncio +@mock.patch.object(utils, "embed_contents_async") @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_circular_row_fallback_to_string( +async def test_similarity_search_circular_row_fallback_to_string( mock_get_spanner_client, - mock_embed_contents, + mock_embed_contents_async, mock_spanner_ids, mock_credentials, ): @@ -202,7 +206,7 @@ def test_similarity_search_circular_row_fallback_to_string( mock_snapshot = MagicMock() circular_row = [] circular_row.append(circular_row) - mock_embed_contents.return_value = [[0.1, 0.2, 0.3]] + mock_embed_contents_async.return_value = [[0.1, 0.2, 0.3]] mock_snapshot.execute_sql.return_value = iter([circular_row]) mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL @@ -210,7 +214,7 @@ def test_similarity_search_circular_row_fallback_to_string( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -228,8 +232,9 @@ def test_similarity_search_circular_row_fallback_to_string( assert result["rows"] == [str(circular_row)] +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_postgresql_knn_success( +async def test_similarity_search_postgresql_knn_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with PostgreSQL dialect for kNN.""" @@ -249,7 +254,7 @@ def test_similarity_search_postgresql_knn_success( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -273,8 +278,9 @@ def test_similarity_search_postgresql_knn_success( assert call_args.kwargs == {"params": {"p1": [0.1, 0.2, 0.3]}} +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_postgresql_ann_unsupported( +async def test_similarity_search_postgresql_ann_unsupported( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with unsupported ANN for PostgreSQL dialect.""" @@ -286,7 +292,7 @@ def test_similarity_search_postgresql_ann_unsupported( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -311,8 +317,9 @@ def test_similarity_search_postgresql_ann_unsupported( ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_gsql_missing_embedding_model_error( +async def test_similarity_search_gsql_missing_embedding_model_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with missing embedding_options for GoogleSQL dialect.""" @@ -324,7 +331,7 @@ def test_similarity_search_gsql_missing_embedding_model_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -348,8 +355,9 @@ def test_similarity_search_gsql_missing_embedding_model_error( ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_pg_missing_embedding_model_error( +async def test_similarity_search_pg_missing_embedding_model_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with missing embedding_options for PostgreSQL dialect.""" @@ -361,7 +369,7 @@ def test_similarity_search_pg_missing_embedding_model_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -427,8 +435,9 @@ def test_similarity_search_pg_missing_embedding_model_error( ), ], ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_multiple_embedding_options_error( +async def test_similarity_search_multiple_embedding_options_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials, @@ -443,7 +452,7 @@ def test_similarity_search_multiple_embedding_options_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -461,8 +470,9 @@ def test_similarity_search_multiple_embedding_options_error( ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_output_dimensionality_gsql_error( +async def test_similarity_search_output_dimensionality_gsql_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with output_dimensionality and spanner_googlesql_embedding_model_name.""" @@ -474,7 +484,7 @@ def test_similarity_search_output_dimensionality_gsql_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -492,8 +502,9 @@ def test_similarity_search_output_dimensionality_gsql_error( assert "is not supported when" in result["error_details"] +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_unsupported_algorithm_error( +async def test_similarity_search_unsupported_algorithm_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with an unsupported nearest neighbors algorithm.""" @@ -505,7 +516,7 @@ def test_similarity_search_unsupported_algorithm_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], diff --git a/tests/unittests/tools/spanner/test_spanner_query_tool.py b/tests/unittests/tools/spanner/test_spanner_query_tool.py index 6c75a3ea..928c207d 100644 --- a/tests/unittests/tools/spanner/test_spanner_query_tool.py +++ b/tests/unittests/tools/spanner/test_spanner_query_tool.py @@ -191,8 +191,9 @@ async def test_execute_sql_query_result( assert tool.description == expected_description +@pytest.mark.asyncio @mock.patch.object(query_tool.utils, "execute_sql", spec_set=True) -def test_execute_sql(mock_utils_execute_sql): +async def test_execute_sql(mock_utils_execute_sql): """Test execute_sql function in query result default mode.""" mock_credentials = mock.create_autospec( Credentials, instance=True, spec_set=True @@ -202,7 +203,7 @@ def test_execute_sql(mock_utils_execute_sql): ) mock_utils_execute_sql.return_value = {"status": "SUCCESS", "rows": [[1]]} - result = query_tool.execute_sql( + result = await query_tool.execute_sql( project_id="test-project", instance_id="test-instance", database_id="test-database", From 6d53d800d5f6dc5d4a3a75300e34d5a9b0f006f5 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 23 Feb 2026 22:55:22 -0800 Subject: [PATCH 028/102] fix: fix typo in PlanReActPlanner instruction PiperOrigin-RevId: 874391350 --- src/google/adk/planners/plan_re_act_planner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/planners/plan_re_act_planner.py b/src/google/adk/planners/plan_re_act_planner.py index f7930b14..dab3a1fe 100644 --- a/src/google/adk/planners/plan_re_act_planner.py +++ b/src/google/adk/planners/plan_re_act_planner.py @@ -168,7 +168,7 @@ Follow this format when answering the question: (1) The planning part should be planning_preamble = f""" Below are the requirements for the planning: The plan is made to answer the user query if following the plan. The plan is coherent and covers all aspects of information from user query, and only involves the tools that are accessible by the agent. The plan contains the decomposed steps as a numbered list where each step should use one or multiple available tools. By reading the plan, you can intuitively know which tools to trigger or what actions to take. -If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be be under {REPLANNING_TAG}. Then use tools to follow the new plan. +If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be under {REPLANNING_TAG}. Then use tools to follow the new plan. """ reasoning_preamble = """ From dab80e4a8f3c5476f731335724bff5df3e6f3650 Mon Sep 17 00:00:00 2001 From: Lusha Wang Date: Mon, 23 Feb 2026 23:51:45 -0800 Subject: [PATCH 029/102] fix: Update agent_engine_sandbox_code_executor in ADK 1. For prototyping and testing purposes, sandbox name can be provided, and it will be used for all requests across the lifecycle of an agent 2. If no sandbox name is provided, agent engine name will be provided, and we will automatically create one sandbox per session, and the sandbox has TTL set for a year. If the sandbox stored in the session hits the TTL, it will not be in "STATE_RUNNING" so a new sandbox will be created. Co-authored-by: Lusha Wang PiperOrigin-RevId: 874415933 --- .../agent_engine_code_execution/README | 4 +- .../agent_engine_code_execution/agent.py | 7 +- .../agent_engine_sandbox_code_executor.py | 54 +++++-- ...test_agent_engine_sandbox_code_executor.py | 133 ++++++++++++++++++ 4 files changed, 179 insertions(+), 19 deletions(-) diff --git a/contributing/samples/agent_engine_code_execution/README b/contributing/samples/agent_engine_code_execution/README index 8d5a4442..b0443ae2 100644 --- a/contributing/samples/agent_engine_code_execution/README +++ b/contributing/samples/agent_engine_code_execution/README @@ -7,9 +7,9 @@ This sample data science agent uses Agent Engine Code Execution Sandbox to execu ## How to use -* 1. Follow https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/code-execution/overview to create a code execution sandbox environment. +* 1. Follow https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create-an-agent-engine-instance to create an agent engine instance. Replace the AGENT_ENGINE_RESOURCE_NAME with the one you just created. A new sandbox environment under this agent engine instance will be created for each session with TTL of 1 year. But sandbox can only main its state for up to 14 days. This is the recommended usage for production environments. -* 2. Replace the SANDBOX_RESOURCE_NAME with the one you just created. If you dont want to create a new sandbox environment directly, the Agent Engine Code Execution Sandbox will create one for you by default using the AGENT_ENGINE_RESOURCE_NAME you specified, however, please ensure to clean up sandboxes after use; otherwise, it will consume quotas. +* 2. For testing or protyping purposes, create a sandbox environment by following this guide: https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create_a_sandbox. Replace the SANDBOX_RESOURCE_NAME with the one you just created. This will be used as the default sandbox environment for all the code executions throughout the lifetime of the agent. As the sandbox is re-used across sessions, all sessions will share the same Python environment and variable values." ## Sample prompt diff --git a/contributing/samples/agent_engine_code_execution/agent.py b/contributing/samples/agent_engine_code_execution/agent.py index d85989eb..a32e4ca4 100644 --- a/contributing/samples/agent_engine_code_execution/agent.py +++ b/contributing/samples/agent_engine_code_execution/agent.py @@ -85,11 +85,10 @@ When plotting trends, you should make sure to sort and order the data by the x-a """, code_executor=AgentEngineSandboxCodeExecutor( - # Replace with your sandbox resource name if you already have one. - sandbox_resource_name="SANDBOX_RESOURCE_NAME", + # Replace with your sandbox resource name if you already have one. Only use it for testing or prototyping purposes, because this will use the same sandbox for all requests. # "projects/vertex-agent-loadtest/locations/us-central1/reasoningEngines/6842889780301135872/sandboxEnvironments/6545148628569161728", - # Replace with agent engine resource name used for creating sandbox if - # sandbox_resource_name is not set. + sandbox_resource_name=None, + # Replace with agent engine resource name used for creating sandbox environment. agent_engine_resource_name="AGENT_ENGINE_RESOURCE_NAME", ), ) diff --git a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py index 69d1778a..9348dbc4 100644 --- a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py +++ b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py @@ -21,6 +21,7 @@ import re from typing import Optional from typing_extensions import override +from vertexai import types from ..agents.invocation_context import InvocationContext from .base_code_executor import BaseCodeExecutor @@ -38,10 +39,15 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): sandbox_resource_name: If set, load the existing resource name of the code interpreter extension instead of creating a new one. Format: projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789 + agent_engine_resource_name: The resource name of the agent engine to use + to create the code execution sandbox. Format: + projects/123/locations/us-central1/reasoningEngines/456 """ sandbox_resource_name: str = None + agent_engine_resource_name: str = None + def __init__( self, sandbox_resource_name: Optional[str] = None, @@ -67,30 +73,19 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): agent_engine_resource_name_pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' if sandbox_resource_name is not None: - self.sandbox_resource_name = sandbox_resource_name self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( sandbox_resource_name, sandbox_resource_name_pattern ) ) + self.sandbox_resource_name = sandbox_resource_name elif agent_engine_resource_name is not None: - from vertexai import types - self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( agent_engine_resource_name, agent_engine_resource_name_pattern ) ) - # @TODO - Add TTL for sandbox creation after it is available - # in SDK. - operation = self._get_api_client().agent_engines.sandboxes.create( - spec={'code_execution_environment': {}}, - name=agent_engine_resource_name, - config=types.CreateAgentEngineSandboxConfig( - display_name='default_sandbox' - ), - ) - self.sandbox_resource_name = operation.response.name + self.agent_engine_resource_name = agent_engine_resource_name else: raise ValueError( 'Either sandbox_resource_name or agent_engine_resource_name must be' @@ -103,6 +98,39 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): invocation_context: InvocationContext, code_execution_input: CodeExecutionInput, ) -> CodeExecutionResult: + if self.sandbox_resource_name is None: + sandbox_name = invocation_context.session.state.get('sandbox_name', None) + create_new_sandbox = False + if sandbox_name is None: + create_new_sandbox = True + else: + # Check if the sandbox is still running OR already expired due to ttl. + sandbox = self._get_api_client().agent_engines.sandboxes.get( + name=sandbox_name + ) + if not sandbox or sandbox.state != 'STATE_RUNNING': + create_new_sandbox = True + + if create_new_sandbox: + operation = self._get_api_client().agent_engines.sandboxes.create( + spec={'code_execution_environment': {}}, + name=self.agent_engine_resource_name, + config=types.CreateAgentEngineSandboxConfig( + # VertexAiSessionService has a default TTL of 1 year, so we set + # the sandbox TTL to 1 year as well. For the current code + # execution sandbox, if it hasn't been used for 14 days, the + # state will be lost. + display_name='default_sandbox', + ttl='31536000s', + ), + ) + self.sandbox_resource_name = operation.response.name + invocation_context.session.state['sandbox_name'] = ( + self.sandbox_resource_name + ) + else: + self.sandbox_resource_name = sandbox_name + # Execute the code. input_data = { 'code': code_execution_input.code, diff --git a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py index 6022527f..604685fe 100644 --- a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py +++ b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py @@ -19,6 +19,7 @@ from unittest.mock import patch from google.adk.agents.invocation_context import InvocationContext from google.adk.code_executors.agent_engine_sandbox_code_executor import AgentEngineSandboxCodeExecutor from google.adk.code_executors.code_execution_utils import CodeExecutionInput +from google.adk.sessions.session import Session import pytest @@ -27,6 +28,10 @@ def mock_invocation_context() -> InvocationContext: """Fixture for a mock InvocationContext.""" mock = MagicMock(spec=InvocationContext) mock.invocation_id = "test-invocation-123" + session = MagicMock(spec=Session) + mock.session = session + session.state = [] + return mock @@ -118,3 +123,131 @@ class TestAgentEngineSandboxCodeExecutor: name="projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789", input_data={"code": 'print("hello world")'}, ) + + @patch("vertexai.Client") + def test_execute_code_recreates_sandbox_when_get_returns_none( + self, + mock_vertexai_client, + mock_invocation_context, + ): + # Setup Mocks + mock_api_client = MagicMock() + mock_vertexai_client.return_value = mock_api_client + + # Existing sandbox name stored in session, but get() will return None + existing_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/old" + mock_invocation_context.session.state = { + "sandbox_name": existing_sandbox_name + } + + # Mock get to return None (simulating missing/expired sandbox) + mock_api_client.agent_engines.sandboxes.get.return_value = None + + # Mock create operation to return a new sandbox resource name + operation_mock = MagicMock() + created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" + operation_mock.response.name = created_sandbox_name + mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock + + # Mock execute_code response + mock_response = MagicMock() + mock_json_output = MagicMock() + mock_json_output.mime_type = "application/json" + mock_json_output.data = json.dumps( + {"stdout": "recreated sandbox run", "stderr": ""} + ).encode("utf-8") + mock_json_output.metadata = None + mock_response.outputs = [mock_json_output] + mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( + mock_response + ) + + # Execute using agent_engine_resource_name so a sandbox can be created + executor = AgentEngineSandboxCodeExecutor( + agent_engine_resource_name=( + "projects/123/locations/us-central1/reasoningEngines/456" + ) + ) + code_input = CodeExecutionInput(code='print("hello world")') + result = executor.execute_code(mock_invocation_context, code_input) + + # Assert get was called for the existing sandbox + mock_api_client.agent_engines.sandboxes.get.assert_called_once_with( + name=existing_sandbox_name + ) + + # Assert create was called and session updated with new sandbox + mock_api_client.agent_engines.sandboxes.create.assert_called_once() + assert executor.sandbox_resource_name == created_sandbox_name + assert ( + mock_invocation_context.session.state["sandbox_name"] + == created_sandbox_name + ) + + # Assert execute_code used the created sandbox name + mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( + name=created_sandbox_name, + input_data={"code": 'print("hello world")'}, + ) + + @patch("vertexai.Client") + def test_execute_code_creates_sandbox_if_missing( + self, + mock_vertexai_client, + mock_invocation_context, + ): + # Setup Mocks + mock_api_client = MagicMock() + mock_vertexai_client.return_value = mock_api_client + + # Mock create operation to return a sandbox resource name + operation_mock = MagicMock() + created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" + operation_mock.response.name = created_sandbox_name + mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock + + # Mock execute_code response + mock_response = MagicMock() + mock_json_output = MagicMock() + mock_json_output.mime_type = "application/json" + mock_json_output.data = json.dumps( + {"stdout": "created sandbox run", "stderr": ""} + ).encode("utf-8") + mock_json_output.metadata = None + mock_response.outputs = [mock_json_output] + mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( + mock_response + ) + + # Ensure session.state behaves like a dict for storing sandbox_name + mock_invocation_context.session.state = {} + + # Execute using agent_engine_resource_name so a sandbox will be created + executor = AgentEngineSandboxCodeExecutor( + agent_engine_resource_name=( + "projects/123/locations/us-central1/reasoningEngines/456" + ), + sandbox_resource_name=None, + ) + code_input = CodeExecutionInput(code='print("hello world")') + result = executor.execute_code(mock_invocation_context, code_input) + + # Assert sandbox creation was called and session state updated + mock_api_client.agent_engines.sandboxes.create.assert_called_once() + create_call_kwargs = ( + mock_api_client.agent_engines.sandboxes.create.call_args.kwargs + ) + assert create_call_kwargs["name"] == ( + "projects/123/locations/us-central1/reasoningEngines/456" + ) + assert executor.sandbox_resource_name == created_sandbox_name + assert ( + mock_invocation_context.session.state["sandbox_name"] + == created_sandbox_name + ) + + # Assert execute_code used the created sandbox name + mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( + name=created_sandbox_name, + input_data={"code": 'print("hello world")'}, + ) From 8c0bd2034ca5325ab704a71521b5b20c8ba2ca78 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Tue, 24 Feb 2026 04:42:06 -0800 Subject: [PATCH 030/102] chore: SessionNotFoundError only inherits form ValueError Co-authored-by: Sasha Sobran PiperOrigin-RevId: 874545504 --- src/google/adk/errors/session_not_found_error.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/google/adk/errors/session_not_found_error.py b/src/google/adk/errors/session_not_found_error.py index a870d0d2..4fc3258e 100644 --- a/src/google/adk/errors/session_not_found_error.py +++ b/src/google/adk/errors/session_not_found_error.py @@ -14,14 +14,11 @@ from __future__ import annotations -from .not_found_error import NotFoundError - -class SessionNotFoundError(ValueError, NotFoundError): +class SessionNotFoundError(ValueError): """Raised when a session cannot be found. - Inherits from both ValueError (for backward compatibility) and NotFoundError - (for semantic consistency with the project's error hierarchy). + Inherits from ValueError (for backward compatibility). """ def __init__(self, message="Session not found."): From c615757ba12093ba4a2ba19bee3f498fef91584c Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 24 Feb 2026 08:34:05 -0800 Subject: [PATCH 031/102] fix: Add support for injecting a custom google.genai.Client into Gemini models This change introduces a new `client` parameter to the `Gemini` model's constructor. When provided, this preconfigured `google.genai.Client` instance is used for all API calls, offering fine-grained control over authentication, project, and location settings Close #2560 Co-authored-by: George Weale PiperOrigin-RevId: 874628604 --- src/google/adk/models/google_llm.py | 56 ++++++++ tests/unittests/models/test_google_llm.py | 150 ++++++++++++++++++++++ 2 files changed, 206 insertions(+) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 23c9c278..b8c5117e 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -85,6 +85,23 @@ class Gemini(BaseLlm): Attributes: model: The name of the Gemini model. + client: An optional preconfigured ``google.genai.Client`` instance. + When provided, ADK uses this client for all API calls instead of + creating one internally from environment variables or ADC. This + allows fine-grained control over authentication, project, location, + and other client-level settings — and enables running agents that + target different Vertex AI regions within the same process. + + Example:: + + from google import genai + from google.adk.models import Gemini + + client = genai.Client( + vertexai=True, project="my-project", location="us-central1" + ) + model = Gemini(model="gemini-2.5-flash", client=client) + use_interactions_api: Whether to use the interactions API for model invocation. """ @@ -131,6 +148,35 @@ class Gemini(BaseLlm): ``` """ + def __init__(self, *, client: Optional[Client] = None, **kwargs: Any): + """Initialises a Gemini model wrapper. + + Args: + client: An optional preconfigured ``google.genai.Client``. When + provided, ADK uses this client for **all** Gemini API calls + (including the Live API) instead of creating one internally. + + .. note:: + When a custom client is supplied it is used as-is for Live API + connections. ADK will **not** override the client's + ``api_version``; you are responsible for setting the correct + version (``v1beta1`` for Vertex AI, ``v1alpha`` for the + Gemini developer API) on the client yourself. + + .. warning:: + ``google.genai.Client`` contains threading primitives that + cannot be pickled. If you are deploying to Agent Engine (or + any environment that serialises the model), do **not** pass a + custom client — let ADK create one from the environment + instead. + + **kwargs: Forwarded to the Pydantic ``BaseLlm`` constructor + (``model``, ``base_url``, ``retry_options``, etc.). + """ + super().__init__(**kwargs) + # Store after super().__init__ so Pydantic validation runs first. + object.__setattr__(self, '_client', client) + @classmethod @override def supported_models(cls) -> list[str]: @@ -299,9 +345,16 @@ class Gemini(BaseLlm): def api_client(self) -> Client: """Provides the api client. + If a preconfigured ``client`` was passed to the constructor it is + returned directly; otherwise a new ``Client`` is created using the + default environment/ADC configuration. + Returns: The api client. """ + if self._client is not None: + return self._client + from google.genai import Client return Client( @@ -334,6 +387,9 @@ class Gemini(BaseLlm): @cached_property def _live_api_client(self) -> Client: + if self._client is not None: + return self._client + from google.genai import Client return Client( diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 70aa01b6..75d4c0fd 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -2140,3 +2140,153 @@ async def test_connect_speech_config_remains_none_when_both_are_none( # Verify the final speech_config is still None assert config_arg.speech_config is None assert isinstance(connection, GeminiLlmConnection) + + +# --------------------------------------------------------------------------- +# Tests for custom client injection (Issue #2560) +# --------------------------------------------------------------------------- + + +def test_custom_client_is_used_for_api_client(): + """When a custom client is provided, api_client returns it directly.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini.api_client is custom_client + + +def test_custom_client_is_used_for_live_api_client(): + """When a custom client is provided, _live_api_client returns it directly.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini._live_api_client is custom_client + + +def test_default_api_client_when_no_custom_client(): + """Without a custom client, api_client creates a default Client.""" + gemini = Gemini(model="gemini-1.5-flash") + + # api_client should construct a real Client (not None) + client = gemini.api_client + assert client is not None + # Verify it is not a mock — it's a real google.genai.Client + from google.genai import Client + + assert isinstance(client, Client) + + +def test_default_live_api_client_when_no_custom_client(): + """Without a custom client, _live_api_client creates a default Client.""" + gemini = Gemini(model="gemini-1.5-flash") + + client = gemini._live_api_client + assert client is not None + from google.genai import Client + + assert isinstance(client, Client) + + +def test_custom_client_api_backend_vertexai(): + """_api_backend reflects the custom client's vertexai setting.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = True + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini._api_backend == GoogleLLMVariant.VERTEX_AI + + +def test_custom_client_api_backend_gemini_api(): + """_api_backend reflects non-vertexai custom client.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = False + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini._api_backend == GoogleLLMVariant.GEMINI_API + + +@pytest.mark.asyncio +async def test_custom_client_used_for_generate_content(): + """Custom client is used when generate_content_async is called.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = False + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + generate_content_response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + role="model", + parts=[Part.from_text(text="Hello from custom client")], + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + async def mock_coro(): + return generate_content_response + + custom_client.aio.models.generate_content.return_value = mock_coro() + + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant", + ), + ) + + responses = [ + resp + async for resp in gemini.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + assert responses[0].content.parts[0].text == "Hello from custom client" + custom_client.aio.models.generate_content.assert_called_once() + + +@pytest.mark.asyncio +async def test_custom_client_used_for_live_connect(): + """Custom client is used for live API streaming connections.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = False + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + mock_live_session = mock.AsyncMock() + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + custom_client.aio.live.connect.return_value = MockLiveConnect() + + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant", + ), + ) + llm_request.live_connect_config = types.LiveConnectConfig() + + async with gemini.connect(llm_request) as connection: + custom_client.aio.live.connect.assert_called_once() + assert isinstance(connection, GeminiLlmConnection) From 7be90db24b41f1830e39ca3d7e15bf4dbfa5a304 Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 24 Feb 2026 08:38:34 -0800 Subject: [PATCH 032/102] feat: Support ID token exchange in ServiceAccountCredentialExchanger Adds use_id_token and audience fields to ServiceAccount so that ServiceAccountCredentialExchanger can produce ID tokens instead of access tokens. This is required for authenticating to Cloud Run, Cloud Functions, and other Google Cloud services that verify caller identity. Close #4458 Co-authored-by: George Weale PiperOrigin-RevId: 874630210 --- src/google/adk/auth/auth_credential.py | 39 +- .../service_account_exchanger.py | 140 ++++++-- .../unittests/tools/mcp_tool/test_mcp_tool.py | 4 +- .../test_service_account_exchanger.py | 333 +++++++++++++----- 4 files changed, 406 insertions(+), 110 deletions(-) diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index e205d9be..6160edcc 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -25,6 +25,7 @@ from pydantic import alias_generators from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +from pydantic import model_validator class BaseModelWithConfig(BaseModel): @@ -145,11 +146,45 @@ class ServiceAccountCredential(BaseModelWithConfig): class ServiceAccount(BaseModelWithConfig): - """Represents Google Service Account configuration.""" + """Represents Google Service Account configuration. + + Attributes: + service_account_credential: The service account credential (JSON key). + scopes: The OAuth2 scopes to request. Optional; when omitted with + ``use_default_credential=True``, defaults to the cloud-platform scope. + use_default_credential: Whether to use Application Default Credentials. + use_id_token: Whether to exchange for an ID token instead of an access + token. Required for service-to-service authentication with Cloud Run, + Cloud Functions, and other Google Cloud services that require identity + verification. When True, ``audience`` must also be set. + audience: The target audience for the ID token, typically the URL of the + receiving service (e.g. ``https://my-service-xyz.run.app``). Required + when ``use_id_token`` is True. + """ service_account_credential: Optional[ServiceAccountCredential] = None - scopes: List[str] + scopes: Optional[List[str]] = None use_default_credential: Optional[bool] = False + use_id_token: Optional[bool] = False + audience: Optional[str] = None + + @model_validator(mode="after") + def _validate_config(self) -> ServiceAccount: + if ( + not self.use_default_credential + and self.service_account_credential is None + ): + raise ValueError( + "service_account_credential is required when" + " use_default_credential is False." + ) + if self.use_id_token and not self.audience: + raise ValueError( + "audience is required when use_id_token is True. Set it to the" + " URL of the target service" + " (e.g. 'https://my-service.run.app')." + ) + return self class AuthCredentialTypes(str, Enum): diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py index 1dbe0fe4..2b79edf9 100644 --- a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py @@ -19,6 +19,7 @@ from __future__ import annotations from typing import Optional import google.auth +from google.auth import exceptions as google_auth_exceptions from google.auth.transport.requests import Request from google.oauth2 import service_account import google.oauth2.credentials @@ -27,6 +28,7 @@ from .....auth.auth_credential import AuthCredential from .....auth.auth_credential import AuthCredentialTypes from .....auth.auth_credential import HttpAuth from .....auth.auth_credential import HttpCredentials +from .....auth.auth_credential import ServiceAccount from .....auth.auth_schemes import AuthScheme from .base_credential_exchanger import AuthCredentialMissingError from .base_credential_exchanger import BaseAuthCredentialExchanger @@ -38,6 +40,11 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): Uses the default service credential if `use_default_credential = True`. Otherwise, uses the service account credential provided in the auth credential. + + Supports exchanging for either an access token (default) or an ID token + when ``ServiceAccount.use_id_token`` is True. ID tokens are required for + service-to-service authentication with Cloud Run, Cloud Functions, and + other services that verify caller identity. """ def exchange_credential( @@ -45,52 +52,130 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): auth_scheme: AuthScheme, auth_credential: Optional[AuthCredential] = None, ) -> AuthCredential: - """Exchanges the service account auth credential for an access token. + """Exchanges the service account auth credential for a token. If auth_credential contains a service account credential, it will be used - to fetch an access token. Otherwise, the default service credential will be - used for fetching an access token. + to fetch a token. Otherwise, the default service credential will be + used for fetching a token. + + When ``service_account.use_id_token`` is True, an ID token is fetched + using the configured ``audience``. This is required for authenticating + to Cloud Run, Cloud Functions, and similar services. Args: auth_scheme: The auth scheme. auth_credential: The auth credential. Returns: - An AuthCredential in HTTPBearer format, containing the access token. + An AuthCredential in HTTPBearer format, containing the token. """ - if ( - auth_credential is None - or auth_credential.service_account is None - or ( - auth_credential.service_account.service_account_credential is None - and not auth_credential.service_account.use_default_credential - ) - ): + if auth_credential is None or auth_credential.service_account is None: raise AuthCredentialMissingError( - "Service account credentials are missing. Please provide them, or set" - " `use_default_credential = True` to use application default" + "Service account credentials are missing. Please provide them, or" + " set `use_default_credential = True` to use application default" " credential in a hosted service like Cloud Run." ) + sa_config = auth_credential.service_account + + if sa_config.use_id_token: + return self._exchange_for_id_token(sa_config) + + return self._exchange_for_access_token(sa_config) + + def _exchange_for_id_token(self, sa_config: ServiceAccount) -> AuthCredential: + """Exchanges the service account credential for an ID token. + + Args: + sa_config: The service account configuration. + + Returns: + An AuthCredential in HTTPBearer format containing the ID token. + + Raises: + AuthCredentialMissingError: If token exchange fails. + """ + # audience and credential presence are validated by the ServiceAccount + # model_validator at construction time. try: - if auth_credential.service_account.use_default_credential: - credentials, project_id = google.auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - quota_project_id = ( - getattr(credentials, "quota_project_id", None) or project_id - ) + if sa_config.use_default_credential: + from google.oauth2 import id_token as oauth2_id_token + + request = Request() + token = oauth2_id_token.fetch_id_token(request, sa_config.audience) else: - config = auth_credential.service_account + # Guaranteed non-None by ServiceAccount model_validator. + assert sa_config.service_account_credential is not None + credentials = ( + service_account.IDTokenCredentials.from_service_account_info( + sa_config.service_account_credential.model_dump(), + target_audience=sa_config.audience, + ) + ) + credentials.refresh(Request()) + token = credentials.token + + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token=token), + ), + ) + + # ValueError is raised by google-auth when service account JSON is + # missing required fields (e.g. client_email, private_key), or when + # fetch_id_token cannot determine credentials from the environment. + except (google_auth_exceptions.GoogleAuthError, ValueError) as e: + raise AuthCredentialMissingError( + f"Failed to exchange service account for ID token: {e}" + ) from e + + def _exchange_for_access_token( + self, sa_config: ServiceAccount + ) -> AuthCredential: + """Exchanges the service account credential for an access token. + + Args: + sa_config: The service account configuration. + + Returns: + An AuthCredential in HTTPBearer format containing the access token. + + Raises: + AuthCredentialMissingError: If scopes are missing for explicit + credentials or token exchange fails. + """ + if not sa_config.use_default_credential and not sa_config.scopes: + raise AuthCredentialMissingError( + "scopes are required when using explicit service account credentials" + " for access token exchange." + ) + + try: + if sa_config.use_default_credential: + scopes = ( + sa_config.scopes + if sa_config.scopes + else ["https://www.googleapis.com/auth/cloud-platform"] + ) + credentials, project_id = google.auth.default( + scopes=scopes, + ) + quota_project_id = credentials.quota_project_id or project_id + else: + # Guaranteed non-None by ServiceAccount model_validator. + assert sa_config.service_account_credential is not None credentials = service_account.Credentials.from_service_account_info( - config.service_account_credential.model_dump(), scopes=config.scopes + sa_config.service_account_credential.model_dump(), + scopes=sa_config.scopes, ) quota_project_id = None credentials.refresh(Request()) - updated_credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=HttpAuth( scheme="bearer", credentials=HttpCredentials(token=credentials.token), @@ -101,9 +186,10 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): else None, ), ) - return updated_credential - except Exception as e: + # ValueError is raised by google-auth when service account JSON is + # missing required fields (e.g. client_email, private_key). + except (google_auth_exceptions.GoogleAuthError, ValueError) as e: raise AuthCredentialMissingError( f"Failed to exchange service account token: {e}" ) from e diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index c4c85e77..f38a8bbc 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -534,7 +534,9 @@ class TestMCPTool: ) # Create service account credential - service_account = ServiceAccount(scopes=["test"]) + service_account = ServiceAccount( + scopes=["test"], use_default_credential=True + ) credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=service_account, diff --git a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py index 0ca99444..fb35daf6 100644 --- a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +++ b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py @@ -25,8 +25,23 @@ from google.adk.auth.auth_schemes import AuthSchemeType from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger import google.auth +from google.auth import exceptions as google_auth_exceptions import pytest +_ACCESS_TOKEN_MONKEYPATCH_TARGET = ( + "google.adk.tools.openapi_tool.auth.credential_exchangers." + "service_account_exchanger.service_account.Credentials." + "from_service_account_info" +) + +_ID_TOKEN_MONKEYPATCH_TARGET = ( + "google.adk.tools.openapi_tool.auth.credential_exchangers." + "service_account_exchanger.service_account.IDTokenCredentials." + "from_service_account_info" +) + +_FETCH_ID_TOKEN_MONKEYPATCH_TARGET = "google.oauth2.id_token.fetch_id_token" + @pytest.fixture def service_account_exchanger(): @@ -41,50 +56,45 @@ def auth_scheme(): return scheme -def test_exchange_credential_success( - service_account_exchanger, auth_scheme, monkeypatch +@pytest.fixture +def sa_credential(): + """A minimal valid ServiceAccountCredential for testing.""" + return ServiceAccountCredential( + type_="service_account", + project_id="test_project_id", + private_key_id="test_private_key_id", + private_key="-----BEGIN PRIVATE KEY-----...", + client_email="test@test.iam.gserviceaccount.com", + client_id="test_client_id", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs", + client_x509_cert_url=( + "https://www.googleapis.com/robot/v1/metadata/x509/test" + ), + universe_domain="googleapis.com", + ) + + +_DEFAULT_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] + + +# --- Access token exchange tests --- + + +def test_exchange_access_token_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch ): - """Test successful exchange of service account credentials.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" + mock_from_sa_info = MagicMock(return_value=mock_credentials) + monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) - # Mock the from_service_account_info method - mock_from_service_account_info = MagicMock(return_value=mock_credentials) - target_path = ( - "google.adk.tools.openapi_tool.auth.credential_exchangers." - "service_account_exchanger.service_account.Credentials." - "from_service_account_info" - ) - monkeypatch.setattr( - target_path, - mock_from_service_account_info, - ) - - # Mock the refresh method - mock_credentials.refresh = MagicMock() - - # Create a valid AuthCredential with service account info auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="your_project_id", - private_key_id="your_private_key_id", - private_key="-----BEGIN PRIVATE KEY-----...", - client_email="...@....iam.gserviceaccount.com", - client_id="your_client_id", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url=( - "https://www.googleapis.com/robot/v1/metadata/x509/..." - ), - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, ), ) @@ -95,7 +105,7 @@ def test_exchange_credential_success( assert result.auth_type == AuthCredentialTypes.HTTP assert result.http.scheme == "bearer" assert result.http.credentials.token == "mock_access_token" - mock_from_service_account_info.assert_called_once() + mock_from_sa_info.assert_called_once() mock_credentials.refresh.assert_called_once() @@ -107,7 +117,7 @@ def test_exchange_credential_success( (None, None, None), ], ) -def test_exchange_credential_use_default_credential_success( +def test_exchange_access_token_with_adc_sets_quota_project( service_account_exchanger, auth_scheme, monkeypatch, @@ -115,7 +125,6 @@ def test_exchange_credential_use_default_credential_success( adc_project_id, expected_quota_project_id, ): - """Test successful exchange of service account credentials using default credential.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" mock_credentials.quota_project_id = cred_quota_project_id @@ -128,7 +137,7 @@ def test_exchange_credential_use_default_credential_success( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], + scopes=["https://www.googleapis.com/auth/bigquery"], ), ) @@ -146,26 +155,49 @@ def test_exchange_credential_use_default_credential_success( ) else: assert not result.http.additional_headers - # Verify google.auth.default is called with the correct scopes parameter mock_google_auth_default.assert_called_once_with( - scopes=["https://www.googleapis.com/auth/cloud-platform"] + scopes=["https://www.googleapis.com/auth/bigquery"] ) mock_credentials.refresh.assert_called_once() -def test_exchange_credential_missing_auth_credential( +def test_exchange_access_token_with_adc_defaults_to_cloud_platform_scope( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_credentials = MagicMock() + mock_credentials.token = "mock_access_token" + mock_credentials.quota_project_id = None + mock_google_auth_default = MagicMock(return_value=(mock_credentials, None)) + monkeypatch.setattr(google.auth, "default", mock_google_auth_default) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_access_token" + mock_google_auth_default.assert_called_once_with(scopes=_DEFAULT_SCOPES) + + +def test_exchange_raises_when_auth_credential_is_none( service_account_exchanger, auth_scheme ): - """Test missing auth credential during exchange.""" with pytest.raises(AuthCredentialMissingError) as exc_info: service_account_exchanger.exchange_credential(auth_scheme, None) assert "Service account credentials are missing" in str(exc_info.value) -def test_exchange_credential_missing_service_account_info( +def test_exchange_raises_when_service_account_is_none( service_account_exchanger, auth_scheme ): - """Test missing service account info during exchange.""" auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, ) @@ -174,47 +206,188 @@ def test_exchange_credential_missing_service_account_info( assert "Service account credentials are missing" in str(exc_info.value) -def test_exchange_credential_exchange_failure( - service_account_exchanger, auth_scheme, monkeypatch +def test_exchange_wraps_google_auth_error_as_missing_error( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch ): - """Test failure during service account token exchange.""" - mock_from_service_account_info = MagicMock( - side_effect=Exception("Failed to load credentials") - ) - target_path = ( - "google.adk.tools.openapi_tool.auth.credential_exchangers." - "service_account_exchanger.service_account.Credentials." - "from_service_account_info" - ) - monkeypatch.setattr( - target_path, - mock_from_service_account_info, + mock_from_sa_info = MagicMock( + side_effect=ValueError("Failed to load credentials") ) + monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="your_project_id", - private_key_id="your_private_key_id", - private_key="-----BEGIN PRIVATE KEY-----...", - client_email="...@....iam.gserviceaccount.com", - client_id="your_client_id", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url=( - "https://www.googleapis.com/robot/v1/metadata/x509/..." - ), - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, ), ) + with pytest.raises(AuthCredentialMissingError) as exc_info: service_account_exchanger.exchange_credential(auth_scheme, auth_credential) assert "Failed to exchange service account token" in str(exc_info.value) - mock_from_service_account_info.assert_called_once() + mock_from_sa_info.assert_called_once() + + +def test_exchange_raises_when_explicit_credentials_have_no_scopes( + service_account_exchanger, auth_scheme, sa_credential +): + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "scopes are required" in str(exc_info.value) + + +# --- ID token exchange tests --- + + +def test_exchange_id_token_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch +): + mock_id_credentials = MagicMock() + mock_id_credentials.token = "mock_id_token" + mock_from_sa_info = MagicMock(return_value=mock_id_credentials) + monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_id_token" + assert result.http.additional_headers is None + mock_from_sa_info.assert_called_once() + assert ( + mock_from_sa_info.call_args[1]["target_audience"] + == "https://my-service.run.app" + ) + mock_id_credentials.refresh.assert_called_once() + + +def test_exchange_id_token_with_adc( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_fetch_id_token = MagicMock(return_value="mock_adc_id_token") + monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_adc_id_token" + assert result.http.additional_headers is None + mock_fetch_id_token.assert_called_once() + assert mock_fetch_id_token.call_args[0][1] == "https://my-service.run.app" + + +def test_id_token_requires_audience(): + with pytest.raises( + ValueError, match="audience is required when use_id_token is True" + ): + ServiceAccount( + use_default_credential=True, + use_id_token=True, + ) + + +def test_exchange_id_token_wraps_error_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch +): + mock_from_sa_info = MagicMock( + side_effect=ValueError("Failed to create ID token credentials") + ) + monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "Failed to exchange service account for ID token" in str( + exc_info.value + ) + + +def test_exchange_id_token_wraps_error_with_adc( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_fetch_id_token = MagicMock( + side_effect=google_auth_exceptions.DefaultCredentialsError( + "Metadata service unavailable" + ) + ) + monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "Failed to exchange service account for ID token" in str( + exc_info.value + ) + + +# --- Model validator tests --- + + +def test_model_validator_rejects_missing_credential_without_adc(): + with pytest.raises( + ValueError, + match="service_account_credential is required", + ): + ServiceAccount( + use_default_credential=False, + scopes=_DEFAULT_SCOPES, + ) + + +def test_model_validator_allows_adc_without_explicit_credential(): + sa = ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + ) + assert sa.service_account_credential is None + assert sa.use_default_credential is True From ee8d956413473d1bbbb025a470ad882c1487d8b8 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Feb 2026 11:15:35 -0800 Subject: [PATCH 033/102] fix: Update agent_engine_sandbox_code_executor in ADK 1. For prototyping and testing purposes, sandbox name can be provided, and it will be used for all requests across the lifecycle of an agent 2. If no sandbox name is provided, agent engine name will be provided, and we will automatically create one sandbox per session, and the sandbox has TTL set for a year. If the sandbox stored in the session hits the TTL, it will not be in "STATE_RUNNING" so a new sandbox will be created. PiperOrigin-RevId: 874705260 --- .../agent_engine_code_execution/README | 4 +- .../agent_engine_code_execution/agent.py | 7 +- .../agent_engine_sandbox_code_executor.py | 54 ++----- ...test_agent_engine_sandbox_code_executor.py | 133 ------------------ 4 files changed, 19 insertions(+), 179 deletions(-) diff --git a/contributing/samples/agent_engine_code_execution/README b/contributing/samples/agent_engine_code_execution/README index b0443ae2..8d5a4442 100644 --- a/contributing/samples/agent_engine_code_execution/README +++ b/contributing/samples/agent_engine_code_execution/README @@ -7,9 +7,9 @@ This sample data science agent uses Agent Engine Code Execution Sandbox to execu ## How to use -* 1. Follow https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create-an-agent-engine-instance to create an agent engine instance. Replace the AGENT_ENGINE_RESOURCE_NAME with the one you just created. A new sandbox environment under this agent engine instance will be created for each session with TTL of 1 year. But sandbox can only main its state for up to 14 days. This is the recommended usage for production environments. +* 1. Follow https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/code-execution/overview to create a code execution sandbox environment. -* 2. For testing or protyping purposes, create a sandbox environment by following this guide: https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create_a_sandbox. Replace the SANDBOX_RESOURCE_NAME with the one you just created. This will be used as the default sandbox environment for all the code executions throughout the lifetime of the agent. As the sandbox is re-used across sessions, all sessions will share the same Python environment and variable values." +* 2. Replace the SANDBOX_RESOURCE_NAME with the one you just created. If you dont want to create a new sandbox environment directly, the Agent Engine Code Execution Sandbox will create one for you by default using the AGENT_ENGINE_RESOURCE_NAME you specified, however, please ensure to clean up sandboxes after use; otherwise, it will consume quotas. ## Sample prompt diff --git a/contributing/samples/agent_engine_code_execution/agent.py b/contributing/samples/agent_engine_code_execution/agent.py index a32e4ca4..d85989eb 100644 --- a/contributing/samples/agent_engine_code_execution/agent.py +++ b/contributing/samples/agent_engine_code_execution/agent.py @@ -85,10 +85,11 @@ When plotting trends, you should make sure to sort and order the data by the x-a """, code_executor=AgentEngineSandboxCodeExecutor( - # Replace with your sandbox resource name if you already have one. Only use it for testing or prototyping purposes, because this will use the same sandbox for all requests. + # Replace with your sandbox resource name if you already have one. + sandbox_resource_name="SANDBOX_RESOURCE_NAME", # "projects/vertex-agent-loadtest/locations/us-central1/reasoningEngines/6842889780301135872/sandboxEnvironments/6545148628569161728", - sandbox_resource_name=None, - # Replace with agent engine resource name used for creating sandbox environment. + # Replace with agent engine resource name used for creating sandbox if + # sandbox_resource_name is not set. agent_engine_resource_name="AGENT_ENGINE_RESOURCE_NAME", ), ) diff --git a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py index 9348dbc4..69d1778a 100644 --- a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py +++ b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py @@ -21,7 +21,6 @@ import re from typing import Optional from typing_extensions import override -from vertexai import types from ..agents.invocation_context import InvocationContext from .base_code_executor import BaseCodeExecutor @@ -39,15 +38,10 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): sandbox_resource_name: If set, load the existing resource name of the code interpreter extension instead of creating a new one. Format: projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789 - agent_engine_resource_name: The resource name of the agent engine to use - to create the code execution sandbox. Format: - projects/123/locations/us-central1/reasoningEngines/456 """ sandbox_resource_name: str = None - agent_engine_resource_name: str = None - def __init__( self, sandbox_resource_name: Optional[str] = None, @@ -73,19 +67,30 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): agent_engine_resource_name_pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' if sandbox_resource_name is not None: + self.sandbox_resource_name = sandbox_resource_name self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( sandbox_resource_name, sandbox_resource_name_pattern ) ) - self.sandbox_resource_name = sandbox_resource_name elif agent_engine_resource_name is not None: + from vertexai import types + self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( agent_engine_resource_name, agent_engine_resource_name_pattern ) ) - self.agent_engine_resource_name = agent_engine_resource_name + # @TODO - Add TTL for sandbox creation after it is available + # in SDK. + operation = self._get_api_client().agent_engines.sandboxes.create( + spec={'code_execution_environment': {}}, + name=agent_engine_resource_name, + config=types.CreateAgentEngineSandboxConfig( + display_name='default_sandbox' + ), + ) + self.sandbox_resource_name = operation.response.name else: raise ValueError( 'Either sandbox_resource_name or agent_engine_resource_name must be' @@ -98,39 +103,6 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): invocation_context: InvocationContext, code_execution_input: CodeExecutionInput, ) -> CodeExecutionResult: - if self.sandbox_resource_name is None: - sandbox_name = invocation_context.session.state.get('sandbox_name', None) - create_new_sandbox = False - if sandbox_name is None: - create_new_sandbox = True - else: - # Check if the sandbox is still running OR already expired due to ttl. - sandbox = self._get_api_client().agent_engines.sandboxes.get( - name=sandbox_name - ) - if not sandbox or sandbox.state != 'STATE_RUNNING': - create_new_sandbox = True - - if create_new_sandbox: - operation = self._get_api_client().agent_engines.sandboxes.create( - spec={'code_execution_environment': {}}, - name=self.agent_engine_resource_name, - config=types.CreateAgentEngineSandboxConfig( - # VertexAiSessionService has a default TTL of 1 year, so we set - # the sandbox TTL to 1 year as well. For the current code - # execution sandbox, if it hasn't been used for 14 days, the - # state will be lost. - display_name='default_sandbox', - ttl='31536000s', - ), - ) - self.sandbox_resource_name = operation.response.name - invocation_context.session.state['sandbox_name'] = ( - self.sandbox_resource_name - ) - else: - self.sandbox_resource_name = sandbox_name - # Execute the code. input_data = { 'code': code_execution_input.code, diff --git a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py index 604685fe..6022527f 100644 --- a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py +++ b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py @@ -19,7 +19,6 @@ from unittest.mock import patch from google.adk.agents.invocation_context import InvocationContext from google.adk.code_executors.agent_engine_sandbox_code_executor import AgentEngineSandboxCodeExecutor from google.adk.code_executors.code_execution_utils import CodeExecutionInput -from google.adk.sessions.session import Session import pytest @@ -28,10 +27,6 @@ def mock_invocation_context() -> InvocationContext: """Fixture for a mock InvocationContext.""" mock = MagicMock(spec=InvocationContext) mock.invocation_id = "test-invocation-123" - session = MagicMock(spec=Session) - mock.session = session - session.state = [] - return mock @@ -123,131 +118,3 @@ class TestAgentEngineSandboxCodeExecutor: name="projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789", input_data={"code": 'print("hello world")'}, ) - - @patch("vertexai.Client") - def test_execute_code_recreates_sandbox_when_get_returns_none( - self, - mock_vertexai_client, - mock_invocation_context, - ): - # Setup Mocks - mock_api_client = MagicMock() - mock_vertexai_client.return_value = mock_api_client - - # Existing sandbox name stored in session, but get() will return None - existing_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/old" - mock_invocation_context.session.state = { - "sandbox_name": existing_sandbox_name - } - - # Mock get to return None (simulating missing/expired sandbox) - mock_api_client.agent_engines.sandboxes.get.return_value = None - - # Mock create operation to return a new sandbox resource name - operation_mock = MagicMock() - created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" - operation_mock.response.name = created_sandbox_name - mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock - - # Mock execute_code response - mock_response = MagicMock() - mock_json_output = MagicMock() - mock_json_output.mime_type = "application/json" - mock_json_output.data = json.dumps( - {"stdout": "recreated sandbox run", "stderr": ""} - ).encode("utf-8") - mock_json_output.metadata = None - mock_response.outputs = [mock_json_output] - mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( - mock_response - ) - - # Execute using agent_engine_resource_name so a sandbox can be created - executor = AgentEngineSandboxCodeExecutor( - agent_engine_resource_name=( - "projects/123/locations/us-central1/reasoningEngines/456" - ) - ) - code_input = CodeExecutionInput(code='print("hello world")') - result = executor.execute_code(mock_invocation_context, code_input) - - # Assert get was called for the existing sandbox - mock_api_client.agent_engines.sandboxes.get.assert_called_once_with( - name=existing_sandbox_name - ) - - # Assert create was called and session updated with new sandbox - mock_api_client.agent_engines.sandboxes.create.assert_called_once() - assert executor.sandbox_resource_name == created_sandbox_name - assert ( - mock_invocation_context.session.state["sandbox_name"] - == created_sandbox_name - ) - - # Assert execute_code used the created sandbox name - mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( - name=created_sandbox_name, - input_data={"code": 'print("hello world")'}, - ) - - @patch("vertexai.Client") - def test_execute_code_creates_sandbox_if_missing( - self, - mock_vertexai_client, - mock_invocation_context, - ): - # Setup Mocks - mock_api_client = MagicMock() - mock_vertexai_client.return_value = mock_api_client - - # Mock create operation to return a sandbox resource name - operation_mock = MagicMock() - created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" - operation_mock.response.name = created_sandbox_name - mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock - - # Mock execute_code response - mock_response = MagicMock() - mock_json_output = MagicMock() - mock_json_output.mime_type = "application/json" - mock_json_output.data = json.dumps( - {"stdout": "created sandbox run", "stderr": ""} - ).encode("utf-8") - mock_json_output.metadata = None - mock_response.outputs = [mock_json_output] - mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( - mock_response - ) - - # Ensure session.state behaves like a dict for storing sandbox_name - mock_invocation_context.session.state = {} - - # Execute using agent_engine_resource_name so a sandbox will be created - executor = AgentEngineSandboxCodeExecutor( - agent_engine_resource_name=( - "projects/123/locations/us-central1/reasoningEngines/456" - ), - sandbox_resource_name=None, - ) - code_input = CodeExecutionInput(code='print("hello world")') - result = executor.execute_code(mock_invocation_context, code_input) - - # Assert sandbox creation was called and session state updated - mock_api_client.agent_engines.sandboxes.create.assert_called_once() - create_call_kwargs = ( - mock_api_client.agent_engines.sandboxes.create.call_args.kwargs - ) - assert create_call_kwargs["name"] == ( - "projects/123/locations/us-central1/reasoningEngines/456" - ) - assert executor.sandbox_resource_name == created_sandbox_name - assert ( - mock_invocation_context.session.state["sandbox_name"] - == created_sandbox_name - ) - - # Assert execute_code used the created sandbox name - mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( - name=created_sandbox_name, - input_data={"code": 'print("hello world")'}, - ) From 48105b49c5ab8e4719a66e7219f731b2cd293b00 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Feb 2026 12:58:01 -0800 Subject: [PATCH 034/102] fix: Add support for injecting a custom google.genai.Client into Gemini models This change introduces a new `client` parameter to the `Gemini` model's constructor. When provided, this preconfigured `google.genai.Client` instance is used for all API calls, offering fine-grained control over authentication, project, and location settings Close #2560 PiperOrigin-RevId: 874752355 --- src/google/adk/models/google_llm.py | 56 -------- tests/unittests/models/test_google_llm.py | 150 ---------------------- 2 files changed, 206 deletions(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index b8c5117e..23c9c278 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -85,23 +85,6 @@ class Gemini(BaseLlm): Attributes: model: The name of the Gemini model. - client: An optional preconfigured ``google.genai.Client`` instance. - When provided, ADK uses this client for all API calls instead of - creating one internally from environment variables or ADC. This - allows fine-grained control over authentication, project, location, - and other client-level settings — and enables running agents that - target different Vertex AI regions within the same process. - - Example:: - - from google import genai - from google.adk.models import Gemini - - client = genai.Client( - vertexai=True, project="my-project", location="us-central1" - ) - model = Gemini(model="gemini-2.5-flash", client=client) - use_interactions_api: Whether to use the interactions API for model invocation. """ @@ -148,35 +131,6 @@ class Gemini(BaseLlm): ``` """ - def __init__(self, *, client: Optional[Client] = None, **kwargs: Any): - """Initialises a Gemini model wrapper. - - Args: - client: An optional preconfigured ``google.genai.Client``. When - provided, ADK uses this client for **all** Gemini API calls - (including the Live API) instead of creating one internally. - - .. note:: - When a custom client is supplied it is used as-is for Live API - connections. ADK will **not** override the client's - ``api_version``; you are responsible for setting the correct - version (``v1beta1`` for Vertex AI, ``v1alpha`` for the - Gemini developer API) on the client yourself. - - .. warning:: - ``google.genai.Client`` contains threading primitives that - cannot be pickled. If you are deploying to Agent Engine (or - any environment that serialises the model), do **not** pass a - custom client — let ADK create one from the environment - instead. - - **kwargs: Forwarded to the Pydantic ``BaseLlm`` constructor - (``model``, ``base_url``, ``retry_options``, etc.). - """ - super().__init__(**kwargs) - # Store after super().__init__ so Pydantic validation runs first. - object.__setattr__(self, '_client', client) - @classmethod @override def supported_models(cls) -> list[str]: @@ -345,16 +299,9 @@ class Gemini(BaseLlm): def api_client(self) -> Client: """Provides the api client. - If a preconfigured ``client`` was passed to the constructor it is - returned directly; otherwise a new ``Client`` is created using the - default environment/ADC configuration. - Returns: The api client. """ - if self._client is not None: - return self._client - from google.genai import Client return Client( @@ -387,9 +334,6 @@ class Gemini(BaseLlm): @cached_property def _live_api_client(self) -> Client: - if self._client is not None: - return self._client - from google.genai import Client return Client( diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 75d4c0fd..70aa01b6 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -2140,153 +2140,3 @@ async def test_connect_speech_config_remains_none_when_both_are_none( # Verify the final speech_config is still None assert config_arg.speech_config is None assert isinstance(connection, GeminiLlmConnection) - - -# --------------------------------------------------------------------------- -# Tests for custom client injection (Issue #2560) -# --------------------------------------------------------------------------- - - -def test_custom_client_is_used_for_api_client(): - """When a custom client is provided, api_client returns it directly.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - assert gemini.api_client is custom_client - - -def test_custom_client_is_used_for_live_api_client(): - """When a custom client is provided, _live_api_client returns it directly.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - assert gemini._live_api_client is custom_client - - -def test_default_api_client_when_no_custom_client(): - """Without a custom client, api_client creates a default Client.""" - gemini = Gemini(model="gemini-1.5-flash") - - # api_client should construct a real Client (not None) - client = gemini.api_client - assert client is not None - # Verify it is not a mock — it's a real google.genai.Client - from google.genai import Client - - assert isinstance(client, Client) - - -def test_default_live_api_client_when_no_custom_client(): - """Without a custom client, _live_api_client creates a default Client.""" - gemini = Gemini(model="gemini-1.5-flash") - - client = gemini._live_api_client - assert client is not None - from google.genai import Client - - assert isinstance(client, Client) - - -def test_custom_client_api_backend_vertexai(): - """_api_backend reflects the custom client's vertexai setting.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - custom_client.vertexai = True - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - assert gemini._api_backend == GoogleLLMVariant.VERTEX_AI - - -def test_custom_client_api_backend_gemini_api(): - """_api_backend reflects non-vertexai custom client.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - custom_client.vertexai = False - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - assert gemini._api_backend == GoogleLLMVariant.GEMINI_API - - -@pytest.mark.asyncio -async def test_custom_client_used_for_generate_content(): - """Custom client is used when generate_content_async is called.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - custom_client.vertexai = False - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - generate_content_response = types.GenerateContentResponse( - candidates=[ - types.Candidate( - content=Content( - role="model", - parts=[Part.from_text(text="Hello from custom client")], - ), - finish_reason=types.FinishReason.STOP, - ) - ] - ) - - async def mock_coro(): - return generate_content_response - - custom_client.aio.models.generate_content.return_value = mock_coro() - - llm_request = LlmRequest( - model="gemini-1.5-flash", - contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], - config=types.GenerateContentConfig( - system_instruction="You are a helpful assistant", - ), - ) - - responses = [ - resp - async for resp in gemini.generate_content_async(llm_request, stream=False) - ] - - assert len(responses) == 1 - assert responses[0].content.parts[0].text == "Hello from custom client" - custom_client.aio.models.generate_content.assert_called_once() - - -@pytest.mark.asyncio -async def test_custom_client_used_for_live_connect(): - """Custom client is used for live API streaming connections.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - custom_client.vertexai = False - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - mock_live_session = mock.AsyncMock() - - class MockLiveConnect: - - async def __aenter__(self): - return mock_live_session - - async def __aexit__(self, *args): - pass - - custom_client.aio.live.connect.return_value = MockLiveConnect() - - llm_request = LlmRequest( - model="gemini-1.5-flash", - contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], - config=types.GenerateContentConfig( - system_instruction="You are a helpful assistant", - ), - ) - llm_request.live_connect_config = types.LiveConnectConfig() - - async with gemini.connect(llm_request) as connection: - custom_client.aio.live.connect.assert_called_once() - assert isinstance(connection, GeminiLlmConnection) From 121d27741684685c564e484704ae949c5f0807b1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Feb 2026 13:26:25 -0800 Subject: [PATCH 035/102] feat: Add /chat/completions streaming support to Apigee LLM PiperOrigin-RevId: 874764985 --- src/google/adk/models/apigee_llm.py | 485 ++++++++++++++---- .../models/test_completions_http_client.py | 341 +++++++++++- 2 files changed, 718 insertions(+), 108 deletions(-) diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index 90a91f32..fc4928cb 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -25,6 +25,7 @@ import logging import os from typing import Any from typing import AsyncGenerator +from typing import Generator from typing import Optional from typing import TYPE_CHECKING @@ -51,6 +52,14 @@ _GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME = 'GOOGLE_GENAI_USE_VERTEXAI' _PROJECT_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_PROJECT' _LOCATION_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_LOCATION' +_CUSTOM_METADATA_FIELDS = ( + 'id', + 'created', + 'model', + 'service_tier', + 'object', +) + class ApigeeLlm(Gemini): """A BaseLlm implementation for calling Apigee proxy. @@ -290,6 +299,45 @@ def _get_model_id(model: str) -> str: return components[-1] +def _parse_logprobs( + logprobs_data: dict[str, Any] | None, +) -> types.LogprobsResult | None: + """Parses OpenAI logprobs data into LogprobsResult.""" + if not logprobs_data or 'content' not in logprobs_data: + return None + + chosen_candidates = [] + top_candidates = [] + + for item in logprobs_data['content']: + chosen_candidates.append( + types.LogprobsResultCandidate( + token=item.get('token'), + log_probability=item.get('logprob'), + # OpenAI text format usually doesn't expose ID easily here + token_id=None, + ) + ) + + if 'top_logprobs' in item: + current_top_candidates = [] + for top_item in item['top_logprobs']: + current_top_candidates.append( + types.LogprobsResultCandidate( + token=top_item.get('token'), + log_probability=top_item.get('logprob'), + token_id=None, + ) + ) + top_candidates.append( + types.LogprobsResultTopCandidates(candidates=current_top_candidates) + ) + + return types.LogprobsResult( + chosen_candidates=chosen_candidates, top_candidates=top_candidates + ) + + def _validate_model_string(model: str) -> bool: """Validates the model string for Apigee LLM. @@ -383,7 +431,7 @@ class CompletionsHTTPClient: loop.create_task(client.aclose()) except RuntimeError: try: - # This fails if aynscio.run is already called in main and is being closed. + # This fails if asyncio.run is already called in main and is closing. asyncio.run(client.aclose()) except RuntimeError: pass @@ -470,7 +518,8 @@ class CompletionsHTTPClient: url = f"{url.rstrip('/')}/chat/completions" if stream: - raise NotImplementedError('Streaming is not supported yet.') + async for stream_res in self._handle_streaming(url, payload, headers): + yield stream_res else: response = await self._httpx_post_with_retry(url, payload, headers) data = response.json() @@ -487,11 +536,33 @@ class CompletionsHTTPClient: response.raise_for_status() return response - async def _handle_streaming_response( - self, response: httpx.Response + async def _handle_streaming( + self, url: str, payload: dict[str, Any], headers: dict[str, str] ) -> AsyncGenerator[LlmResponse, None]: """Handles streaming response from OpenAI-compatible API.""" - raise NotImplementedError('Streaming is not supported yet.') + accumulator = ChatCompletionsResponseHandler() + async with self._client.stream( + 'POST', + url, + json=payload, + headers=headers, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + line = line.strip() + if line.startswith('data:'): + line = line.removeprefix('data:') + line = line.lstrip() + if line == '[DONE]': + break + try: + for res in self._parse_streaming_line(line, accumulator): + yield res + except json.JSONDecodeError: + logger.warning('Failed to parse JSON chunk: %s', line) + continue def _construct_payload( self, llm_request: LlmRequest, stream: bool @@ -731,78 +802,62 @@ class CompletionsHTTPClient: return ''.join(part.text for part in parts if part.text) return None - def _parse_logprobs( - self, logprobs_data: dict[str, Any] | None - ) -> types.LogprobsResult | None: - """Parses OpenAI logprobs data into LogprobsResult.""" - if not logprobs_data or 'content' not in logprobs_data: - return None - - chosen_candidates = [] - top_candidates = [] - - for item in logprobs_data['content']: - chosen_candidates.append( - types.LogprobsResultCandidate( - token=item.get('token'), - log_probability=item.get('logprob'), - # OpenAI text format usually doesn't expose ID easily here - token_id=None, - ) - ) - - if 'top_logprobs' in item: - current_top_candidates = [] - for top_item in item['top_logprobs']: - current_top_candidates.append( - types.LogprobsResultCandidate( - token=top_item.get('token'), - log_probability=top_item.get('logprob'), - token_id=None, - ) - ) - top_candidates.append( - types.LogprobsResultTopCandidates(candidates=current_top_candidates) - ) - - return types.LogprobsResult( - chosen_candidates=chosen_candidates, top_candidates=top_candidates - ) - def _parse_response(self, response: dict[str, Any]) -> LlmResponse: """Parses an OpenAI response dictionary into an LlmResponse.""" + handler = ChatCompletionsResponseHandler() + return handler.process_response(response) + + def _parse_streaming_line( + self, + line: str, + accumulator: ChatCompletionsResponseHandler, + ) -> Generator[LlmResponse]: + """Parses a single line from the streaming response. + + Args: + line: A single line from the streaming response, expected to be a JSON + string. + accumulator: An accumulator to manage partial chat completion choices + across multiple chunks. + + Yields: + An LlmResponse object parsed from the streaming line. + """ + chunk = json.loads(line) + for response in accumulator.process_chunk(chunk): + yield response + + +class ChatCompletionsResponseHandler: + """Accumulates responses from the /chat/completions endpoint. + + Useful for both streaming and non-streaming responses. + """ + + def __init__(self): + self.content_parts = '' + self.tool_call_parts = {} + self.role = '' + self.streaming_complete = False + self.model = '' + self.usage = {} + self.logprobs = {} + self.custom_metadata = {} + + def process_response(self, response: dict[str, Any]) -> LlmResponse: + """Processes a complete non-streaming response.""" choices = response.get('choices', []) if not choices: - return LlmResponse() - + raise ValueError('No choices found in response.') + if len(choices) > 1: + logging.error( + 'Multiple choices found in response but only the first one will be' + ' used.' + ) choice = choices[0] message = choice.get('message', {}) - role = message.get('role', 'model') - if role == 'assistant': - role = 'model' - - parts = [] - content_str = message.get('content') - if content_str: - parts.append(types.Part.from_text(text=content_str)) - - tool_calls = message.get('tool_calls') - if tool_calls: - for tool_call in tool_calls: - call_type = tool_call.get('type', 'unknown') - # TODO: Add support for 'custom' type. - if call_type != 'function': - raise ValueError( - f'Unsupported tool_call type: {call_type} in call {tool_call}' - ) - func = tool_call.get('function', {}) - part = self._parse_function_call(func) - parts.append(part) - - function_call = message.get('function_call') - if function_call: - part = self._parse_function_call(function_call) - parts.append(part) + _, role = self._add_chat_completion_message(message) + parts = self._get_content_parts() usage = response.get('usage', {}) usage_metadata = types.GenerateContentResponseUsageMetadata( @@ -810,19 +865,13 @@ class CompletionsHTTPClient: candidates_token_count=usage.get('completion_tokens', 0), total_token_count=usage.get('total_tokens', 0), ) + logprobs_result = _parse_logprobs(choice.get('logprobs')) - logprobs_result = self._parse_logprobs(choice.get('logprobs')) - - custom_metadata = { - 'id': response.get('id'), - 'created': response.get('created'), - 'model': response.get('model'), - 'system_fingerprint': response.get('system_fingerprint'), - 'service_tier': response.get('service_tier'), - } - custom_metadata = { - k: v for k, v in custom_metadata.items() if v is not None - } + custom_metadata = {} + for k in _CUSTOM_METADATA_FIELDS: + v = response.get(k) + if v is not None: + custom_metadata[k] = v return LlmResponse( content=types.Content(role=role, parts=parts), @@ -833,6 +882,83 @@ class CompletionsHTTPClient: custom_metadata=custom_metadata, ) + def process_chunk( + self, chunk: dict[str, Any] + ) -> Generator[LlmResponse, None, None]: + """Processes a chunk and yields responses.""" + if 'model' in chunk: + self.model = chunk['model'] + if 'usage' in chunk and chunk['usage']: + self.usage.update(chunk['usage']) + + for k in _CUSTOM_METADATA_FIELDS: + v = chunk.get(k) + if v is not None: + self.custom_metadata[k] = v + + usage_metadata = None + if self.usage: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=self.usage.get('prompt_tokens', 0), + candidates_token_count=self.usage.get('completion_tokens', 0), + total_token_count=self.usage.get('total_tokens', 0), + ) + + choices = chunk.get('choices') + if not choices: + # If no choices, but we have usage or other metadata updates, yield them. + if usage_metadata or self.custom_metadata: + yield LlmResponse( + partial=True, + model_version=self.model, + usage_metadata=usage_metadata, + custom_metadata=self.custom_metadata, + ) + return + + if len(choices) > 1: + logging.error( + 'Multiple choices found in streaming response but only the first one' + ' will be used.' + ) + choice = choices[0] + + # Accumulate logprobs if present + if 'logprobs' in choice and choice['logprobs']: + self._accumulate_logprobs(choice['logprobs']) + + logprobs_result = None + if self.logprobs: + logprobs_result = _parse_logprobs(self.logprobs) + + delta = choice.get('delta', {}) + partial_parts, role = self._add_chat_completion_chunk_delta(delta) + + yield LlmResponse( + partial=True, + content=types.Content(role=role, parts=partial_parts), + model_version=self.model, + usage_metadata=usage_metadata, + custom_metadata=self.custom_metadata, + logprobs_result=logprobs_result, + ) + + finish_reason = choice.get('finish_reason') + if finish_reason: + yield LlmResponse( + content=types.Content( + role=role, + parts=self._get_content_parts(), + ), + finish_reason=self._map_finish_reason(finish_reason), + custom_metadata=self.custom_metadata, + model_version=self.model, + usage_metadata=usage_metadata, + logprobs_result=logprobs_result, + ) + # Exit because the 'finish_reason' chunk is the final chunk. + return + def _map_finish_reason(self, reason: str | None) -> types.FinishReason: if reason == 'stop': return types.FinishReason.STOP @@ -844,25 +970,176 @@ class CompletionsHTTPClient: return types.FinishReason.SAFETY return types.FinishReason.FINISH_REASON_UNSPECIFIED - def _parse_function_call(self, func: dict[str, Any]) -> types.Part: - """Parses a function call dictionary into a Part.""" - name = func.get('name') - args_str = func.get('arguments', '{}') - try: - args = json.loads(args_str) - except json.JSONDecodeError: - args = {} - tool_part = types.Part.from_function_call(name=name, args=args) - if tool_part.function_call: - tool_part.function_call.id = func.get('id', None) - # Add support for gemini's thought_signature. - thought_signature = ( - func.get('extra_content', {}) - .get('google', {}) - .get('thought_signature', '') + def _accumulate_logprobs(self, logprobs_chunk: dict[str, Any]) -> None: + """Accumulates logprobs from a chunk.""" + if not self.logprobs: + self.logprobs = {'content': [], 'refusal': []} + + if 'content' in logprobs_chunk and logprobs_chunk['content']: + if 'content' not in self.logprobs: + self.logprobs['content'] = [] + self.logprobs['content'].extend(logprobs_chunk['content']) + + if 'refusal' in logprobs_chunk and logprobs_chunk['refusal']: + if 'refusal' not in self.logprobs: + self.logprobs['refusal'] = [] + self.logprobs['refusal'].extend(logprobs_chunk['refusal']) + + def _append_content(self, content: str, refusal: str) -> str: + if content and refusal: + content += '\n' + content += refusal + elif refusal: + content = refusal + if content: + self.content_parts += content + return content + + def _add_chat_completion_chunk_delta( + self, delta: dict[str, Any] + ) -> (list[types.Part], str): + """Adds a chunk delta from a streaming chat completions response. + + This method processes a single delta chunk from a streaming chat completions + response, accumulating partial content and tool calls. + + Args: + delta: A dictionary representing a single delta from the streaming chat + completions API. + + Returns: + A tuple containing: + - A list of `types.Part` objects representing the content and tool calls + in this chunk. + - The role associated with the message. + """ + parts = [] + for tool_call in delta.get('tool_calls', []): + chunk_part = self._upsert_tool_call(tool_call) + parts.append(chunk_part) + content = delta.get('content') + refusal = delta.get('refusal') + merged_content = self._append_content(content, refusal) + if merged_content: + parts.append(types.Part.from_text(text=merged_content)) + + self._get_or_create_role(delta.get('role', 'model')) + return parts, self.role + + def _add_chat_completion_message( + self, message: dict[str, Any] + ) -> (list[types.Part], str): + """Adds a complete chat completion message to the accumulator. + + This method processes a single message from a non-streaming chat completions + response, extracting and accumulating content and tool calls. + + Args: + message: A dictionary representing a single message from the chat + completions API. + + Returns: + A tuple containing: + - A list of `types.Part` objects representing the content and tool calls + in this message. + - The role associated with the message. + """ + for tool_call in message.get('tool_calls', []): + self._upsert_tool_call(tool_call) + function_call = message.get('function_call') + if function_call: + # function_call is a single tool call and does not have an id. + self._upsert_tool_call({ + 'type': 'function', + 'function': function_call, + }) + content = message.get('content') + refusal = message.get('refusal') + self._append_content(content, refusal) + + self._get_or_create_role(message.get('role', 'model')) + return self._get_content_parts(), self.role + + def _get_content_parts(self) -> list[types.Part]: + """Returns the content parts from the accumulated response.""" + parts = [] + if self.content_parts: + parts.append(types.Part.from_text(text=self.content_parts)) + sorted_indices = sorted(self.tool_call_parts.keys()) + for index in sorted_indices: + parts.append(self.tool_call_parts[index]) + return parts + + def _upsert_tool_call(self, tool_call: dict[str, Any]) -> types.Part: + """Upserts a tool call into the accumulated tool call parts. + + This method handles partial tool call chunks in streaming responses by + updating existing tool call parts or creating new ones. + + Args: + tool_call: A dictionary representing a tool call or a delta of a tool call + from the chat completions API. + + Returns: + A `types.Part` object representing the updated or newly created tool call. + """ + index = tool_call.get('index') + if index is None: + # If index is not provided, we might be in a non-streaming response. + # We just append it as a new tool call. + index = len(self.tool_call_parts) + + if index not in self.tool_call_parts: + self.tool_call_parts[index] = types.Part( + function_call=types.FunctionCall() ) - if thought_signature: - if isinstance(thought_signature, str): - thought_signature = base64.b64decode(thought_signature) - tool_part.thought_signature = thought_signature - return tool_part + part = self.tool_call_parts[index] + chunk_part = types.Part(function_call=types.FunctionCall()) + call_type = tool_call.get('type') + # TODO: Add support for 'custom' type. + if call_type is not None and call_type != 'function': + raise ValueError( + f'Unsupported tool_call type: {call_type} in call {tool_call}' + ) + func = tool_call.get('function', {}) + args_delta = func.get('arguments', '') + if args_delta: + try: + args = json.loads(args_delta) + chunk_part.function_call.args = args + if not part.function_call.args: + part.function_call.args = dict(args) + else: + part.function_call.args.update(args) + except json.JSONDecodeError as e: + raise ValueError(f'Failed to parse arguments: {args_delta}') from e + + func_name = func.get('name') + if func_name: + part.function_call.name = func_name + chunk_part.function_call.name = func_name + tool_call_id = tool_call.get('id') + if tool_call_id: + part.function_call.id = tool_call_id + chunk_part.function_call.id = tool_call_id + + # Add support for gemini's thought_signature. + thought_signature = ( + tool_call.get('extra_content', {}) + .get('google', {}) + .get('thought_signature', '') + ) + if thought_signature: + if isinstance(thought_signature, str): + thought_signature = base64.b64decode(thought_signature) + part.thought_signature = thought_signature + chunk_part.thought_signature = thought_signature + return chunk_part + + def _get_or_create_role(self, role: str = '') -> str: + if self.role: + return self.role + if role == 'assistant': + role = 'model' + self.role = role + return self.role diff --git a/tests/unittests/models/test_completions_http_client.py b/tests/unittests/models/test_completions_http_client.py index f16376d7..615871eb 100644 --- a/tests/unittests/models/test_completions_http_client.py +++ b/tests/unittests/models/test_completions_http_client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from unittest import mock from unittest.mock import AsyncMock @@ -24,7 +25,7 @@ import pytest @pytest.fixture def client(): - return CompletionsHTTPClient(base_url='https://example.com') + return CompletionsHTTPClient(base_url='https://localhost') @pytest.fixture(name='llm_request') @@ -58,7 +59,7 @@ async def test_construct_payload_basic_payload(client, llm_request): url = call_args[0][0] kwargs = call_args[1] - assert url == 'https://example.com/chat/completions' + assert url == 'https://localhost/chat/completions' payload = kwargs['json'] assert payload['model'] == 'open_llama' assert payload['stream'] is False @@ -231,7 +232,7 @@ async def test_construct_payload_image_file_uri(client): role='user', parts=[ types.Part.from_uri( - file_uri='https://example.com/image.jpg', + file_uri='https://localhost/image.jpg', mime_type='image/jpeg', ) ], @@ -263,7 +264,7 @@ async def test_construct_payload_image_file_uri(client): assert isinstance(message['content'], list) assert message['content'][0] == { 'type': 'image_url', - 'image_url': {'url': 'https://example.com/image.jpg'}, + 'image_url': {'url': 'https://localhost/image.jpg'}, } @@ -368,6 +369,7 @@ async def test_construct_payload_response_format( mock_post.assert_called_once() payload = mock_post.call_args[1]['json'] + assert payload['response_format'] == expected_response_format @@ -438,3 +440,334 @@ async def test_generate_content_async_function_call_response( assert part.function_call.name == 'get_weather' assert part.function_call.args == {'location': 'London'} assert part.function_call.id is None + + +@pytest.mark.asyncio +async def test_generate_content_async_streaming_function_call(): + local_client = CompletionsHTTPClient(base_url='https://localhost') + llm_request = LlmRequest( + model='apigee/test', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]) + ], + ) + + # Mock chunks simulating split arguments + chunk_data_0 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': 0, + 'id': 'call_123', + 'type': 'function', + 'function': {'name': 'get_weather', 'arguments': ''}, + }] + }, + 'finish_reason': None, + }], + } + chunk_data_1 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': 0, + 'function': {'arguments': '{"location": "London"}'}, + }] + }, + 'finish_reason': None, + }], + } + chunk_data_2 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': 0, + 'function': {'arguments': '{"country": "UK"}'}, + }] + }, + 'finish_reason': None, + }], + } + chunk_data_3 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'tool_calls'}], + 'usage': { + 'prompt_tokens': 10, + 'completion_tokens': 20, + 'total_tokens': 30, + }, + } + + chunks = [ + f'{json.dumps(chunk_data_0)}\n', + f'{json.dumps(chunk_data_1)}\n', + f'{json.dumps(chunk_data_2)}\n', + f'{json.dumps(chunk_data_3)}\n', + ] + + async def mock_aiter_lines(): + for chunk in chunks: + yield chunk + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.aiter_lines.return_value = mock_aiter_lines() + mock_response.status_code = 200 + + mock_stream_ctx = mock.AsyncMock() + mock_stream_ctx.__aenter__.return_value = mock_response + + with mock.patch.object( + httpx.AsyncClient, 'stream', return_value=mock_stream_ctx + ): + responses = [ + r + async for r in local_client.generate_content_async( + llm_request, stream=True + ) + ] + # Check that we get 5 responses (one per chunk + extra final accumulated) + assert len(responses) == 5 + + # Check 1st response: partial tool call, empty args + assert responses[0].partial is True + assert responses[0].content.parts[0].function_call.name == 'get_weather' + assert responses[0].content.parts[0].function_call.id == 'call_123' + + # Check 2nd response: full args for first update + assert responses[1].partial is True + assert responses[1].content.parts[0].function_call.args == { + 'location': 'London' + } + + # Check 3rd response: full args for second update (merged) + assert responses[2].partial is True + assert responses[2].content.parts[0].function_call.args == {'country': 'UK'} + + # Check 4th response: last delta (empty) + assert responses[3].partial is True + assert responses[3].content.parts == [] + + # Check 5th response: final accumulated + assert responses[4].finish_reason == types.FinishReason.STOP + # Full accumulated args + assert responses[4].content.parts[0].function_call.args == { + 'location': 'London', + 'country': 'UK', + } + + # Check metadata and usage + assert responses[4].model_version == 'gpt-3.5-turbo' + assert responses[4].custom_metadata['id'] == 'chatcmpl-123' + assert responses[4].custom_metadata['created'], 1234567890 + assert responses[4].custom_metadata['object'], 'chat.completion.chunk' + assert responses[4].custom_metadata['service_tier'], 'default' + assert responses[4].usage_metadata is not None + assert responses[4].usage_metadata.prompt_token_count == 10 + assert responses[4].usage_metadata.candidates_token_count == 20 + assert responses[4].usage_metadata.total_token_count == 30 + + +@pytest.mark.asyncio +async def test_generate_content_async_streaming_multiple_function_calls(): + # Mock streaming response with multiple tool calls + local_client = CompletionsHTTPClient(base_url='https://localhost') + llm_request = LlmRequest( + model='apigee/test', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]) + ], + ) + chunk_data_1 = { + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [ + { + 'index': 0, + 'id': 'call_1', + 'type': 'function', + 'function': {'name': 'func_1', 'arguments': ''}, + }, + { + 'index': 1, + 'id': 'call_2', + 'type': 'function', + 'function': {'name': 'func_2', 'arguments': ''}, + }, + ] + }, + 'finish_reason': None, + }] + } + # the tool_call type is optional in chunk responses. + chunk_data_2 = { + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [ + {'index': 0, 'function': {'arguments': '{"arg": 1}'}}, + {'index': 1, 'function': {'arguments': '{"arg": 2}'}}, + ] + }, + 'finish_reason': None, + }] + } + chunk_data_3 = { + 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'tool_calls'}] + } + + chunks = [ + f'{json.dumps(chunk_data_1)}\n', + f'{json.dumps(chunk_data_2)}\n', + f'{json.dumps(chunk_data_3)}\n', + ] + + async def mock_aiter_lines(): + for chunk in chunks: + yield chunk + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.aiter_lines.return_value = mock_aiter_lines() + mock_response.status_code = 200 + + mock_stream_ctx = mock.AsyncMock() + mock_stream_ctx.__aenter__.return_value = mock_response + + with mock.patch.object( + httpx.AsyncClient, 'stream', return_value=mock_stream_ctx + ): + responses = [ + r + async for r in local_client.generate_content_async( + llm_request, stream=True + ) + ] + + assert len(responses) == 4 + parts = responses[-1].content.parts + assert len(parts) == 2 + + assert parts[0].function_call.name == 'func_1' + assert parts[0].function_call.args == {'arg': 1} + assert parts[0].function_call.id == 'call_1' + + assert parts[1].function_call.name == 'func_2' + assert parts[1].function_call.args == {'arg': 2} + + assert parts[1].function_call.id == 'call_2' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ('chunks', 'expected_response_count'), + [ + ( + [ + '\n', + ' \n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + ], + 1, + ), + ( + [ + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + '[DONE]\n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "World"}, "finish_reason": "stop"}]}\n' + ), + ], + 1, # Should stop after [DONE] + ), + ( + [ + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + ' [DONE] \n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "World"}, "finish_reason": "stop"}]}\n' + ), + ], + 1, # Should stop after [DONE] + ), + ( + [ + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + 'data: [DONE]\n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "World"}, "finish_reason": "stop"}]}\n' + ), + ], + 1, # Should stop after [DONE] + ), + ], +) +async def test_generate_content_async_streaming_parse_lines( + chunks, expected_response_count +): + local_client = CompletionsHTTPClient(base_url='https://localhost') + llm_request = LlmRequest( + model='apigee/test', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]) + ], + ) + + async def mock_aiter_lines(): + for chunk in chunks: + yield chunk + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.aiter_lines.return_value = mock_aiter_lines() + mock_response.status_code = 200 + + mock_stream_ctx = mock.AsyncMock() + mock_stream_ctx.__aenter__.return_value = mock_response + + with mock.patch.object( + httpx.AsyncClient, 'stream', return_value=mock_stream_ctx + ): + responses = [ + r + async for r in local_client.generate_content_async( + llm_request, stream=True + ) + ] + assert len(responses) == expected_response_count + assert responses[0].content.parts[0].text == 'Hello' From 8f5428150d18ed732b66379c0acb806a9121c3cb Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Tue, 24 Feb 2026 14:34:52 -0800 Subject: [PATCH 036/102] fix: Update sample skills agent to use weather-skill instead of weather_skill Co-authored-by: Kathy Wu PiperOrigin-RevId: 874796345 --- contributing/samples/skills_agent/agent.py | 2 +- .../skills/{weather_skill => weather-skill}/SKILL.md | 0 .../{weather_skill => weather-skill}/references/weather_info.md | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename contributing/samples/skills_agent/skills/{weather_skill => weather-skill}/SKILL.md (100%) rename contributing/samples/skills_agent/skills/{weather_skill => weather-skill}/references/weather_info.md (100%) diff --git a/contributing/samples/skills_agent/agent.py b/contributing/samples/skills_agent/agent.py index 6cd69ffb..9caf0ad7 100644 --- a/contributing/samples/skills_agent/agent.py +++ b/contributing/samples/skills_agent/agent.py @@ -41,7 +41,7 @@ greeting_skill = models.Skill( ) weather_skill = load_skill_from_dir( - pathlib.Path(__file__).parent / "skills" / "weather_skill" + pathlib.Path(__file__).parent / "skills" / "weather-skill" ) my_skill_toolset = SkillToolset(skills=[greeting_skill, weather_skill]) diff --git a/contributing/samples/skills_agent/skills/weather_skill/SKILL.md b/contributing/samples/skills_agent/skills/weather-skill/SKILL.md similarity index 100% rename from contributing/samples/skills_agent/skills/weather_skill/SKILL.md rename to contributing/samples/skills_agent/skills/weather-skill/SKILL.md diff --git a/contributing/samples/skills_agent/skills/weather_skill/references/weather_info.md b/contributing/samples/skills_agent/skills/weather-skill/references/weather_info.md similarity index 100% rename from contributing/samples/skills_agent/skills/weather_skill/references/weather_info.md rename to contributing/samples/skills_agent/skills/weather-skill/references/weather_info.md From e4d9540ce3552ffd3335e1776293eafee4ea28cd Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Tue, 24 Feb 2026 23:29:06 -0800 Subject: [PATCH 037/102] chore: Make `Release: Please` workflow only run via workflow_dispatch Co-authored-by: Xuan Yang PiperOrigin-RevId: 874980878 --- .github/workflows/release-please.yml | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 791d84a5..41d8d864 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -1,11 +1,10 @@ # Runs release-please to create/update a PR with version bump and changelog. -# Triggered automatically by step 1 (cut) or step 3 (cherry-pick). +# Triggered only by workflow_dispatch (from release-cut.yml). +# Does NOT auto-run on push to preserve manual changelog edits after cherry-picks. name: "Release: Please" on: - push: - branches: - - release/candidate + # Only run via workflow_dispatch (triggered by release-cut.yml) workflow_dispatch: permissions: @@ -14,8 +13,6 @@ permissions: jobs: release-please: - # Skip if this is a release-please PR merge (handled by Release: Finalize) - if: "!startsWith(github.event.head_commit.message, 'chore(release')" runs-on: ubuntu-latest steps: - name: Check if release/candidate still exists From 636f68fbee700aa47f01e2cfd746859353b3333d Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Wed, 25 Feb 2026 00:58:38 -0800 Subject: [PATCH 038/102] feat: Add RunSkillScriptTool to SkillToolset Introduces RunSkillScriptTool to execute scripts located in a skill's scripts/ directory. The execution logic is isolated within a dedicated SkillScriptCodeExecutor wrapper instantiated by RunSkillScriptTool. This wrapper manages script materialization in a temporary directory and executes Python (via runpy) or Shell scripts (returning standard output or JSON-encoded envelopes). This isolation eliminates the need to modify the underlying `BaseCodeExecutor` interface or implementations (`unsafe_local_code_executor`, etc.) to support working directories or file paths. Co-authored-by: Haiyuan Cao PiperOrigin-RevId: 875012237 --- src/google/adk/tools/skill_toolset.py | 366 ++++++++ tests/unittests/tools/test_skill_toolset.py | 894 +++++++++++++++++++- 2 files changed, 1256 insertions(+), 4 deletions(-) diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index f90dfdb2..d13481eb 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -12,16 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=g-import-not-at-top,protected-access + """Toolset for discovering, viewing, and executing agent skills.""" from __future__ import annotations +import asyncio +import json +import logging from typing import Any +from typing import Optional from typing import TYPE_CHECKING from google.genai import types from ..agents.readonly_context import ReadonlyContext +from ..code_executors.base_code_executor import BaseCodeExecutor +from ..code_executors.code_execution_utils import CodeExecutionInput from ..features import experimental from ..features import FeatureName from ..skills import models @@ -33,6 +41,11 @@ from .tool_context import ToolContext if TYPE_CHECKING: from ..models.llm_request import LlmRequest +logger = logging.getLogger("google_adk." + __name__) + +_DEFAULT_SCRIPT_TIMEOUT = 300 +_MAX_SKILL_PAYLOAD_BYTES = 16 * 1024 * 1024 # 16 MB + DEFAULT_SKILL_SYSTEM_INSTRUCTION = """You can use specialized 'skills' to help you with complex tasks. You MUST use the skill tools to interact with these skills. Skills are folders of instructions and resources that extend your capabilities for specialized tasks. Each skill folder contains: @@ -46,6 +59,7 @@ This is very important: 1. If a skill seems relevant to the current user query, you MUST use the `load_skill` tool with `name=""` to read its full instructions before proceeding. 2. Once you have read the instructions, follow them exactly as documented before replying to the user. For example, If the instruction lists multiple steps, please make sure you complete all of them in order. 3. The `load_skill_resource` tool is for viewing files within a skill's directory (e.g., `references/*`, `assets/*`, `scripts/*`). Do NOT use other tools to access these files. +4. Use `run_skill_script` to run scripts from a skill's `scripts/` directory. Use `load_skill_resource` to view script content first if needed. """ @@ -227,6 +241,340 @@ class LoadSkillResourceTool(BaseTool): } +class _SkillScriptCodeExecutor: + """A helper that materializes skill files and executes scripts.""" + + _base_executor: BaseCodeExecutor + _script_timeout: int + + def __init__(self, base_executor: BaseCodeExecutor, script_timeout: int): + self._base_executor = base_executor + self._script_timeout = script_timeout + + async def execute_script_async( + self, + invocation_context: Any, + skill: models.Skill, + script_path: str, + script_args: dict[str, Any], + ) -> dict[str, Any]: + """Prepares and executes the script using the base executor.""" + code = self._build_wrapper_code(skill, script_path, script_args) + if code is None: + if "." in script_path: + ext_msg = f"'.{script_path.rsplit('.', 1)[-1]}'" + else: + ext_msg = "(no extension)" + return { + "error": ( + f"Unsupported script type {ext_msg}." + " Supported types: .py, .sh, .bash" + ), + "error_code": "UNSUPPORTED_SCRIPT_TYPE", + } + + try: + # Execute the self-contained script using the underlying executor + result = await asyncio.to_thread( + self._base_executor.execute_code, + invocation_context, + CodeExecutionInput(code=code), + ) + + stdout = result.stdout + stderr = result.stderr + + # Shell scripts serialize both streams as JSON + # through stdout; parse the envelope if present. + rc = 0 + is_shell = "." in script_path and script_path.rsplit(".", 1)[ + -1 + ].lower() in ("sh", "bash") + if is_shell and stdout: + try: + parsed = json.loads(stdout) + if isinstance(parsed, dict) and parsed.get("__shell_result__"): + stdout = parsed.get("stdout", "") + stderr = parsed.get("stderr", "") + rc = parsed.get("returncode", 0) + if rc != 0 and not stderr: + stderr = f"Exit code {rc}" + except (json.JSONDecodeError, ValueError): + pass + + status = "success" + if rc != 0: + status = "error" + elif stderr and not stdout: + status = "error" + elif stderr: + status = "warning" + + return { + "skill_name": skill.name, + "script_path": script_path, + "stdout": stdout, + "stderr": stderr, + "status": status, + } + except SystemExit as e: + if e.code in (None, 0): + return { + "skill_name": skill.name, + "script_path": script_path, + "stdout": "", + "stderr": "", + "status": "success", + } + return { + "error": ( + f"Failed to execute script '{script_path}':" + f" exited with code {e.code}" + ), + "error_code": "EXECUTION_ERROR", + } + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception( + "Error executing script '%s' from skill '%s'", + script_path, + skill.name, + ) + short_msg = str(e) + if len(short_msg) > 200: + short_msg = short_msg[:200] + "..." + return { + "error": ( + "Failed to execute script" + f" '{script_path}':\n{type(e).__name__}:" + f" {short_msg}" + ), + "error_code": "EXECUTION_ERROR", + } + + def _build_wrapper_code( + self, + skill: models.Skill, + script_path: str, + script_args: dict[str, Any], + ) -> str | None: + """Builds a self-extracting Python script.""" + ext = "" + if "." in script_path: + ext = script_path.rsplit(".", 1)[-1].lower() + + if not script_path.startswith("scripts/"): + script_path = f"scripts/{script_path}" + + files_dict = {} + for ref_name in skill.resources.list_references(): + content = skill.resources.get_reference(ref_name) + if content is not None: + files_dict[f"references/{ref_name}"] = content + + for asset_name in skill.resources.list_assets(): + content = skill.resources.get_asset(asset_name) + if content is not None: + files_dict[f"assets/{asset_name}"] = content + + for scr_name in skill.resources.list_scripts(): + scr = skill.resources.get_script(scr_name) + if scr is not None and scr.src is not None: + files_dict[f"scripts/{scr_name}"] = scr.src + + total_size = sum( + len(v) if isinstance(v, (str, bytes)) else 0 + for v in files_dict.values() + ) + if total_size > _MAX_SKILL_PAYLOAD_BYTES: + logger.warning( + "Skill '%s' resources total %d bytes, exceeding" + " the recommended limit of %d bytes.", + skill.name, + total_size, + _MAX_SKILL_PAYLOAD_BYTES, + ) + + # Build the boilerplate extract string + code_lines = [ + "import os", + "import tempfile", + "import sys", + "import json as _json", + "import subprocess", + "import runpy", + f"_files = {files_dict!r}", + "def _materialize_and_run():", + " _orig_cwd = os.getcwd()", + " with tempfile.TemporaryDirectory() as td:", + " for rel_path, content in _files.items():", + " full_path = os.path.join(td, rel_path)", + " os.makedirs(os.path.dirname(full_path), exist_ok=True)", + " mode = 'wb' if isinstance(content, bytes) else 'w'", + " with open(full_path, mode) as f:", + " f.write(content)", + " os.chdir(td)", + " try:", + ] + + if ext == "py": + argv_list = [script_path] + for k, v in script_args.items(): + argv_list.extend([f"--{k}", str(v)]) + code_lines.extend([ + f" sys.argv = {argv_list!r}", + " try:", + f" runpy.run_path({script_path!r}, run_name='__main__')", + " except SystemExit as e:", + " if e.code is not None and e.code != 0:", + " raise e", + ]) + elif ext in ("sh", "bash"): + arr = ["bash", script_path] + for k, v in script_args.items(): + arr.extend([f"--{k}", str(v)]) + timeout = self._script_timeout + code_lines.extend([ + " try:", + " _r = subprocess.run(", + f" {arr!r},", + " capture_output=True, text=True,", + f" timeout={timeout!r}, cwd=td,", + " )", + " print(_json.dumps({", + " '__shell_result__': True,", + " 'stdout': _r.stdout,", + " 'stderr': _r.stderr,", + " 'returncode': _r.returncode,", + " }))", + " except subprocess.TimeoutExpired as _e:", + " print(_json.dumps({", + " '__shell_result__': True,", + " 'stdout': _e.stdout or '',", + f" 'stderr': 'Timed out after {timeout}s',", + " 'returncode': -1,", + " }))", + ]) + else: + return None + + code_lines.extend([ + " finally:", + " os.chdir(_orig_cwd)", + ]) + + code_lines.append("_materialize_and_run()") + return "\n".join(code_lines) + + +@experimental(FeatureName.SKILL_TOOLSET) +class RunSkillScriptTool(BaseTool): + """Tool to execute scripts from a skill's scripts/ directory.""" + + def __init__(self, toolset: "SkillToolset"): + super().__init__( + name="run_skill_script", + description="Executes a script from a skill's scripts/ directory.", + ) + self._toolset = toolset + + def _get_declaration(self) -> types.FunctionDeclaration | None: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters_json_schema={ + "type": "object", + "properties": { + "skill_name": { + "type": "string", + "description": "The name of the skill.", + }, + "script_path": { + "type": "string", + "description": ( + "The relative path to the script (e.g.," + " 'scripts/setup.py')." + ), + }, + "args": { + "type": "object", + "description": ( + "Optional arguments to pass to the script as key-value" + " pairs." + ), + }, + }, + "required": ["skill_name", "script_path"], + }, + ) + + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + skill_name = args.get("skill_name") + script_path = args.get("script_path") + script_args = args.get("args", {}) + if not isinstance(script_args, dict): + return { + "error": ( + "'args' must be a JSON object (key-value pairs)," + f" got {type(script_args).__name__}." + ), + "error_code": "INVALID_ARGS_TYPE", + } + + if not skill_name: + return { + "error": "Skill name is required.", + "error_code": "MISSING_SKILL_NAME", + } + if not script_path: + return { + "error": "Script path is required.", + "error_code": "MISSING_SCRIPT_PATH", + } + + skill = self._toolset._get_skill(skill_name) + if not skill: + return { + "error": f"Skill '{skill_name}' not found.", + "error_code": "SKILL_NOT_FOUND", + } + + script = None + if script_path.startswith("scripts/"): + script = skill.resources.get_script(script_path[len("scripts/") :]) + else: + script = skill.resources.get_script(script_path) + + if script is None: + return { + "error": f"Script '{script_path}' not found in skill '{skill_name}'.", + "error_code": "SCRIPT_NOT_FOUND", + } + + # Resolve code executor: toolset-level first, then agent fallback + code_executor = self._toolset._code_executor + if code_executor is None: + agent = tool_context._invocation_context.agent + if hasattr(agent, "code_executor"): + code_executor = agent.code_executor + if code_executor is None: + return { + "error": ( + "No code executor configured. A code executor is" + " required to run scripts." + ), + "error_code": "NO_CODE_EXECUTOR", + } + + script_executor = _SkillScriptCodeExecutor( + code_executor, self._toolset._script_timeout # pylint: disable=protected-access + ) + return await script_executor.execute_script_async( + tool_context._invocation_context, skill, script_path, script_args # pylint: disable=protected-access + ) + + @experimental(FeatureName.SKILL_TOOLSET) class SkillToolset(BaseToolset): """A toolset for managing and interacting with agent skills.""" @@ -234,7 +582,19 @@ class SkillToolset(BaseToolset): def __init__( self, skills: list[models.Skill], + *, + code_executor: Optional[BaseCodeExecutor] = None, + script_timeout: int = _DEFAULT_SCRIPT_TIMEOUT, ): + """Initializes the SkillToolset. + + Args: + skills: List of skills to register. + code_executor: Optional code executor for script execution. + script_timeout: Timeout in seconds for shell script execution via + subprocess.run. Defaults to 300 seconds. Does not apply to Python + scripts executed via exec(). + """ super().__init__() # Check for duplicate skill names @@ -245,11 +605,17 @@ class SkillToolset(BaseToolset): seen.add(skill.name) self._skills = {skill.name: skill for skill in skills} + self._code_executor = code_executor + self._script_timeout = script_timeout + + # Initialize core skill tools self._tools = [ ListSkillsTool(self), LoadSkillTool(self), LoadSkillResourceTool(self), ] + # Always add RunSkillScriptTool, relies on invocation_context fallback if _code_executor is None + self._tools.append(RunSkillScriptTool(self)) async def get_tools( self, readonly_context: ReadonlyContext | None = None diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index 066eedfb..65323324 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=redefined-outer-name,g-import-not-at-top,protected-access + + from unittest import mock +from google.adk.code_executors.base_code_executor import BaseCodeExecutor +from google.adk.code_executors.code_execution_utils import CodeExecutionResult from google.adk.models import llm_request as llm_request_model from google.adk.skills import models from google.adk.tools import skill_toolset @@ -27,6 +32,7 @@ def mock_skill1_frontmatter(): frontmatter = mock.create_autospec(models.Frontmatter, instance=True) frontmatter.name = "skill1" frontmatter.description = "Skill 1 description" + frontmatter.allowed_tools = ["test_tool"] frontmatter.model_dump.return_value = { "name": "skill1", "description": "Skill 1 description", @@ -43,7 +49,14 @@ def mock_skill1(mock_skill1_frontmatter): skill.instructions = "instructions for skill1" skill.frontmatter = mock_skill1_frontmatter skill.resources = mock.MagicMock( - spec=["get_reference", "get_asset", "get_script"] + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] ) def get_ref(name): @@ -59,11 +72,22 @@ def mock_skill1(mock_skill1_frontmatter): def get_script(name): if name == "setup.sh": return models.Script(src="echo setup") + if name == "run.py": + return models.Script(src="print('hello')") + if name == "build.rb": + return models.Script(src="puts 'hello'") return None skill.resources.get_reference.side_effect = get_ref skill.resources.get_asset.side_effect = get_asset skill.resources.get_script.side_effect = get_script + skill.resources.list_references.return_value = ["ref1.md"] + skill.resources.list_assets.return_value = ["asset1.txt"] + skill.resources.list_scripts.return_value = [ + "setup.sh", + "run.py", + "build.rb", + ] return skill @@ -73,6 +97,7 @@ def mock_skill2_frontmatter(): frontmatter = mock.create_autospec(models.Frontmatter, instance=True) frontmatter.name = "skill2" frontmatter.description = "Skill 2 description" + frontmatter.allowed_tools = [] frontmatter.model_dump.return_value = { "name": "skill2", "description": "Skill 2 description", @@ -89,7 +114,14 @@ def mock_skill2(mock_skill2_frontmatter): skill.instructions = "instructions for skill2" skill.frontmatter = mock_skill2_frontmatter skill.resources = mock.MagicMock( - spec=["get_reference", "get_asset", "get_script"] + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] ) def get_ref(name): @@ -104,6 +136,9 @@ def mock_skill2(mock_skill2_frontmatter): skill.resources.get_reference.side_effect = get_ref skill.resources.get_asset.side_effect = get_asset + skill.resources.list_references.return_value = ["ref2.md"] + skill.resources.list_assets.return_value = ["asset2.txt"] + skill.resources.list_scripts.return_value = [] return skill @@ -132,13 +167,13 @@ def test_list_skills(mock_skill1, mock_skill2): async def test_get_tools(mock_skill1, mock_skill2): toolset = skill_toolset.SkillToolset([mock_skill1, mock_skill2]) tools = await toolset.get_tools() - assert len(tools) == 3 + assert len(tools) == 4 assert isinstance(tools[0], skill_toolset.ListSkillsTool) assert isinstance(tools[1], skill_toolset.LoadSkillTool) assert isinstance(tools[2], skill_toolset.LoadSkillResourceTool) + assert isinstance(tools[3], skill_toolset.RunSkillScriptTool) -@pytest.mark.asyncio @pytest.mark.asyncio async def test_list_skills_tool( mock_skill1, mock_skill2, tool_context_instance @@ -308,3 +343,854 @@ async def test_scripts_resource_not_found(mock_skill1, tool_context_instance): tool_context=tool_context_instance, ) assert result["error_code"] == "RESOURCE_NOT_FOUND" + + +# RunSkillScriptTool tests + + +def _make_tool_context_with_agent(agent=None): + """Creates a mock ToolContext with _invocation_context.agent.""" + ctx = mock.MagicMock(spec=tool_context.ToolContext) + ctx._invocation_context = mock.MagicMock() + ctx._invocation_context.agent = agent or mock.MagicMock() + return ctx + + +def _make_mock_executor(stdout="", stderr=""): + """Creates a mock code executor that returns the given output.""" + executor = mock.create_autospec(BaseCodeExecutor, instance=True) + executor.execute_code.return_value = CodeExecutionResult( + stdout=stdout, stderr=stderr + ) + return executor + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "args, expected_error_code", + [ + ( + {"script_path": "setup.sh"}, + "MISSING_SKILL_NAME", + ), + ( + {"skill_name": "skill1"}, + "MISSING_SCRIPT_PATH", + ), + ( + {"skill_name": "", "script_path": "setup.sh"}, + "MISSING_SKILL_NAME", + ), + ( + {"skill_name": "skill1", "script_path": ""}, + "MISSING_SCRIPT_PATH", + ), + ], +) +async def test_execute_script_missing_params( + mock_skill1, args, expected_error_code +): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async(args=args, tool_context=ctx) + assert result["error_code"] == expected_error_code + + +@pytest.mark.asyncio +async def test_execute_script_skill_not_found(mock_skill1): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "nonexistent", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["error_code"] == "SKILL_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_execute_script_script_not_found(mock_skill1): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "nonexistent.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "SCRIPT_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_execute_script_no_code_executor(mock_skill1): + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.RunSkillScriptTool(toolset) + # Agent without code_executor attribute + agent = mock.MagicMock(spec=[]) + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["error_code"] == "NO_CODE_EXECUTOR" + + +@pytest.mark.asyncio +async def test_execute_script_agent_code_executor_none(mock_skill1): + """Agent has code_executor attr but it's None.""" + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.RunSkillScriptTool(toolset) + agent = mock.MagicMock() + agent.code_executor = None + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["error_code"] == "NO_CODE_EXECUTOR" + + +@pytest.mark.asyncio +async def test_execute_script_unsupported_type(mock_skill1): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "build.rb"}, + tool_context=ctx, + ) + assert result["error_code"] == "UNSUPPORTED_SCRIPT_TYPE" + + +@pytest.mark.asyncio +async def test_execute_script_python_success(mock_skill1): + executor = _make_mock_executor(stdout="hello\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "hello\n" + assert result["stderr"] == "" + assert result["skill_name"] == "skill1" + assert result["script_path"] == "run.py" + + # Verify the code passed to executor runs the python scripts + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "_materialize_and_run()" in code_input.code + assert "import runpy" in code_input.code + assert "sys.argv = ['scripts/run.py']" in code_input.code + assert ( + "runpy.run_path('scripts/run.py', run_name='__main__')" in code_input.code + ) + + +@pytest.mark.asyncio +async def test_execute_script_shell_success(mock_skill1): + executor = _make_mock_executor(stdout="setup\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "setup\n" + + # Verify the code wraps in subprocess.run with JSON envelope + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "subprocess.run" in code_input.code + assert "bash" in code_input.code + assert "__shell_result__" in code_input.code + + +@pytest.mark.asyncio +async def test_execute_script_with_input_args_python(mock_skill1): + executor = _make_mock_executor(stdout="done\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "run.py", + "args": {"verbose": True, "count": "3"}, + }, + tool_context=ctx, + ) + assert result["status"] == "success" + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert ( + "['scripts/run.py', '--verbose', 'True', '--count', '3']" + in code_input.code + ) + + +@pytest.mark.asyncio +async def test_execute_script_with_input_args_shell(mock_skill1): + executor = _make_mock_executor(stdout="done\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "setup.sh", + "args": {"force": True}, + }, + tool_context=ctx, + ) + assert result["status"] == "success" + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "['bash', 'scripts/setup.sh', '--force', 'True']" in code_input.code + + +@pytest.mark.asyncio +async def test_execute_script_scripts_prefix_stripping(mock_skill1): + executor = _make_mock_executor(stdout="setup\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "scripts/setup.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["script_path"] == "scripts/setup.sh" + + +@pytest.mark.asyncio +async def test_execute_script_toolset_executor_priority(mock_skill1): + """Toolset-level executor takes priority over agent's.""" + toolset_executor = _make_mock_executor(stdout="from toolset\n") + agent_executor = _make_mock_executor(stdout="from agent\n") + toolset = skill_toolset.SkillToolset( + [mock_skill1], code_executor=toolset_executor + ) + tool = skill_toolset.RunSkillScriptTool(toolset) + agent = mock.MagicMock() + agent.code_executor = agent_executor + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["stdout"] == "from toolset\n" + toolset_executor.execute_code.assert_called_once() + agent_executor.execute_code.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_script_agent_executor_fallback(mock_skill1): + """Falls back to agent's code executor when toolset has none.""" + agent_executor = _make_mock_executor(stdout="from agent\n") + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.RunSkillScriptTool(toolset) + agent = mock.MagicMock() + agent.code_executor = agent_executor + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["stdout"] == "from agent\n" + agent_executor.execute_code.assert_called_once() + + +@pytest.mark.asyncio +async def test_execute_script_execution_error(mock_skill1): + executor = _make_mock_executor() + executor.execute_code.side_effect = RuntimeError("boom") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "EXECUTION_ERROR" + assert "boom" in result["error"] + assert result["error"].startswith("Failed to execute script 'run.py':") + + +@pytest.mark.asyncio +async def test_execute_script_stderr_only_sets_error_status(mock_skill1): + """stderr with no stdout should report error status.""" + executor = _make_mock_executor(stdout="", stderr="fatal error\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "error" + assert result["stderr"] == "fatal error\n" + + +@pytest.mark.asyncio +async def test_execute_script_stderr_with_stdout_sets_warning(mock_skill1): + """stderr alongside stdout should report warning status.""" + executor = _make_mock_executor(stdout="output\n", stderr="deprecation\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "warning" + assert result["stdout"] == "output\n" + assert result["stderr"] == "deprecation\n" + + +@pytest.mark.asyncio +async def test_execute_script_execution_error_truncated(mock_skill1): + """Long exception messages are truncated to avoid wasting LLM tokens.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = RuntimeError("x" * 300) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "EXECUTION_ERROR" + # 200 chars of the message + "..." suffix + the prefix + assert result["error"].endswith("...") + assert len(result["error"]) < 300 + + +@pytest.mark.asyncio +async def test_execute_script_system_exit_caught(mock_skill1): + """sys.exit() in a script should not terminate the process.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = SystemExit(1) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "EXECUTION_ERROR" + assert "exited with code 1" in result["error"] + + +@pytest.mark.asyncio +async def test_execute_script_system_exit_zero_is_success(mock_skill1): + """sys.exit(0) is a normal termination and should report success.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = SystemExit(0) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_execute_script_system_exit_none_is_success(mock_skill1): + """sys.exit() with no arg (None) should report success.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = SystemExit(None) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_execute_script_shell_includes_timeout(mock_skill1): + """Shell wrapper includes timeout in subprocess.run.""" + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset( + [mock_skill1], code_executor=executor, script_timeout=60 + ) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "timeout=60" in code_input.code + + +@pytest.mark.asyncio +async def test_execute_script_extensionless_unsupported(mock_skill1): + """Files without extensions should return UNSUPPORTED_SCRIPT_TYPE.""" + # Add a script with no extension to the mock + original_side_effect = mock_skill1.resources.get_script.side_effect + + def get_script_extended(name): + if name == "noext": + return models.Script(src="print('hi')") + return original_side_effect(name) + + mock_skill1.resources.get_script.side_effect = get_script_extended + + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "noext"}, + tool_context=ctx, + ) + assert result["error_code"] == "UNSUPPORTED_SCRIPT_TYPE" + + +# ── Integration tests using real UnsafeLocalCodeExecutor ── + + +def _make_skill_with_script(skill_name, script_name, script): + """Creates a minimal mock Skill with a single script.""" + skill = mock.create_autospec(models.Skill, instance=True) + skill.name = skill_name + skill.description = f"Test skill {skill_name}" + skill.instructions = "test instructions" + fm = mock.create_autospec(models.Frontmatter, instance=True) + fm.name = skill_name + fm.description = f"Test skill {skill_name}" + skill.frontmatter = fm + skill.resources = mock.MagicMock( + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] + ) + + def get_script(name): + if name == script_name: + return script + return None + + skill.resources.get_script.side_effect = get_script + skill.resources.get_reference.return_value = None + skill.resources.get_asset.return_value = None + skill.resources.list_references.return_value = [] + skill.resources.list_assets.return_value = [] + skill.resources.list_scripts.return_value = [script_name] + return skill + + +def _make_real_executor_toolset(skills, **kwargs): + """Creates a SkillToolset with a real UnsafeLocalCodeExecutor.""" + from google.adk.code_executors.unsafe_local_code_executor import UnsafeLocalCodeExecutor + + executor = UnsafeLocalCodeExecutor() + return skill_toolset.SkillToolset(skills, code_executor=executor, **kwargs) + + +@pytest.mark.asyncio +async def test_integration_python_stdout(): + """Real executor: Python script stdout is captured.""" + script = models.Script(src="print('hello world')") + skill = _make_skill_with_script("test_skill", "hello.py", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "hello.py", + }, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "hello world\n" + assert result["stderr"] == "" + + +@pytest.mark.asyncio +async def test_integration_python_sys_exit_zero(): + """Real executor: sys.exit(0) is treated as success.""" + script = models.Script(src="import sys; sys.exit(0)") + skill = _make_skill_with_script("test_skill", "exit_zero.py", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "exit_zero.py", + }, + tool_context=ctx, + ) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_integration_shell_stdout_and_stderr(): + """Real executor: shell script preserves both stdout and stderr.""" + script = models.Script(src="echo output; echo warning >&2") + skill = _make_skill_with_script("test_skill", "both.sh", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "both.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "warning" + assert "output" in result["stdout"] + assert "warning" in result["stderr"] + + +@pytest.mark.asyncio +async def test_integration_shell_stderr_only(): + """Real executor: shell script with only stderr reports error.""" + script = models.Script(src="echo failure >&2") + skill = _make_skill_with_script("test_skill", "err.sh", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "err.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "error" + assert "failure" in result["stderr"] + + +# ── Shell JSON envelope parsing (unit tests with mock executor) ── + + +@pytest.mark.asyncio +async def test_shell_json_envelope_parsed(mock_skill1): + """Shell JSON envelope is correctly unpacked by run_async.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "hello from shell\n", + "stderr": "", + "returncode": 0, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "hello from shell\n" + assert result["stderr"] == "" + + +@pytest.mark.asyncio +async def test_shell_json_envelope_nonzero_returncode(mock_skill1): + """Non-zero returncode in shell envelope sets stderr.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "", + "stderr": "", + "returncode": 2, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "error" + assert "Exit code 2" in result["stderr"] + + +@pytest.mark.asyncio +async def test_shell_json_envelope_with_stderr(mock_skill1): + """Shell envelope with both stdout and stderr reports warning.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "data\n", + "stderr": "deprecation warning\n", + "returncode": 0, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "warning" + assert result["stdout"] == "data\n" + assert result["stderr"] == "deprecation warning\n" + + +@pytest.mark.asyncio +async def test_shell_json_envelope_timeout(mock_skill1): + """Shell envelope from TimeoutExpired reports error status.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "partial output\n", + "stderr": "Timed out after 300s", + "returncode": -1, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "error" + assert result["stdout"] == "partial output\n" + assert "Timed out" in result["stderr"] + + +@pytest.mark.asyncio +async def test_shell_non_json_stdout_passthrough(mock_skill1): + """Non-JSON shell stdout is passed through without parsing.""" + executor = _make_mock_executor(stdout="plain text output\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "plain text output\n" + + +# ── input_files packaging ── + + +@pytest.mark.asyncio +async def test_execute_script_input_files_packaged(mock_skill1): + """Verify references, assets, and scripts are packaged inside the wrapper code.""" + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + + # input_files is no longer populated; it's serialized inside the script + assert code_input.input_files is None or len(code_input.input_files) == 0 + + # Ensure the extracted literal contains our fake files + assert "references/ref1.md" in code_input.code + assert "assets/asset1.txt" in code_input.code + assert "scripts/setup.sh" in code_input.code + assert "scripts/run.py" in code_input.code + assert "scripts/build.rb" in code_input.code + + # Verify content mappings exist in the string + assert "'references/ref1.md': 'ref content 1'" in code_input.code + assert "'assets/asset1.txt': 'asset content 1'" in code_input.code + + +# ── Integration: shell non-zero exit ── + + +@pytest.mark.asyncio +async def test_integration_shell_nonzero_exit(): + """Real executor: shell script with non-zero exit via JSON envelope.""" + script = models.Script(src="exit 42") + skill = _make_skill_with_script("test_skill", "fail.sh", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "fail.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "error" + assert "42" in result["stderr"] + + +# ── Finding 1: system instruction references correct tool name ── + + +def test_system_instruction_references_run_skill_script(): + """System instruction must reference the actual tool name.""" + assert "run_skill_script" in skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + assert ( + "execute_skill_script" + not in skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + ) + + +# ── Finding 2: empty files are mounted (not silently dropped) ── + + +@pytest.mark.asyncio +async def test_execute_script_empty_files_mounted(): + """Verify empty files are included in wrapper code, not dropped.""" + skill = mock.create_autospec(models.Skill, instance=True) + skill.name = "skill_empty" + skill.resources = mock.MagicMock( + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] + ) + skill.resources.get_reference.side_effect = ( + lambda n: "" if n == "empty.md" else None + ) + skill.resources.get_asset.side_effect = ( + lambda n: "" if n == "empty.cfg" else None + ) + skill.resources.get_script.side_effect = ( + lambda n: models.Script(src="") if n == "run.py" else None + ) + skill.resources.list_references.return_value = ["empty.md"] + skill.resources.list_assets.return_value = ["empty.cfg"] + skill.resources.list_scripts.return_value = ["run.py"] + + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset([skill], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + await tool.run_async( + args={"skill_name": "skill_empty", "script_path": "run.py"}, + tool_context=ctx, + ) + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "'references/empty.md': ''" in code_input.code + assert "'assets/empty.cfg': ''" in code_input.code + assert "'scripts/run.py': ''" in code_input.code + + +# ── Finding 3: invalid args type returns clear error ── + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "bad_args", + [ + "not a dict", + ["a", "list"], + 42, + True, + ], +) +async def test_execute_script_invalid_args_type(mock_skill1, bad_args): + """Non-dict args should return INVALID_ARGS_TYPE, not crash.""" + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "run.py", + "args": bad_args, + }, + tool_context=ctx, + ) + assert result["error_code"] == "INVALID_ARGS_TYPE" + executor.execute_code.assert_not_called() + + +# ── Finding 4: binary file content is handled in wrapper ── + + +@pytest.mark.asyncio +async def test_execute_script_binary_content_packaged(): + """Verify binary asset content uses 'wb' mode in wrapper code.""" + skill = mock.create_autospec(models.Skill, instance=True) + skill.name = "skill_bin" + skill.resources = mock.MagicMock( + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] + ) + skill.resources.get_reference.side_effect = ( + lambda n: b"\x00\x01\x02" if n == "data.bin" else None + ) + skill.resources.get_asset.return_value = None + skill.resources.get_script.side_effect = lambda n: ( + models.Script(src="print('ok')") if n == "run.py" else None + ) + skill.resources.list_references.return_value = ["data.bin"] + skill.resources.list_assets.return_value = [] + skill.resources.list_scripts.return_value = ["run.py"] + + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset([skill], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + await tool.run_async( + args={"skill_name": "skill_bin", "script_path": "run.py"}, + tool_context=ctx, + ) + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + # Binary content should appear as bytes literal + assert "b'\\x00\\x01\\x02'" in code_input.code + # Wrapper code handles binary with 'wb' mode + assert "'wb' if isinstance(content, bytes)" in code_input.code From e59929e11a56aaee7bb0c45cd4c9d9fef689548c Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 25 Feb 2026 06:22:56 -0800 Subject: [PATCH 039/102] fix: Propagate thought from A2A TextPart metadata to GenAI Part When converting an A2A TextPart to a GenAI Part, extract the 'thought' field from the TextPart's metadata and include it in the GenAI Part. PiperOrigin-RevId: 875129430 --- .../adk/a2a/converters/part_converter.py | 5 ++++- .../a2a/converters/test_part_converter.py | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index 72dbcb21..7b501f75 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -61,7 +61,10 @@ def convert_a2a_part_to_genai_part( """Convert an A2A Part to a Google GenAI Part.""" part = a2a_part.root if isinstance(part, a2a_types.TextPart): - return genai_types.Part(text=part.text) + thought = None + if part.metadata: + thought = part.metadata.get(_get_adk_metadata_key('thought')) + return genai_types.Part(text=part.text, thought=thought) if isinstance(part, a2a_types.FilePart): if isinstance(part.file, a2a_types.FileWithUri): diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 647caa5b..ec611fba 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -289,7 +289,7 @@ class TestConvertGenaiPartToA2aPart: assert isinstance(result.root, a2a_types.TextPart) assert result.root.text == "Hello, world!" assert result.root.metadata is not None - assert result.root.metadata[_get_adk_metadata_key("thought")] == True + assert result.root.metadata[_get_adk_metadata_key("thought")] def test_convert_file_data_part(self): """Test conversion of GenAI file_data Part to A2A Part.""" @@ -516,6 +516,22 @@ class TestRoundTripConversions: assert isinstance(result_a2a_part.root, a2a_types.TextPart) assert result_a2a_part.root.text == original_text + def test_text_part_with_thought_round_trip(self): + """Test round-trip conversion for text parts with thought.""" + # Arrange + original_text = "Thinking..." + genai_part = genai_types.Part(text=original_text, thought=True) + + # Act + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.text == original_text + assert result_genai_part.thought + def test_file_uri_round_trip(self): """Test round-trip conversion for file parts with URI.""" # Arrange From 3256a679da3e0fb6f18b26057e87f5284680cb58 Mon Sep 17 00:00:00 2001 From: Brian Fox <878612+onematchfox@users.noreply.github.com> Date: Wed, 25 Feb 2026 10:10:09 -0800 Subject: [PATCH 040/102] fix(tools): Handle JSON Schema boolean schemas in Gemini schema conversion Merge https://github.com/google/adk-python/pull/4531 **Problem:** JSON Schema allows `true` and `false` as valid boolean schemas, where `true` accepts any value and `false` rejects all values. Some MCP servers use this pattern for unconstrained fields. E.g. [mcp-grafana](https://github.com/grafana/mcp-grafana) - see [grafana-mcp-list-tools.json](https://github.com/user-attachments/files/25392430/grafana-mcp-list-tools.json) which was obtained from `tools/list` The schema sanitizer previously passed booleans through unchanged, causing a Pydantic ValidationError when `_ExtendedJSONSchema` tried to validate them as schema objects. ``` 1 validation error for _ExtendedJSONSchema properties.data.items.properties.model Input should be a valid dictionary or object to extract fields from [type=model_attributes_type, input_value=True, input_type=bool] For further information visit https://errors.pydantic.dev/2.12/v/model_attributes_type Traceback (most recent call last): ... File "/.foo/.venv/lib/python3.13/site-packages/google/adk/runners.py", line 561, in run_async async for event in agen: yield event File "/.foo/.venv/lib/python3.13/site-packages/google/adk/runners.py", line 549, in _run_with_trace async for event in agen: yield event File "/.foo/.venv/lib/python3.13/site-packages/google/adk/runners.py", line 778, in _exec_with_plugin async for event in agen: ...<64 lines>... yield event File "/.foo/.venv/lib/python3.13/site-packages/google/adk/runners.py", line 538, in execute async for event in agen: yield event File "/.foo/.venv/lib/python3.13/site-packages/google/adk/agents/base_agent.py", line 294, in run_async async for event in agen: yield event File "/.foo/.venv/lib/python3.13/site-packages/google/adk/agents/llm_agent.py", line 468, in _run_async_impl async for event in agen: ...<5 lines>... should_pause = True File "/.foo/.venv/lib/python3.13/site-packages/google/adk/flows/llm_flows/base_llm_flow.py", line 427, in run_async async for event in agen: last_event = event yield event File "/.foo/.venv/lib/python3.13/site-packages/google/adk/flows/llm_flows/base_llm_flow.py", line 446, in _run_one_step_async async for event in agen: yield event File "/.foo/.venv/lib/python3.13/site-packages/google/adk/flows/llm_flows/base_llm_flow.py", line 578, in _preprocess_async await tool.process_llm_request( tool_context=tool_context, llm_request=llm_request ) File "/.foo/.venv/lib/python3.13/site-packages/google/adk/tools/base_tool.py", line 129, in process_llm_request llm_request.append_tools([self]) ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^ File "/.foo/.venv/lib/python3.13/site-packages/google/adk/models/llm_request.py", line 255, in append_tools declaration = tool._get_declaration() File "/.foo/.venv/lib/python3.13/site-packages/google/adk/tools/mcp_tool/mcp_tool.py", line 200, in _get_declaration parameters = _to_gemini_schema(input_schema) File "/.foo/.venv/lib/python3.13/site-packages/google/adk/tools/_gemini_schema_util.py", line 218, in _to_gemini_schema json_schema=_ExtendedJSONSchema.model_validate(sanitized_schema), ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^ File "/.foo/.venv/lib/python3.13/site-packages/pydantic/main.py", line 716, in model_validate return cls.__pydantic_validator__.validate_python( ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^ obj, ^^^^ ...<5 lines>... by_name=by_name, ^^^^^^^^^^^^^^^^ ) ^ pydantic_core._pydantic_core.ValidationError: 1 validation error for _ExtendedJSONSchema properties.data.items.properties.model Input should be a valid dictionary or object to extract fields from [type=model_attributes_type, input_value=True, input_type=bool] For further information visit https://errors.pydantic.dev/2.12/v/model_attributes_type ``` **Solution:** Convert boolean schemas to `{"type": "object"}` as the closest approximation available in Gemini's schema model. Co-authored-by: Xuan Yang COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4531 from onematchfox:fix-gemini-schema-bool 383ac0c0c3ab78d77be4503f5d6b9ad26c41b0db PiperOrigin-RevId: 875219362 --- src/google/adk/tools/_gemini_schema_util.py | 7 ++ .../tools/test_gemini_schema_util.py | 82 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/src/google/adk/tools/_gemini_schema_util.py b/src/google/adk/tools/_gemini_schema_util.py index 6a05f6c6..595b41a0 100644 --- a/src/google/adk/tools/_gemini_schema_util.py +++ b/src/google/adk/tools/_gemini_schema_util.py @@ -152,6 +152,13 @@ def _sanitize_schema_formats_for_gemini( ) for item in schema ] + # JSON Schema allows boolean schemas: `true` (accept any value) and `false` + # (reject all values). Gemini has no equivalent for either. `true` is + # approximated as an unconstrained object schema; `false` has no meaningful + # Gemini representation and is also mapped to an object schema as a safe + # fallback so that schema conversion does not crash. + if isinstance(schema, bool): + return {"type": "object"} if not isinstance(schema, dict): return schema diff --git a/tests/unittests/tools/test_gemini_schema_util.py b/tests/unittests/tools/test_gemini_schema_util.py index d8445ab8..b7091903 100644 --- a/tests/unittests/tools/test_gemini_schema_util.py +++ b/tests/unittests/tools/test_gemini_schema_util.py @@ -648,6 +648,88 @@ class TestToGeminiSchema: assert gemini_schema.type == Type.OBJECT assert gemini_schema.properties is None + def test_to_gemini_schema_boolean_true_property(self): + """Tests that a JSON Schema boolean `true` property is handled. + + JSON Schema allows `true` as a schema meaning "accept any value". + Some MCP servers use this pattern for fields whose content is not + further constrained. + """ + openapi_schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "refId": {"type": "string"}, + "model": True, # JSON Schema boolean schema + }, + }, + } + }, + } + gemini_schema = _to_gemini_schema(openapi_schema) + assert isinstance(gemini_schema, Schema) + items_schema = gemini_schema.properties["items"] + assert items_schema.type == Type.ARRAY + # `model: true` should be converted to an object schema + model_schema = items_schema.items.properties["model"] + assert model_schema.type == Type.OBJECT + + def test_to_gemini_schema_boolean_false_property(self): + """Tests that a JSON Schema boolean `false` property does not raise. + + `false` means "no value is valid" in JSON Schema, which has no Gemini + equivalent. Conversion falls back to an object schema to avoid crashing; + the result is semantically imprecise but safe. + """ + openapi_schema = { + "type": "object", + "properties": { + "anything": False, # JSON Schema boolean schema (reject all) + }, + } + # Should not raise even though `false` has no Gemini equivalent. + gemini_schema = _to_gemini_schema(openapi_schema) + assert isinstance(gemini_schema, Schema) + assert gemini_schema.properties["anything"] is not None + + def test_to_gemini_schema_boolean_true_in_array_items_properties(self): + """Regression test: boolean `true` schema inside array item properties. + + Some MCP servers use `"field": true` in an array item's properties to + indicate an unconstrained field, which is valid JSON Schema. + """ + openapi_schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "datasourceUid": {"type": "string"}, + "model": True, + "queryType": {"type": "string"}, + "refId": {"type": "string"}, + }, + }, + }, + }, + "required": ["title", "data"], + } + # Should not raise a ValidationError + gemini_schema = _to_gemini_schema(openapi_schema) + assert isinstance(gemini_schema, Schema) + assert gemini_schema.type == Type.OBJECT + data_schema = gemini_schema.properties["data"] + assert data_schema.type == Type.ARRAY + model_schema = data_schema.items.properties["model"] + assert model_schema.type == Type.OBJECT + class TestToSnakeCase: From 35366f4e2a0575090fe12cd85f51e8116a1cd0d3 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Wed, 25 Feb 2026 12:37:01 -0800 Subject: [PATCH 041/102] feat: Warn when accessing DEFAULT_SKILL_SYSTEM_INSTRUCTION This change makes DEFAULT_SKILL_SYSTEM_INSTRUCTION raise a UserWarning when accessed, indicating that its content is experimental and subject to change in minor/patch releases. The constant is also made "private" internally. Co-authored-by: Haiyuan Cao PiperOrigin-RevId: 875288217 --- src/google/adk/tools/skill_toolset.py | 18 ++++++++++++++++-- tests/unittests/tools/test_skill_toolset.py | 8 ++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index d13481eb..12411b41 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -24,6 +24,7 @@ import logging from typing import Any from typing import Optional from typing import TYPE_CHECKING +import warnings from google.genai import types @@ -46,7 +47,7 @@ logger = logging.getLogger("google_adk." + __name__) _DEFAULT_SCRIPT_TIMEOUT = 300 _MAX_SKILL_PAYLOAD_BYTES = 16 * 1024 * 1024 # 16 MB -DEFAULT_SKILL_SYSTEM_INSTRUCTION = """You can use specialized 'skills' to help you with complex tasks. You MUST use the skill tools to interact with these skills. +_DEFAULT_SKILL_SYSTEM_INSTRUCTION = """You can use specialized 'skills' to help you with complex tasks. You MUST use the skill tools to interact with these skills. Skills are folders of instructions and resources that extend your capabilities for specialized tasks. Each skill folder contains: - **SKILL.md** (required): The main instruction file with skill metadata and detailed markdown instructions. @@ -638,6 +639,19 @@ class SkillToolset(BaseToolset): skills = self._list_skills() skills_xml = prompt.format_skills_as_xml(skills) instructions = [] - instructions.append(DEFAULT_SKILL_SYSTEM_INSTRUCTION) + instructions.append(_DEFAULT_SKILL_SYSTEM_INSTRUCTION) instructions.append(skills_xml) llm_request.append_instructions(instructions) + + +def __getattr__(name: str) -> Any: + if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": + warnings.warn( + "DEFAULT_SKILL_SYSTEM_INSTRUCTION is experimental. Its content " + "is internal implementation and will change in minor/patch releases " + "to tune agent performance.", + UserWarning, + stacklevel=2, + ) + return _DEFAULT_SKILL_SYSTEM_INSTRUCTION + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index 65323324..cbccecdb 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -327,6 +327,14 @@ async def test_process_llm_request( assert "skill2" in instructions[1] +def test_default_skill_system_instruction_warning(): + with pytest.warns( + UserWarning, match="DEFAULT_SKILL_SYSTEM_INSTRUCTION is experimental" + ): + instruction = skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + assert "specialized 'skills'" in instruction + + def test_duplicate_skill_name_raises(mock_skill1): skill_dup = mock.create_autospec(models.Skill, instance=True) skill_dup.name = "skill1" From d7cfd8fe4def2198c113ff1993ef39cd519908a1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 25 Feb 2026 12:45:52 -0800 Subject: [PATCH 042/102] fix: Decode image data from ComputerUse tool response into image blobs PiperOrigin-RevId: 875292001 --- src/google/adk/flows/llm_flows/functions.py | 57 ++++++++++++++- .../flows/llm_flows/test_functions_simple.py | 70 ++++++++++++++++--- 2 files changed, 118 insertions(+), 9 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 4d045fac..c228eafb 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -17,6 +17,8 @@ from __future__ import annotations import asyncio +import base64 +import binascii from concurrent.futures import ThreadPoolExecutor import copy import functools @@ -31,6 +33,7 @@ from typing import Optional from typing import TYPE_CHECKING import uuid +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool from google.genai import types from ...agents.active_streaming_tool import ActiveStreamingTool @@ -991,6 +994,50 @@ def _get_tool_and_context( return (tool, tool_context) +def _try_decode_computer_use_image( + tool: BaseTool, + function_result: dict[str, object], +) -> Optional[list[types.FunctionResponsePart]]: + """Decodes the image from the function result for a computer use tool. + + Args: + tool: The tool that produced the function result. + function_result: The dictionary containing the function's result. This + dictionary may be modified in-place to remove the 'image' key if an image + is successfully decoded. + + Returns: + A list containing a `types.FunctionResponsePart` with the decoded image + data, or None if no image was found or decoding failed. + """ + + if not isinstance(tool, ComputerUseTool) or not isinstance( + function_result, dict + ): + return None + + if ( + 'image' not in function_result + or 'data' not in function_result['image'] + or 'mimetype' not in function_result['image'] + ): + return None + + try: + image_data = base64.b64decode(function_result['image']['data']) + mime_type = function_result['image']['mimetype'] + + part = types.FunctionResponsePart.from_bytes( + data=image_data, mime_type=mime_type + ) + + del function_result['image'] + return [part] + except (binascii.Error, ValueError): + logger.exception('Failed to decode image from computer use tool') + return None + + async def __call_tool_live( tool: BaseTool, args: dict[str, object], @@ -1028,8 +1075,16 @@ def __build_response_event( if not isinstance(function_result, dict): function_result = {'result': function_result} + function_response_parts = None + if isinstance(tool, ComputerUseTool): + function_response_parts = _try_decode_computer_use_image( + tool, function_result + ) + part_function_response = types.Part.from_function_response( - name=tool.name, response=function_result + name=tool.name, + response=function_result, + parts=function_response_parts, ) part_function_response.function_response.id = tool_context.function_call_id diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index 93f8c151..7aacb237 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -19,7 +19,10 @@ from typing import Callable from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.flows.llm_flows.functions import find_matching_function_call +from google.adk.flows.llm_flows.functions import handle_function_calls_async +from google.adk.flows.llm_flows.functions import handle_function_calls_live from google.adk.flows.llm_flows.functions import merge_parallel_function_response_events +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool from google.adk.tools.function_tool import FunctionTool from google.adk.tools.tool_context import ToolContext from google.genai import types @@ -397,8 +400,6 @@ def test_find_function_call_event_multiple_function_responses(): @pytest.mark.asyncio async def test_function_call_args_not_modified(): """Test that function_call.args is not modified when making a copy.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(**kwargs) -> dict: return {'result': 'test'} @@ -455,8 +456,6 @@ async def test_function_call_args_not_modified(): @pytest.mark.asyncio async def test_function_call_args_none_handling(): """Test that function_call.args=None is handled correctly.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(**kwargs) -> dict: return {'result': 'test'} @@ -504,8 +503,6 @@ async def test_function_call_args_none_handling(): @pytest.mark.asyncio async def test_function_call_args_copy_behavior(): """Test that modifying the copied args doesn't affect the original.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(test_param: str, other_param: int) -> dict: # Modify the args to test that the copy prevents affecting the original @@ -565,8 +562,6 @@ async def test_function_call_args_copy_behavior(): @pytest.mark.asyncio async def test_function_call_args_deep_copy_behavior(): """Test that deep copy behavior works correctly with nested structures.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(nested_dict: dict, list_param: list) -> dict: # Modify the nested structures to test deep copy @@ -1141,3 +1136,62 @@ async def test_mixed_function_types_execution_order(): 'yield_E', 'yield_F', ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'handle_function_calls', + [ + (handle_function_calls_async), + (handle_function_calls_live), + ], +) +async def test_computer_use_tool_decoding_behavior(handle_function_calls): + """Tests that computer use tools automatically decode base64 images.""" + valid_b64 = 'R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7' + + # make the tool return a dictionary with the image + async def mock_run(*args, **kwargs): + return { + 'image': {'data': valid_b64, 'mimetype': 'image/png'}, + 'url': 'https://example.com', + } + + # create a ComputerUseTool + tool = ComputerUseTool(func=mock_run, screen_size=(1024, 768)) + + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name='test_agent', + model=model, + tools=[tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + # Create function call + function_call = types.FunctionCall(name=tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + result = await handle_function_calls( + invocation_context, + event, + tools_dict, + ) + + assert result is not None + response_part = result.content.parts[0].function_response + + # Verify original image data is removed from the dict response + assert 'image' not in response_part.response + assert 'url' in response_part.response + # Verify the image was converted to a blob + assert len(response_part.parts) == 1 + assert response_part.parts[0].inline_data is not None From 4460f4fadaffa667e796c0ec8299c65c68836203 Mon Sep 17 00:00:00 2001 From: Yifan Wang Date: Wed, 25 Feb 2026 13:47:11 -0800 Subject: [PATCH 043/102] chore: add /dev/build_graph/{app_name} to build the graph serialization for apps, and make it dev only with `with_ui` flag Co-authored-by: Yifan Wang PiperOrigin-RevId: 875319398 --- src/google/adk/cli/adk_web_server.py | 101 +++++++++++++++++++++++++-- 1 file changed, 94 insertions(+), 7 deletions(-) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 48587bd5..e032178e 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -588,7 +588,8 @@ class AdkWebServer: """Import a plugin object (class or instance) from a fully qualified name. Args: - qualified_name: Fully qualified name (e.g., 'my_package.my_plugin.MyPlugin') + qualified_name: Fully qualified name (e.g., + 'my_package.my_plugin.MyPlugin') Returns: The imported object, which can be either a class or an instance. @@ -688,6 +689,7 @@ class AdkWebServer: ] = lambda o, s: None, register_processors: Callable[[TracerProvider], None] = lambda o: None, otel_to_cloud: bool = False, + with_ui: bool = False, ): """Creates a FastAPI app for the ADK web server. @@ -700,7 +702,8 @@ class AdkWebServer: lifespan: The lifespan of the FastAPI app. allow_origins: The origins that are allowed to make cross-origin requests. Entries can be literal origins (e.g., 'https://example.com') or regex - patterns prefixed with 'regex:' (e.g., 'regex:https://.*\\.example\\.com'). + patterns prefixed with 'regex:' (e.g., + 'regex:https://.*\\.example\\.com'). web_assets_dir: The directory containing the web assets to serve. setup_observer: Callback for setting up the file system observer. tear_down_observer: Callback for cleaning up the file system observer. @@ -795,10 +798,93 @@ class AdkWebServer: raise HTTPException(status_code=404, detail="Trace not found") return event_dict - @app.get("/apps/{app_name}") - async def get_app_info(app_name: str) -> Any: - runner = await self.get_runner_async(app_name) - return runner.app + if web_assets_dir: + + @app.get("/dev/build_graph/{app_name}") + async def get_app_info(app_name: str) -> Any: + runner = await self.get_runner_async(app_name) + + if not runner.app: + raise HTTPException( + status_code=404, detail=f"App not found: {app_name}" + ) + + def serialize_agent(agent: BaseAgent) -> dict[str, Any]: + """Recursively serialize an agent, excluding non-serializable fields.""" + agent_dict = {} + + for field_name, field_info in agent.__class__.model_fields.items(): + # Skip non-serializable fields + if field_name in [ + "parent_agent", + "before_agent_callback", + "after_agent_callback", + "before_model_callback", + "after_model_callback", + "on_model_error_callback", + "before_tool_callback", + "after_tool_callback", + "on_tool_error_callback", + ]: + continue + + value = getattr(agent, field_name, None) + + # Handle sub_agents recursively + if field_name == "sub_agents" and value: + agent_dict[field_name] = [ + serialize_agent(sub_agent) for sub_agent in value + ] + elif value is None or field_name == "tools": + continue + else: + try: + if isinstance(value, (str, int, float, bool, list, dict)): + agent_dict[field_name] = value + elif hasattr(value, "model_dump"): + agent_dict[field_name] = value.model_dump( + mode="python", exclude_none=True + ) + else: + agent_dict[field_name] = str(value) + except Exception: + pass + + return agent_dict + + app_info = { + "name": runner.app.name, + "root_agent": serialize_agent(runner.app.root_agent), + } + + # Add optional fields if present + if runner.app.plugins: + app_info["plugins"] = [ + {"name": getattr(plugin, "name", type(plugin).__name__)} + for plugin in runner.app.plugins + ] + + if runner.app.context_cache_config: + try: + app_info["context_cache_config"] = ( + runner.app.context_cache_config.model_dump( + mode="python", exclude_none=True + ) + ) + except Exception: + pass + + if runner.app.resumability_config: + try: + app_info["resumability_config"] = ( + runner.app.resumability_config.model_dump( + mode="python", exclude_none=True + ) + ) + except Exception: + pass + + return app_info @app.get("/debug/trace/session/{session_id}", tags=[TAG_DEBUG]) async def get_session_trace(session_id: str) -> Any: @@ -1534,7 +1620,8 @@ class AdkWebServer: update_memory_request: The memory request for the update Raises: - HTTPException: If the memory service is not configured or the request is invalid. + HTTPException: If the memory service is not configured or the request + is invalid. """ if not self.memory_service: raise HTTPException( From 9730bc34d7d19bd1554896231566c316c8d61d72 Mon Sep 17 00:00:00 2001 From: Google Admin Date: Wed, 25 Feb 2026 14:09:43 -0800 Subject: [PATCH 044/102] Refactor Github Action per b/485167538 (#4535) Co-authored-by: Ben Knutson --- .github/workflows/check-file-contents.yml | 14 +++++++------- .github/workflows/isort.yml | 4 ++-- .github/workflows/pyink.yml | 4 ++-- .github/workflows/release-cherry-pick.yml | 6 ++++-- .github/workflows/release-finalize.yml | 4 +++- .github/workflows/release-publish.yml | 11 ++++++----- 6 files changed, 24 insertions(+), 19 deletions(-) diff --git a/.github/workflows/check-file-contents.yml b/.github/workflows/check-file-contents.yml index 7670733e..8e506d92 100644 --- a/.github/workflows/check-file-contents.yml +++ b/.github/workflows/check-file-contents.yml @@ -30,8 +30,8 @@ jobs: - name: Check for logger pattern in all changed Python files run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -61,8 +61,8 @@ jobs: - name: Check for import pattern in certain changed Python files run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -88,8 +88,8 @@ jobs: - name: Check for import from cli package in certain changed Python files run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -110,4 +110,4 @@ jobs: fi else echo "✅ No relevant Python files found." - fi \ No newline at end of file + fi diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index 49536911..840d4ea8 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -42,8 +42,8 @@ jobs: - name: Run isort on changed files id: run_isort run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files:" echo "$CHANGED_FILES" diff --git a/.github/workflows/pyink.yml b/.github/workflows/pyink.yml index d2eac1da..a2d9e6d7 100644 --- a/.github/workflows/pyink.yml +++ b/.github/workflows/pyink.yml @@ -42,8 +42,8 @@ jobs: - name: Run pyink on changed files id: run_pyink run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files:" echo "$CHANGED_FILES" diff --git a/.github/workflows/release-cherry-pick.yml b/.github/workflows/release-cherry-pick.yml index bf2247c4..ac5e5c08 100644 --- a/.github/workflows/release-cherry-pick.yml +++ b/.github/workflows/release-cherry-pick.yml @@ -30,8 +30,10 @@ jobs: - name: Cherry-pick commit run: | - echo "Cherry-picking ${{ inputs.commit_sha }} to release/candidate" - git cherry-pick ${{ inputs.commit_sha }} + echo "Cherry-picking ${INPUTS_COMMIT_SHA} to release/candidate" + git cherry-pick ${INPUTS_COMMIT_SHA} + env: + INPUTS_COMMIT_SHA: ${{ inputs.commit_sha }} - name: Push changes run: | diff --git a/.github/workflows/release-finalize.yml b/.github/workflows/release-finalize.yml index ade58ec2..b9d6203f 100644 --- a/.github/workflows/release-finalize.yml +++ b/.github/workflows/release-finalize.yml @@ -68,9 +68,11 @@ jobs: - name: Rename release/candidate to release/v{version} if: steps.check.outputs.is_release_pr == 'true' run: | - VERSION="v${{ steps.version.outputs.version }}" + VERSION="v${STEPS_VERSION_OUTPUTS_VERSION}" git push origin "release/candidate:refs/heads/release/$VERSION" ":release/candidate" echo "Renamed release/candidate to release/$VERSION" + env: + STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} - name: Update PR label to tagged if: steps.check.outputs.is_release_pr == 'true' diff --git a/.github/workflows/release-publish.yml b/.github/workflows/release-publish.yml index 5979cd9c..95ee326a 100644 --- a/.github/workflows/release-publish.yml +++ b/.github/workflows/release-publish.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Validate branch run: | - if [[ ! "${{ github.ref_name }}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + if [[ ! "${GITHUB_REF_NAME}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then echo "Error: Must run from a release/v* branch (e.g., release/v0.3.0)" exit 1 fi @@ -23,7 +23,7 @@ jobs: - name: Extract version id: version run: | - VERSION="${{ github.ref_name }}" + VERSION="${GITHUB_REF_NAME}" VERSION="${VERSION#release/v}" echo "version=$VERSION" >> $GITHUB_OUTPUT echo "Publishing version: $VERSION" @@ -51,9 +51,10 @@ jobs: - name: Create merge-back PR env: GH_TOKEN: ${{ secrets.RELEASE_PAT }} + STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} run: | gh pr create \ --base main \ - --head "${{ github.ref_name }}" \ - --title "chore: merge release v${{ steps.version.outputs.version }} to main" \ - --body "Syncs version bump and CHANGELOG from release v${{ steps.version.outputs.version }} to main." + --head "${GITHUB_REF_NAME}" \ + --title "chore: merge release v${STEPS_VERSION_OUTPUTS_VERSION} to main" \ + --body "Syncs version bump and CHANGELOG from release v${STEPS_VERSION_OUTPUTS_VERSION} to main." From 5702a4b1f59b17fd8b290fc125c349240b0953d7 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 25 Feb 2026 16:11:25 -0800 Subject: [PATCH 045/102] feat: Add param support to Bigtable execute_sql PiperOrigin-RevId: 875382675 --- .github/workflows/check-file-contents.yml | 14 ++-- .github/workflows/isort.yml | 4 +- .github/workflows/pyink.yml | 4 +- .github/workflows/release-cherry-pick.yml | 6 +- .github/workflows/release-finalize.yml | 4 +- .github/workflows/release-publish.yml | 11 ++-- contributing/samples/bigtable/agent.py | 66 +++++++++++++++++-- src/google/adk/tools/bigtable/query_tool.py | 34 ++++++---- .../bigtable/test_bigtable_query_tool.py | 54 ++++++++++++++- 9 files changed, 154 insertions(+), 43 deletions(-) diff --git a/.github/workflows/check-file-contents.yml b/.github/workflows/check-file-contents.yml index 8e506d92..7670733e 100644 --- a/.github/workflows/check-file-contents.yml +++ b/.github/workflows/check-file-contents.yml @@ -30,8 +30,8 @@ jobs: - name: Check for logger pattern in all changed Python files run: | - git fetch origin ${GITHUB_BASE_REF} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) + git fetch origin ${{ github.base_ref }} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -61,8 +61,8 @@ jobs: - name: Check for import pattern in certain changed Python files run: | - git fetch origin ${GITHUB_BASE_REF} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) + git fetch origin ${{ github.base_ref }} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -88,8 +88,8 @@ jobs: - name: Check for import from cli package in certain changed Python files run: | - git fetch origin ${GITHUB_BASE_REF} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) + git fetch origin ${{ github.base_ref }} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -110,4 +110,4 @@ jobs: fi else echo "✅ No relevant Python files found." - fi + fi \ No newline at end of file diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index 840d4ea8..49536911 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -42,8 +42,8 @@ jobs: - name: Run isort on changed files id: run_isort run: | - git fetch origin ${GITHUB_BASE_REF} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) + git fetch origin ${{ github.base_ref }} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files:" echo "$CHANGED_FILES" diff --git a/.github/workflows/pyink.yml b/.github/workflows/pyink.yml index a2d9e6d7..d2eac1da 100644 --- a/.github/workflows/pyink.yml +++ b/.github/workflows/pyink.yml @@ -42,8 +42,8 @@ jobs: - name: Run pyink on changed files id: run_pyink run: | - git fetch origin ${GITHUB_BASE_REF} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) + git fetch origin ${{ github.base_ref }} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files:" echo "$CHANGED_FILES" diff --git a/.github/workflows/release-cherry-pick.yml b/.github/workflows/release-cherry-pick.yml index ac5e5c08..bf2247c4 100644 --- a/.github/workflows/release-cherry-pick.yml +++ b/.github/workflows/release-cherry-pick.yml @@ -30,10 +30,8 @@ jobs: - name: Cherry-pick commit run: | - echo "Cherry-picking ${INPUTS_COMMIT_SHA} to release/candidate" - git cherry-pick ${INPUTS_COMMIT_SHA} - env: - INPUTS_COMMIT_SHA: ${{ inputs.commit_sha }} + echo "Cherry-picking ${{ inputs.commit_sha }} to release/candidate" + git cherry-pick ${{ inputs.commit_sha }} - name: Push changes run: | diff --git a/.github/workflows/release-finalize.yml b/.github/workflows/release-finalize.yml index b9d6203f..ade58ec2 100644 --- a/.github/workflows/release-finalize.yml +++ b/.github/workflows/release-finalize.yml @@ -68,11 +68,9 @@ jobs: - name: Rename release/candidate to release/v{version} if: steps.check.outputs.is_release_pr == 'true' run: | - VERSION="v${STEPS_VERSION_OUTPUTS_VERSION}" + VERSION="v${{ steps.version.outputs.version }}" git push origin "release/candidate:refs/heads/release/$VERSION" ":release/candidate" echo "Renamed release/candidate to release/$VERSION" - env: - STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} - name: Update PR label to tagged if: steps.check.outputs.is_release_pr == 'true' diff --git a/.github/workflows/release-publish.yml b/.github/workflows/release-publish.yml index 95ee326a..5979cd9c 100644 --- a/.github/workflows/release-publish.yml +++ b/.github/workflows/release-publish.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Validate branch run: | - if [[ ! "${GITHUB_REF_NAME}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + if [[ ! "${{ github.ref_name }}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then echo "Error: Must run from a release/v* branch (e.g., release/v0.3.0)" exit 1 fi @@ -23,7 +23,7 @@ jobs: - name: Extract version id: version run: | - VERSION="${GITHUB_REF_NAME}" + VERSION="${{ github.ref_name }}" VERSION="${VERSION#release/v}" echo "version=$VERSION" >> $GITHUB_OUTPUT echo "Publishing version: $VERSION" @@ -51,10 +51,9 @@ jobs: - name: Create merge-back PR env: GH_TOKEN: ${{ secrets.RELEASE_PAT }} - STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} run: | gh pr create \ --base main \ - --head "${GITHUB_REF_NAME}" \ - --title "chore: merge release v${STEPS_VERSION_OUTPUTS_VERSION} to main" \ - --body "Syncs version bump and CHANGELOG from release v${STEPS_VERSION_OUTPUTS_VERSION} to main." + --head "${{ github.ref_name }}" \ + --title "chore: merge release v${{ steps.version.outputs.version }} to main" \ + --body "Syncs version bump and CHANGELOG from release v${{ steps.version.outputs.version }} to main." diff --git a/contributing/samples/bigtable/agent.py b/contributing/samples/bigtable/agent.py index d35f51c1..6b4a50fc 100644 --- a/contributing/samples/bigtable/agent.py +++ b/contributing/samples/bigtable/agent.py @@ -16,14 +16,17 @@ import os from google.adk.agents.llm_agent import LlmAgent from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.tools.bigtable import query_tool as bigtable_query_tool from google.adk.tools.bigtable.bigtable_credentials import BigtableCredentialsConfig from google.adk.tools.bigtable.bigtable_toolset import BigtableToolset from google.adk.tools.bigtable.settings import BigtableToolSettings +from google.adk.tools.google_tool import GoogleTool import google.auth +from google.cloud.bigtable.data.execute_query.metadata import SqlType -# Define an appropriate credential type -CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2 - +# Define an appropriate credential type. +# None for Application Default Credentials +CREDENTIALS_TYPE = None # Define Bigtable tool config with read capability set to allowed. tool_settings = BigtableToolSettings() @@ -59,6 +62,53 @@ bigtable_toolset = BigtableToolset( credentials_config=credentials_config, bigtable_tool_settings=tool_settings ) +_BIGTABLE_PROJECT_ID = "google.com:cloud-bigtable-dev" +_BIGTABLE_INSTANCE_ID = "annenguyen-bus-instance" + + +def search_hotels_by_location( + location_name: str, + credentials: google.auth.credentials.Credentials, + settings: BigtableToolSettings, + tool_context: google.adk.tools.tool_context.ToolContext, +): + """Search hotels by location name. + + This function takes a location name and returns a list of hotels + in that area. + + Args: + location_name (str): The geographical location (e.g., city or town) for the + hotel search. + Example: { "location_name": "Basel" } + + Returns: + The hotels name, price tier. + """ + + sql_template = """ + SELECT + TO_INT64(cf['id']) as id, + CAST(cf['name'] AS STRING) AS name, + CAST(cf['location'] AS STRING) AS location, + CAST(cf['price_tier'] AS STRING) AS price_tier, + CAST(cf['checkin_date'] AS STRING) AS checkin_date, + CAST(cf['checkout_date'] AS STRING) AS checkout_date + FROM hotels + WHERE LOWER(CAST(cf['location'] AS STRING)) LIKE LOWER(CONCAT('%', @location_name, '%')) + """ + return bigtable_query_tool.execute_sql( + project_id=_BIGTABLE_PROJECT_ID, + instance_id=_BIGTABLE_INSTANCE_ID, + query=sql_template, + credentials=credentials, + settings=settings, + tool_context=tool_context, + parameters={"location": location_name}, + parameter_types={"location": SqlType.String()}, + ) + + # The variable name `root_agent` determines what your root agent is for the # debug CLI root_agent = LlmAgent( @@ -72,5 +122,13 @@ root_agent = LlmAgent( You are a data agent with access to several Bigtable tools. Make use of those tools to answer the user's questions. """, - tools=[bigtable_toolset], + tools=[ + bigtable_toolset, + # Or, uncomment to use customized Bigtable tools. + # GoogleTool( + # func=search_hotels_by_location, + # credentials_config=credentials_config, + # tool_settings=tool_settings, + # ), + ], ) diff --git a/src/google/adk/tools/bigtable/query_tool.py b/src/google/adk/tools/bigtable/query_tool.py index a7a785a2..63267f01 100644 --- a/src/google/adk/tools/bigtable/query_tool.py +++ b/src/google/adk/tools/bigtable/query_tool.py @@ -22,7 +22,6 @@ from typing import Dict from typing import List from google.auth.credentials import Credentials -from google.cloud import bigtable from . import client from ..tool_context import ToolContext @@ -40,6 +39,8 @@ def execute_sql( credentials: Credentials, settings: BigtableToolSettings, tool_context: ToolContext, + parameters: Dict[str, Any] | None = None, + parameter_types: Dict[str, Any] | None = None, ) -> dict: """Execute a GoogleSQL query from a Bigtable table. @@ -51,6 +52,10 @@ def execute_sql( credentials (Credentials): The credentials to use for the request. settings (BigtableToolSettings): The configuration for the tool. tool_context (ToolContext): The context for the tool. + parameters (dict): properties for parameter replacement. Keys must match + the names used in ``query``. + parameter_types (dict): maps explicit types for one or more param values. + Returns: dict: Dictionary containing the status and the rows read. If the result contains the key "result_is_likely_truncated" with @@ -59,18 +64,19 @@ def execute_sql( Examples: Fetch data or insights from a table: - - >>> execute_sql("my_project", "my_instance", - ... "SELECT * from mytable", credentials, config, tool_context) - { - "status": "SUCCESS", - "rows": [ - { - "user_id": 1, - "user_name": "Alice" - } - ] - } + + >>> execute_sql("my_project", "my_instance", + ... "SELECT * from mytable", credentials, config, tool_context) + { + "status": "SUCCESS", + "rows": [ + { + "user_id": 1, + "user_name": "Alice" + } + ] + } + """ del tool_context # Unused for now @@ -81,6 +87,8 @@ def execute_sql( eqi = bt_client.execute_query( query=query, instance_id=instance_id, + parameters=parameters, + parameter_types=parameter_types, ) rows: List[Dict[str, Any]] = [] diff --git a/tests/unittests/tools/bigtable/test_bigtable_query_tool.py b/tests/unittests/tools/bigtable/test_bigtable_query_tool.py index 0bd0fedc..abcef88e 100644 --- a/tests/unittests/tools/bigtable/test_bigtable_query_tool.py +++ b/tests/unittests/tools/bigtable/test_bigtable_query_tool.py @@ -62,7 +62,10 @@ def test_execute_sql_basic(): expected_rows = [{"col1": "val1", "col2": 123}] assert result == {"status": "SUCCESS", "rows": expected_rows} mock_client.execute_query.assert_called_once_with( - query=query, instance_id=instance_id + query=query, + instance_id=instance_id, + parameters=None, + parameter_types=None, ) mock_iterator.close.assert_called_once() @@ -106,7 +109,10 @@ def test_execute_sql_truncated(): "result_is_likely_truncated": True, } mock_client.execute_query.assert_called_once_with( - query=query, instance_id=instance_id + query=query, + instance_id=instance_id, + parameters=None, + parameter_types=None, ) mock_iterator.close.assert_called_once() @@ -169,3 +175,47 @@ def test_execute_sql_row_value_circular_reference_fallback(): assert result["status"] == "SUCCESS" assert result["rows"][0]["col1"] == str(circular_value) + + +def test_execute_sql_with_parameters(): + """Test execute_sql tool with parameters and parameter_types.""" + project = "my_project" + instance_id = "my_instance" + query = "SELECT * FROM my_table WHERE col1 = @param1" + credentials = mock.create_autospec(Credentials, instance=True) + tool_context = mock.create_autospec(ToolContext, instance=True) + parameters = {"param1": "val1"} + parameter_types = {"param1": "string"} + + with mock.patch( + "google.adk.tools.bigtable.client.get_bigtable_data_client" + ) as mock_get_client: + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) + mock_client.execute_query.return_value = mock_iterator + + # Mock row data + mock_row = mock.MagicMock() + mock_row.fields = {"col1": "val1"} + mock_iterator.__iter__.return_value = [mock_row] + + result = execute_sql( + project_id=project, + instance_id=instance_id, + credentials=credentials, + query=query, + settings=BigtableToolSettings(), + tool_context=tool_context, + parameters=parameters, + parameter_types=parameter_types, + ) + + assert result["status"] == "SUCCESS" + mock_client.execute_query.assert_called_once_with( + query=query, + instance_id=instance_id, + parameters=parameters, + parameter_types=parameter_types, + ) + mock_iterator.close.assert_called_once() From de4dee899cd777a01ba15906f8496a72e717ea98 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 25 Feb 2026 17:17:58 -0800 Subject: [PATCH 046/102] fix: Re-export DEFAULT_SKILL_SYSTEM_INSTRUCTION to skills and skill/prompt.py to avoid breaking current users Co-authored-by: Kathy Wu PiperOrigin-RevId: 875407169 --- src/google/adk/skills/__init__.py | 22 ++++++++++++++++++++++ src/google/adk/skills/prompt.py | 20 ++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/src/google/adk/skills/__init__.py b/src/google/adk/skills/__init__.py index 72bab7b6..86724bd0 100644 --- a/src/google/adk/skills/__init__.py +++ b/src/google/adk/skills/__init__.py @@ -14,6 +14,9 @@ """Agent Development Kit - Skills.""" +from typing import Any +import warnings + from ._utils import _load_skill_from_dir as load_skill_from_dir from .models import Frontmatter from .models import Resources @@ -21,9 +24,28 @@ from .models import Script from .models import Skill __all__ = [ + "DEFAULT_SKILL_SYSTEM_INSTRUCTION", "Frontmatter", "Resources", "Script", "Skill", "load_skill_from_dir", ] + + +def __getattr__(name: str) -> Any: + if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": + + from ..tools import skill_toolset + + warnings.warn( + ( + "Importing DEFAULT_SKILL_SYSTEM_INSTRUCTION from" + " google.adk.skills is deprecated." + " Please import it from google.adk.tools.skill_toolset instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/google/adk/skills/prompt.py b/src/google/adk/skills/prompt.py index 110033cd..3c352036 100644 --- a/src/google/adk/skills/prompt.py +++ b/src/google/adk/skills/prompt.py @@ -17,8 +17,10 @@ from __future__ import annotations import html +from typing import Any from typing import List from typing import Union +import warnings from . import models @@ -54,3 +56,21 @@ def format_skills_as_xml( lines.append("") return "\n".join(lines) + + +def __getattr__(name: str) -> Any: + if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": + + from ..tools import skill_toolset + + warnings.warn( + ( + "Importing DEFAULT_SKILL_SYSTEM_INSTRUCTION from" + " google.adk.skills.prompt is deprecated." + " Please import it from google.adk.tools.skill_toolset instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + raise AttributeError(f"module {__name__} has no attribute {name}") From 5f806ed73a9631bad836278a39d37f7d62a31295 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 25 Feb 2026 18:26:49 -0800 Subject: [PATCH 047/102] chore: Refactor runner to infer invocation_id from FunctionResponse Event for HITL resuming invocation_id is no longer required in resuming case, unless no new_message is provided. Co-authored-by: Shangjie Chen PiperOrigin-RevId: 875432024 --- src/google/adk/agents/invocation_context.py | 15 ++- src/google/adk/flows/llm_flows/functions.py | 39 +++---- src/google/adk/runners.py | 101 ++++++++++++++---- .../runners/test_run_tool_confirmation.py | 80 ++++++++++++++ 4 files changed, 181 insertions(+), 54 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 4c75e1c4..35b8dc97 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -396,23 +396,20 @@ class InvocationContext(BaseModel): return False # TODO: Move this method from invocation_context to a dedicated module. - # TODO: Converge this method with find_matching_function_call in llm_flows. def _find_matching_function_call( self, function_response_event: Event ) -> Optional[Event]: """Finds the function call event in the current invocation that matches the function response id.""" + from ..flows.llm_flows.functions import find_event_by_function_call_id + function_responses = function_response_event.get_function_responses() if not function_responses: return None - function_call_id = function_responses[0].id - events = self._get_events(current_invocation=True) - # The last event is function_response_event, so we search backwards from the - # one before it. - for event in reversed(events[:-1]): - if any(fc.id == function_call_id for fc in event.get_function_calls()): - return event - return None + # Search backwards from the event before the current response event. + return find_event_by_function_call_id( + self._get_events(current_invocation=True)[:-1], function_responses[0].id + ) def new_invocation_context_id() -> str: diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index c228eafb..6082e1a7 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -37,7 +37,6 @@ from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool from google.genai import types from ...agents.active_streaming_tool import ActiveStreamingTool -from ...agents.invocation_context import InvocationContext from ...agents.live_request_queue import LiveRequestQueue from ...auth.auth_tool import AuthConfig from ...auth.auth_tool import AuthToolArguments @@ -52,6 +51,7 @@ from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing if TYPE_CHECKING: + from ...agents.invocation_context import InvocationContext from ...agents.llm_agent import LlmAgent AF_FUNCTION_CALL_ID_PREFIX = 'adk-' @@ -1157,6 +1157,18 @@ def merge_parallel_function_response_events( return merged_event +def find_event_by_function_call_id( + events: list[Event], + function_call_id: str, +) -> Optional[Event]: + """Finds the function call event that matches the function call id.""" + for event in reversed(events): + for function_call in event.get_function_calls(): + if function_call.id == function_call_id: + return event + return None + + def find_matching_function_call( events: list[Event], ) -> Optional[Event]: @@ -1165,25 +1177,8 @@ def find_matching_function_call( return None last_event = events[-1] - if ( - last_event.content - and last_event.content.parts - and any(part.function_response for part in last_event.content.parts) - ): + function_responses = last_event.get_function_responses() + if not function_responses: + return None - function_call_id = next( - part.function_response.id - for part in last_event.content.parts - if part.function_response - ) - for i in range(len(events) - 2, -1, -1): - event = events[i] - # looking for the system long-running request euc function call - function_calls = event.get_function_calls() - if not function_calls: - continue - - for function_call in function_calls: - if function_call.id == function_call_id: - return event - return None + return find_event_by_function_call_id(events[:-1], function_responses[0].id) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 736859fb..22011974 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -49,6 +49,7 @@ from .errors.session_not_found_error import SessionNotFoundError from .events.event import Event from .events.event import EventActions from .flows.llm_flows import contents +from .flows.llm_flows.functions import find_event_by_function_call_id from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService from .memory.in_memory_memory_service import InMemoryMemoryService @@ -70,6 +71,16 @@ def _is_tool_call_or_response(event: Event) -> bool: return bool(event.get_function_calls() or event.get_function_responses()) +def _get_function_responses_from_content( + content: types.Content, +) -> list[types.FunctionResponse]: + if not content: + return [] + return [ + part.function_response for part in content.parts if part.function_response + ] + + def _is_transcription(event: Event) -> bool: return ( event.input_transcription is not None @@ -341,6 +352,35 @@ class Runner: self._app_name_alignment_hint = f'{mismatch_details} {resolution}' logger.warning('App name mismatch detected. %s', mismatch_details) + def _resolve_invocation_id( + self, + session: Session, + new_message: Optional[types.Content], + invocation_id: Optional[str], + ) -> Optional[str]: + """Infers invocation_id from new_message if it is a function response.""" + function_responses = _get_function_responses_from_content(new_message) + if not function_responses: + return invocation_id + + fc_event = find_event_by_function_call_id( + session.events, function_responses[0].id + ) + if not fc_event: + raise ValueError( + 'Function call event not found for function response id:' + f' {function_responses[0].id}' + ) + + if invocation_id and invocation_id != fc_event.invocation_id: + logger.warning( + 'Provided invocation_id %s is ignored because new_message has a ' + 'function response with invocation_id %s.', + invocation_id, + fc_event.invocation_id, + ) + return fc_event.invocation_id + def _format_session_not_found_message(self, session_id: str) -> str: message = f'Session not found: {session_id}' if not self._app_name_alignment_hint: @@ -497,6 +537,7 @@ class Runner: session = await self._get_or_create_session( user_id=user_id, session_id=session_id ) + if not invocation_id and not new_message: raise ValueError( 'Running an agent requires either a new_message or an ' @@ -504,35 +545,49 @@ class Runner: f'Session: {session_id}, User: {user_id}' ) - if invocation_id: - if ( - not self.resumability_config - or not self.resumability_config.is_resumable - ): - raise ValueError( - f'invocation_id: {invocation_id} is provided but the app is not' - ' resumable.' - ) - invocation_context = await self._setup_context_for_resumed_invocation( - session=session, - new_message=new_message, - invocation_id=invocation_id, - run_config=run_config, - state_delta=state_delta, + is_resumable = ( + self.resumability_config and self.resumability_config.is_resumable + ) + if not is_resumable and not new_message: + raise ValueError( + 'Running an agent requires a new_message or a resumable app. ' + f'Session: {session_id}, User: {user_id}' ) - if invocation_context.end_of_agents.get( - invocation_context.agent.name - ): - # Directly return if the current agent in invocation context is - # already final. - return - else: + + if not is_resumable: invocation_context = await self._setup_context_for_new_invocation( session=session, - new_message=new_message, # new_message is not None. + new_message=new_message, run_config=run_config, state_delta=state_delta, ) + else: + invocation_id = self._resolve_invocation_id( + session, new_message, invocation_id + ) + if not invocation_id: + invocation_context = await self._setup_context_for_new_invocation( + session=session, + new_message=new_message, + run_config=run_config, + state_delta=state_delta, + ) + else: + invocation_context = ( + await self._setup_context_for_resumed_invocation( + session=session, + new_message=new_message, + invocation_id=invocation_id, + run_config=run_config, + state_delta=state_delta, + ) + ) + if invocation_context.end_of_agents.get( + invocation_context.agent.name + ): + # Directly return if the current agent in invocation context is + # already final. + return async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing(ctx.agent.run_async(ctx)) as agen: diff --git a/tests/unittests/runners/test_run_tool_confirmation.py b/tests/unittests/runners/test_run_tool_confirmation.py index 08dfdd6f..6b12790d 100644 --- a/tests/unittests/runners/test_run_tool_confirmation.py +++ b/tests/unittests/runners/test_run_tool_confirmation.py @@ -502,6 +502,86 @@ class TestHITLConfirmationFlowWithResumableApp: == expected_parts_final ) + @pytest.mark.asyncio + async def test_pause_and_resume_on_request_confirmation_without_invocation_id( + self, + runner: testing_utils.InMemoryRunner, + agent: LlmAgent, + ): + """Tests HITL flow where all tool calls are confirmed.""" + events = runner.run("test user query") + + # Verify that the invocation is paused when tool confirmation is requested. + # The tool call returns error response, and summarization was skipped. + assert testing_utils.simplify_resumable_app_events( + copy.deepcopy(events) + ) == [ + ( + agent.name, + Part(function_call=FunctionCall(name=agent.tools[0].name, args={})), + ), + ( + agent.name, + Part( + function_call=FunctionCall( + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args={ + "originalFunctionCall": { + "name": agent.tools[0].name, + "id": mock.ANY, + "args": {}, + }, + "toolConfirmation": { + "hint": HINT_TEXT, + "confirmed": False, + }, + }, + ) + ), + ), + ( + agent.name, + Part( + function_response=FunctionResponse( + name=agent.tools[0].name, response=TOOL_CALL_ERROR_RESPONSE + ) + ), + ), + ] + ask_for_confirmation_function_call_id = ( + events[1].content.parts[0].function_call.id + ) + invocation_id = events[1].invocation_id + user_confirmation = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=ask_for_confirmation_function_call_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"confirmed": True}, + ) + ) + ) + events = await runner.run_async(user_confirmation) + expected_parts_final = [ + ( + agent.name, + Part( + function_response=FunctionResponse( + name=agent.tools[0].name, + response={"result": "confirmed=True"}, + ) + ), + ), + (agent.name, "test llm response after tool call"), + (agent.name, testing_utils.END_OF_AGENT), + ] + for event in events: + assert event.invocation_id == invocation_id + assert ( + testing_utils.simplify_resumable_app_events(copy.deepcopy(events)) + == expected_parts_final + ) + class TestHITLConfirmationFlowWithSequentialAgentAndResumableApp: """Tests the HITL confirmation flow with a resumable sequential agent app.""" From 6f772d2b0841446bc168ccf405b59eb17c1d671a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 26 Feb 2026 00:22:26 -0800 Subject: [PATCH 048/102] feat: Introduce A2A request interceptors in RemoteA2aAgent This change adds a new `a2a` subpackage with configuration and utility functions for intercepting requests and responses in `RemoteA2aAgent`. The `RemoteA2aAgent` now accepts an `A2aRemoteAgentConfig` to register `RequestInterceptor` instances, allowing custom logic to be executed before and after the A2A message send. PiperOrigin-RevId: 875559286 --- src/google/adk/a2a/agent/__init__.py | 25 +++ src/google/adk/a2a/agent/config.py | 76 +++++++ src/google/adk/a2a/agent/utils.py | 70 ++++++ src/google/adk/agents/remote_a2a_agent.py | 36 ++- .../unittests/agents/test_remote_a2a_agent.py | 206 +++++++++++++++++- 5 files changed, 406 insertions(+), 7 deletions(-) create mode 100644 src/google/adk/a2a/agent/__init__.py create mode 100644 src/google/adk/a2a/agent/config.py create mode 100644 src/google/adk/a2a/agent/utils.py diff --git a/src/google/adk/a2a/agent/__init__.py b/src/google/adk/a2a/agent/__init__.py new file mode 100644 index 00000000..8026986e --- /dev/null +++ b/src/google/adk/a2a/agent/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A2A agents package.""" + +from .config import A2aRemoteAgentConfig +from .config import ParametersConfig +from .config import RequestInterceptor + +__all__ = [ + "A2aRemoteAgentConfig", + "ParametersConfig", + "RequestInterceptor", +] diff --git a/src/google/adk/a2a/agent/config.py b/src/google/adk/a2a/agent/config.py new file mode 100644 index 00000000..e8f012cf --- /dev/null +++ b/src/google/adk/a2a/agent/config.py @@ -0,0 +1,76 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for A2A agents.""" + +from __future__ import annotations + +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Optional +from typing import Union + +from a2a.client.middleware import ClientCallContext +from a2a.server.events import Event as A2AEvent +from a2a.types import Message as A2AMessage +from a2a.types import MessageSendConfiguration +from pydantic import BaseModel + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event + + +class ParametersConfig(BaseModel): + """Configuration for the parameters passed to the A2A send_message request.""" + + request_metadata: Optional[dict[str, Any]] = None + client_call_context: Optional[ClientCallContext] = None + # TODO: Add support for requested_extension and + # message_send_configuration once they are supported by the A2A client. + # + # requested_extension: Optional[list[str]] = None + # message_send_configuration: Optional[MessageSendConfiguration] = None + + +class RequestInterceptor(BaseModel): + """Interceptor for A2A requests.""" + + before_request: Optional[ + Callable[ + [InvocationContext, A2AMessage, ParametersConfig], + Awaitable[tuple[Union[A2AMessage, Event], ParametersConfig]], + ] + ] = None + """Hook executed before the agent starts processing the request. + + Returns an Event if the request should be aborted and the Event + returned to the caller. + """ + + after_request: Optional[ + Callable[ + [InvocationContext, A2AEvent, Event], Awaitable[Union[Event, None]] + ] + ] = None + """Hook executed after the agent has processed the request. + + Returns None if the event should not be sent to the caller. + """ + + +class A2aRemoteAgentConfig(BaseModel): + """Configuration for the RemoteA2aAgent.""" + + request_interceptors: Optional[list[RequestInterceptor]] = None diff --git a/src/google/adk/a2a/agent/utils.py b/src/google/adk/a2a/agent/utils.py new file mode 100644 index 00000000..7cbb25eb --- /dev/null +++ b/src/google/adk/a2a/agent/utils.py @@ -0,0 +1,70 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for A2A agents.""" + +from __future__ import annotations + +from typing import Optional +from typing import Union + +from a2a.client import ClientEvent as A2AClientEvent +from a2a.client.middleware import ClientCallContext +from a2a.types import Message as A2AMessage + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event +from .config import ParametersConfig +from .config import RequestInterceptor + + +async def execute_before_request_interceptors( + request_interceptors: Optional[list[RequestInterceptor]], + ctx: InvocationContext, + a2a_request: A2AMessage, +) -> tuple[Union[A2AMessage, Event], ParametersConfig]: + """Executes registered before_request interceptors.""" + + params = ParametersConfig( + client_call_context=ClientCallContext(state=ctx.session.state) + ) + if request_interceptors: + for interceptor in request_interceptors: + if not interceptor.before_request: + continue + + result, params = await interceptor.before_request( + ctx, a2a_request, params + ) + if isinstance(result, Event): + return result, params + a2a_request = result + + return a2a_request, params + + +async def execute_after_request_interceptors( + request_interceptors: Optional[list[RequestInterceptor]], + ctx: InvocationContext, + a2a_response: A2AMessage | A2AClientEvent, + event: Event, +) -> Optional[Event]: + """Executes registered after_request interceptors.""" + if request_interceptors: + for interceptor in reversed(request_interceptors): + if interceptor.after_request: + event = await interceptor.after_request(ctx, a2a_response, event) + if not event: + return None + return event diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 2da7a4fa..5ffd123f 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -35,6 +35,7 @@ from a2a.client.errors import A2AClientHTTPError from a2a.client.middleware import ClientCallContext from a2a.types import AgentCard from a2a.types import Message as A2AMessage +from a2a.types import MessageSendConfiguration from a2a.types import Part as A2APart from a2a.types import Role from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent @@ -43,6 +44,7 @@ from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent from a2a.types import TransportProtocol as A2ATransport from google.genai import types as genai_types import httpx +from pydantic import BaseModel try: from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -50,6 +52,9 @@ except ImportError: # Fallback for older versions of a2a-sdk. AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json" +from ..a2a.agent.config import A2aRemoteAgentConfig +from ..a2a.agent.utils import execute_after_request_interceptors +from ..a2a.agent.utils import execute_before_request_interceptors from ..a2a.converters.event_converter import convert_a2a_message_to_event from ..a2a.converters.event_converter import convert_a2a_task_to_event from ..a2a.converters.event_converter import convert_event_to_a2a_message @@ -127,6 +132,7 @@ class RemoteA2aAgent(BaseAgent): Callable[[InvocationContext, A2AMessage], dict[str, Any]] ] = None, full_history_when_stateless: bool = False, + config: Optional[A2aRemoteAgentConfig] = None, **kwargs: Any, ) -> None: """Initialize RemoteA2aAgent. @@ -147,6 +153,7 @@ class RemoteA2aAgent(BaseAgent): return Tasks or context IDs) will receive all session events on every request. If False, the default behavior of sending only events since the last reply from the agent will be used. + config: Optional configuration object. **kwargs: Additional arguments passed to BaseAgent Raises: @@ -174,6 +181,7 @@ class RemoteA2aAgent(BaseAgent): self._a2a_client_factory: Optional[A2AClientFactory] = a2a_client_factory self._a2a_request_meta_provider = a2a_request_meta_provider self._full_history_when_stateless = full_history_when_stateless + self._config = config or A2aRemoteAgentConfig() # Validate and store agent card reference if isinstance(agent_card, AgentCard): @@ -558,14 +566,26 @@ class RemoteA2aAgent(BaseAgent): logger.debug(build_a2a_request_log(a2a_request)) try: - request_metadata = None - if self._a2a_request_meta_provider: - request_metadata = self._a2a_request_meta_provider(ctx, a2a_request) + a2a_request, parameters = await execute_before_request_interceptors( + self._config.request_interceptors, ctx, a2a_request + ) + if isinstance(a2a_request, Event): + yield a2a_request + return + + # Backward compatibility + if self._a2a_request_meta_provider: + parameters.request_metadata = self._a2a_request_meta_provider( + ctx, a2a_request + ) + + # TODO: Add support for requested_extension and + # message_send_configuration once they are supported by the A2A client. async for a2a_response in self._a2a_client.send_message( request=a2a_request, - request_metadata=request_metadata, - context=ClientCallContext(state=ctx.session.state), + request_metadata=parameters.request_metadata, + context=parameters.client_call_context, ): logger.debug(build_a2a_response_log(a2a_response)) @@ -573,6 +593,12 @@ class RemoteA2aAgent(BaseAgent): if not event: continue + event = await execute_after_request_interceptors( + self._config.request_interceptors, ctx, a2a_response, event + ) + if not event: + continue + # Add metadata about the request and response event.custom_metadata = event.custom_metadata or {} event.custom_metadata[A2A_METADATA_PREFIX + "request"] = ( diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 7643125d..fe155d30 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -21,7 +21,6 @@ from unittest.mock import Mock from unittest.mock import patch from a2a.client.client import ClientConfig -from a2a.client.client import Consumer from a2a.client.client_factory import ClientFactory from a2a.client.middleware import ClientCallContext from a2a.types import AgentCapabilities @@ -29,13 +28,16 @@ from a2a.types import AgentCard from a2a.types import AgentSkill from a2a.types import Artifact from a2a.types import Message as A2AMessage -from a2a.types import SendMessageSuccessResponse from a2a.types import Task as A2ATask from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatus as A2ATaskStatus from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart +from google.adk.a2a.agent import ParametersConfig +from google.adk.a2a.agent import RequestInterceptor +from google.adk.a2a.agent.utils import execute_after_request_interceptors +from google.adk.a2a.agent.utils import execute_before_request_interceptors from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX from google.adk.agents.remote_a2a_agent import AgentCardResolutionError @@ -2432,3 +2434,203 @@ class TestRemoteA2aAgentIntegration: # Verify A2A client was called mock_a2a_client.send_message.assert_called_once() + + +class TestRemoteA2aAgentInterceptors: + + @pytest.fixture + def mock_context(self): + ctx = Mock(spec=InvocationContext) + ctx.session = Mock() + ctx.session.state = {"key": "value"} + return ctx + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_none(self, mock_context): + request = Mock(spec=A2AMessage) + result_req, params = await execute_before_request_interceptors( + None, mock_context, request + ) + assert result_req is request + assert params.client_call_context.state == {"key": "value"} + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_empty(self, mock_context): + request = Mock(spec=A2AMessage) + result_req, params = await execute_before_request_interceptors( + [], mock_context, request + ) + assert result_req is request + assert params.client_call_context.state == {"key": "value"} + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_success( + self, mock_context + ): + request = Mock(spec=A2AMessage) + new_request = Mock(spec=A2AMessage) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.before_request = AsyncMock( + return_value=( + new_request, + ParametersConfig( + client_call_context=ClientCallContext(state={"updated": "true"}) + ), + ) + ) + + result_req, params = await execute_before_request_interceptors( + [interceptor1], mock_context, request + ) + + assert result_req is new_request + assert params.client_call_context.state == {"updated": "true"} + interceptor1.before_request.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_returns_event( + self, mock_context + ): + request = Mock(spec=A2AMessage) + event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.before_request = AsyncMock( + return_value=( + event, + ParametersConfig( + client_call_context=ClientCallContext(state={"updated": "true"}) + ), + ) + ) + + interceptor2 = Mock(spec=RequestInterceptor) + interceptor2.before_request = AsyncMock() + + result, params = await execute_before_request_interceptors( + [interceptor1, interceptor2], mock_context, request + ) + + assert result is event + assert params.client_call_context.state == {"updated": "true"} + interceptor1.before_request.assert_called_once() + interceptor2.before_request.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_no_before_request( + self, mock_context + ): + request = Mock(spec=A2AMessage) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.before_request = None + + result_req, params = await execute_before_request_interceptors( + [interceptor1], mock_context, request + ) + + assert result_req is request + assert params.client_call_context.state == {"key": "value"} + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_none(self, mock_context): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + result = await execute_after_request_interceptors( + None, mock_context, response, event + ) + assert result is event + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_empty(self, mock_context): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + result = await execute_after_request_interceptors( + [], mock_context, response, event + ) + assert result is event + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_success(self, mock_context): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + new_event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = AsyncMock(return_value=new_event) + + result = await execute_after_request_interceptors( + [interceptor1], mock_context, response, event + ) + + assert result is new_event + interceptor1.after_request.assert_called_once_with( + mock_context, response, event + ) + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_reverse_order( + self, mock_context + ): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + event1 = Mock(spec=Event) + event2 = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = AsyncMock(return_value=event1) + + interceptor2 = Mock(spec=RequestInterceptor) + interceptor2.after_request = AsyncMock(return_value=event2) + + result = await execute_after_request_interceptors( + [interceptor1, interceptor2], mock_context, response, event + ) + + assert result is event1 + interceptor2.after_request.assert_called_once_with( + mock_context, response, event + ) + interceptor1.after_request.assert_called_once_with( + mock_context, response, event2 + ) + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_returns_none( + self, mock_context + ): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = AsyncMock() + + interceptor2 = Mock(spec=RequestInterceptor) + interceptor2.after_request = AsyncMock(return_value=None) + + result = await execute_after_request_interceptors( + [interceptor1, interceptor2], mock_context, response, event + ) + + assert result is None + interceptor2.after_request.assert_called_once_with( + mock_context, response, event + ) + interceptor1.after_request.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_no_after_request( + self, mock_context + ): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = None + + result = await execute_after_request_interceptors( + [interceptor1], mock_context, response, event + ) + + assert result is event From 19718e9c174af7b1287b627e6b23a609db1ee5e2 Mon Sep 17 00:00:00 2001 From: Wiktoria Walczak Date: Thu, 26 Feb 2026 07:17:37 -0800 Subject: [PATCH 049/102] feat(otel): add experimental semantic convention and emit `gen_ai.client.inference.operation.details` event Co-authored-by: Wiktoria Walczak PiperOrigin-RevId: 875709959 --- .../adk/flows/llm_flows/base_llm_flow.py | 14 +- .../adk/telemetry/_experimental_semconv.py | 323 ++++++++++++++++++ src/google/adk/telemetry/tracing.py | 185 +++++++++- tests/unittests/telemetry/test_functional.py | 8 + tests/unittests/telemetry/test_spans.py | 218 +++++++++++- 5 files changed, 728 insertions(+), 20 deletions(-) create mode 100644 src/google/adk/telemetry/_experimental_semconv.py diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 424bb580..5368ca93 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -368,11 +368,17 @@ async def _run_and_handle_error( try: async with Aclosing(response_generator) as agen: - with tracing.use_generate_content_span( - llm_request, invocation_context, model_response_event - ) as span: + async with tracing.use_inference_span( + llm_request, + invocation_context, + model_response_event, + ) as gc_span: async for llm_response in agen: - tracing.trace_generate_content_result(span, llm_response) + if gc_span: + tracing.trace_inference_result( + gc_span, + llm_response, + ) yield llm_response except Exception as model_error: callback_context = CallbackContext( diff --git a/src/google/adk/telemetry/_experimental_semconv.py b/src/google/adk/telemetry/_experimental_semconv.py new file mode 100644 index 00000000..acec4437 --- /dev/null +++ b/src/google/adk/telemetry/_experimental_semconv.py @@ -0,0 +1,323 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Provides instrumentation for experimental semantic convention https://github.com/open-telemetry/semantic-conventions/blob/v1.39.0/docs/gen-ai/gen-ai-events.md.""" + +from __future__ import annotations + +from collections.abc import Mapping +from collections.abc import MutableMapping +import contextvars +import json +import os +from typing import Any +from typing import Literal +from typing import TypedDict + +from google.genai import types +from google.genai.models import t as transformers +from opentelemetry._logs import Logger +from opentelemetry._logs import LogRecord +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_INPUT_MESSAGES +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_OUTPUT_MESSAGES +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_RESPONSE_FINISH_REASONS +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_SYSTEM_INSTRUCTIONS +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_USAGE_INPUT_TOKENS +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_USAGE_OUTPUT_TOKENS +from opentelemetry.trace import Span +from opentelemetry.util.types import AttributeValue + +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse + +OTEL_SEMCONV_STABILITY_OPT_IN = 'OTEL_SEMCONV_STABILITY_OPT_IN' + +OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = ( + 'OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT' +) + + +class Text(TypedDict): + content: str + type: Literal['text'] + + +class Blob(TypedDict): + mime_type: str + data: bytes + type: Literal['blob'] + + +class FileData(TypedDict): + mime_type: str + uri: str + type: Literal['file_data'] + + +class ToolCall(TypedDict): + id: str | None + name: str + arguments: Any + type: Literal['tool_call'] + + +class ToolCallResponse(TypedDict): + id: str | None + response: Any + type: Literal['tool_call_response'] + + +Part = Text | Blob | FileData | ToolCall | ToolCallResponse + + +class InputMessage(TypedDict): + role: str + parts: list[Part] + + +class OutputMessage(TypedDict): + role: str + parts: list[Part] + finish_reason: str + + +def _safe_json_serialize_no_whitespaces(obj) -> str: + """Convert any Python object to a JSON-serializable type or string. + + Args: + obj: The object to serialize. + + Returns: + The JSON-serialized object string or if the object cannot be serialized. + """ + + try: + # Try direct JSON serialization first + return json.dumps( + obj, + separators=(',', ':'), + ensure_ascii=False, + default=lambda o: '', + ) + except (TypeError, OverflowError): + return '' + + +def is_experimental_semconv() -> bool: + opt_ins = os.getenv(OTEL_SEMCONV_STABILITY_OPT_IN) + if not opt_ins: + return False + opt_ins_list = [s.strip() for s in opt_ins.split(',')] + return 'gen_ai_latest_experimental' in opt_ins_list + + +def get_content_capturing_mode() -> str: + return os.getenv( + OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, '' + ).upper() + + +def _to_input_message( + content: types.Content, +) -> InputMessage: + parts = (_to_part(part, idx) for idx, part in enumerate(content.parts or [])) + return InputMessage( + role=_to_role(content.role), + parts=[part for part in parts if part is not None], + ) + + +def _to_output_message( + llm_response: LlmResponse, +) -> OutputMessage | None: + if not llm_response.content: + return None + + message = _to_input_message(llm_response.content) + return OutputMessage( + role=message['role'], + parts=message['parts'], + finish_reason=_to_finish_reason(llm_response.finish_reason), + ) + + +def _to_finish_reason( + finish_reason: types.FinishReason | None, +) -> str: + if finish_reason is None: + return '' + if ( + # Mapping unspecified and other to error, + # as JSON schema for finish_reason does not support them. + finish_reason is types.FinishReason.FINISH_REASON_UNSPECIFIED + or finish_reason is types.FinishReason.OTHER + ): + return 'error' + if finish_reason is types.FinishReason.STOP: + return 'stop' + if finish_reason is types.FinishReason.MAX_TOKENS: + return 'length' + + return finish_reason.name.lower() + + +def _to_part(part: types.Part, idx: int) -> Part | None: + def tool_call_id_fallback(name: str | None) -> str: + if name: + return f'{name}_{idx}' + return f'{idx}' + + if part is None: + return None + + if (text := part.text) is not None: + return Text(content=text, type='text') + + if data := part.inline_data: + return Blob( + mime_type=data.mime_type or '', data=data.data or b'', type='blob' + ) + + if data := part.file_data: + return FileData( + mime_type=data.mime_type or '', + uri=data.file_uri or '', + type='file_data', + ) + + if call := part.function_call: + return ToolCall( + id=call.id or tool_call_id_fallback(call.name), + name=call.name or '', + arguments=call.args, + type='tool_call', + ) + + if response := part.function_response: + return ToolCallResponse( + id=response.id or tool_call_id_fallback(response.name), + response=response.response, + type='tool_call_response', + ) + + return None + + +def _to_role(role: str | None) -> str: + if role == 'user': + return 'user' + if role == 'model': + return 'assistant' + return '' + + +def _to_input_messages(contents: list[types.Content]) -> list[InputMessage]: + return [_to_input_message(content) for content in contents] + + +def _to_system_instructions( + config: types.GenerateContentConfig, +) -> list[Part]: + + if not config.system_instruction: + return [] + + transformed_contents = transformers.t_contents(config.system_instruction) + if not transformed_contents: + return [] + + sys_instr = transformed_contents[0] + + parts = ( + _to_part(part, idx) for idx, part in enumerate(sys_instr.parts or []) + ) + return [part for part in parts if part is not None] + + +def set_operation_details_common_attributes( + operation_details_common_attributes: MutableMapping[str, AttributeValue], + attributes: Mapping[str, AttributeValue], +): + operation_details_common_attributes.update(attributes) + + +async def set_operation_details_attributes_from_request( + operation_details_attributes: MutableMapping[str, AttributeValue], + llm_request: LlmRequest, +): + + input_messages = _to_input_messages( + transformers.t_contents(llm_request.contents) + ) + + system_instructions = _to_system_instructions(llm_request.config) + + operation_details_attributes[GEN_AI_INPUT_MESSAGES] = input_messages + operation_details_attributes[GEN_AI_SYSTEM_INSTRUCTIONS] = system_instructions + + +def set_operation_details_attributes_from_response( + llm_response: LlmResponse, + operation_details_attributes: MutableMapping[str, AttributeValue], + operation_details_common_attributes: MutableMapping[str, AttributeValue], +): + if finish_reason := llm_response.finish_reason: + operation_details_common_attributes[GEN_AI_RESPONSE_FINISH_REASONS] = [ + _to_finish_reason(finish_reason) + ] + if usage_metadata := llm_response.usage_metadata: + if usage_metadata.prompt_token_count is not None: + operation_details_common_attributes[GEN_AI_USAGE_INPUT_TOKENS] = ( + usage_metadata.prompt_token_count + ) + if usage_metadata.candidates_token_count is not None: + operation_details_common_attributes[GEN_AI_USAGE_OUTPUT_TOKENS] = ( + usage_metadata.candidates_token_count + ) + + output_message = _to_output_message(llm_response) + if output_message is not None: + operation_details_attributes[GEN_AI_OUTPUT_MESSAGES] = [output_message] + + +def maybe_log_completion_details( + span: Span | None, + otel_logger: Logger, + operation_details_attributes: Mapping[str, AttributeValue], + operation_details_common_attributes: Mapping[str, AttributeValue], +): + """Logs completion details based on the experimental semantic convention capturing mode.""" + if span is None: + return + + if not is_experimental_semconv(): + return + + capturing_mode = get_content_capturing_mode() + final_attributes = operation_details_common_attributes + + if capturing_mode in ['EVENT_ONLY', 'SPAN_AND_EVENT']: + final_attributes = final_attributes | operation_details_attributes + + otel_logger.emit( + LogRecord( + event_name='gen_ai.client.inference.operation.details', + attributes=final_attributes, + ) + ) + + if capturing_mode in ['SPAN_ONLY', 'SPAN_AND_EVENT']: + for key, value in operation_details_attributes.items(): + span.set_attribute(key, _safe_json_serialize_no_whitespaces(value)) diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index fbb55ec9..707bc313 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -23,8 +23,11 @@ from __future__ import annotations +import asyncio +from collections.abc import AsyncIterator from collections.abc import Iterator from collections.abc import Mapping +from contextlib import asynccontextmanager from contextlib import contextmanager import json import logging @@ -58,9 +61,15 @@ from opentelemetry.trace import Span from opentelemetry.util.types import AnyValue from opentelemetry.util.types import AttributeValue from pydantic import BaseModel +from typing_extensions import deprecated from .. import version from ..utils.model_name_utils import is_gemini_model +from ._experimental_semconv import is_experimental_semconv +from ._experimental_semconv import maybe_log_completion_details +from ._experimental_semconv import set_operation_details_attributes_from_request +from ._experimental_semconv import set_operation_details_attributes_from_response +from ._experimental_semconv import set_operation_details_common_attributes # By default some ADK spans include attributes with potential PII data. # This env, when set to false, allows to disable populating those attributes. @@ -427,6 +436,7 @@ def _should_add_request_response_to_spans() -> bool: return not disabled_via_env_var +@deprecated('Replaced by use_inference_span to support experimental semconv.') @contextmanager def use_generate_content_span( llm_request: LlmRequest, @@ -453,11 +463,57 @@ def use_generate_content_span( with _use_extra_generate_content_attributes(common_attributes): yield else: - with _use_native_generate_content_span( + with _use_native_generate_content_span_stable_semconv( llm_request=llm_request, common_attributes=common_attributes, ) as span: - yield span + yield span.span + + +@asynccontextmanager +async def use_inference_span( + llm_request: LlmRequest, + invocation_context: InvocationContext, + model_response_event: Event, +) -> AsyncIterator[GenerateContentSpan | None]: + """Context manager encompassing `generate_content {model.name}` span. + + When an external library for inference instrumentation is installed (e.g. + opentelemetry-instrumentation-google-genai), + span creation is delegated to said library. + """ + + common_attributes = { + GEN_AI_AGENT_NAME: invocation_context.agent.name, + GEN_AI_CONVERSATION_ID: invocation_context.session.id, + USER_ID: invocation_context.session.user_id, + 'gcp.vertex.agent.event_id': model_response_event.id, + 'gcp.vertex.agent.invocation_id': invocation_context.invocation_id, + } + if ( + _is_gemini_agent(invocation_context.agent) + and _instrumented_with_opentelemetry_instrumentation_google_genai() + ): + with _use_extra_generate_content_attributes(common_attributes): + yield + else: + async with _use_native_generate_content_span( + llm_request=llm_request, + common_attributes=common_attributes, + ) as gc_span: + if is_experimental_semconv(): + set_operation_details_common_attributes( + gc_span.operation_details_common_attributes, common_attributes + ) + try: + yield gc_span + finally: + maybe_log_completion_details( + gc_span.span, + otel_logger, + gc_span.operation_details_attributes, + gc_span.operation_details_common_attributes, + ) def _should_log_prompt_response_content() -> bool: @@ -467,6 +523,8 @@ def _should_log_prompt_response_content() -> bool: def _serialize_content(content: types.ContentUnion) -> AnyValue: + if content is None: + return None if isinstance(content, BaseModel): return content.model_dump() if isinstance(content, str): @@ -540,18 +598,29 @@ def _is_gemini_agent(agent: BaseAgent) -> bool: return isinstance(agent.model, Gemini) -@contextmanager -def _use_native_generate_content_span( +def _set_common_generate_content_attributes( + span: Span, llm_request: LlmRequest, common_attributes: Mapping[str, AttributeValue], -) -> Iterator[Span]: +): + span.set_attribute(GEN_AI_OPERATION_NAME, 'generate_content') + span.set_attribute(GEN_AI_REQUEST_MODEL, llm_request.model or '') + span.set_attributes(common_attributes) + + +@contextmanager +def _use_native_generate_content_span_stable_semconv( + llm_request: LlmRequest, + common_attributes: Mapping[str, AttributeValue], +) -> Iterator[GenerateContentSpan]: with tracer.start_as_current_span( f"generate_content {llm_request.model or ''}" ) as span: span.set_attribute(GEN_AI_SYSTEM, _guess_gemini_system_name()) - span.set_attribute(GEN_AI_OPERATION_NAME, 'generate_content') - span.set_attribute(GEN_AI_REQUEST_MODEL, llm_request.model or '') - span.set_attributes(common_attributes) + _set_common_generate_content_attributes( + span, llm_request, common_attributes + ) + gc_span = GenerateContentSpan(span) otel_logger.emit( LogRecord( @@ -564,7 +633,6 @@ def _use_native_generate_content_span( attributes={GEN_AI_SYSTEM: _guess_gemini_system_name()}, ) ) - for content in llm_request.contents: otel_logger.emit( LogRecord( @@ -574,9 +642,51 @@ def _use_native_generate_content_span( ) ) - yield span + yield gc_span +@asynccontextmanager +async def _use_native_generate_content_span( + llm_request: LlmRequest, + common_attributes: Mapping[str, AttributeValue], +) -> AsyncIterator[GenerateContentSpan]: + if not is_experimental_semconv(): + with _use_native_generate_content_span_stable_semconv( + llm_request, common_attributes + ) as gc_span: + yield gc_span + return + + with tracer.start_as_current_span( + f"generate_content {llm_request.model or ''}" + ) as span: + + _set_common_generate_content_attributes( + span, llm_request, common_attributes + ) + gc_span = GenerateContentSpan(span) + + await set_operation_details_attributes_from_request( + gc_span.operation_details_attributes, llm_request + ) + yield gc_span + + +class GenerateContentSpan: + """Manages tracing within a `generate_content` OpenTelemetry span. + + This class provides attributes for the experimental semantic convention. + """ + + def __init__(self, span: Span): + self.span = span + self.operation_details_attributes = {} + self.operation_details_common_attributes = {} + + +@deprecated( + 'Replaced by trace_inference_result to support experimental semconv.' +) def trace_generate_content_result(span: Span | None, llm_response: LlmResponse): """Trace result of the inference in generate_content span.""" @@ -613,6 +723,61 @@ def trace_generate_content_result(span: Span | None, llm_response: LlmResponse): ) +def trace_inference_result( + span: Span | None | GenerateContentSpan, + llm_response: LlmResponse, +): + """Trace result of the inference in generate_content span.""" + gc_span = None + if isinstance(span, GenerateContentSpan): + gc_span = span + span = gc_span.span + + if span is None: + return + + if llm_response.partial: + return + + if finish_reason := llm_response.finish_reason: + span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, [finish_reason.lower()]) + if usage_metadata := llm_response.usage_metadata: + if usage_metadata.prompt_token_count is not None: + span.set_attribute( + GEN_AI_USAGE_INPUT_TOKENS, usage_metadata.prompt_token_count + ) + if usage_metadata.candidates_token_count is not None: + span.set_attribute( + GEN_AI_USAGE_OUTPUT_TOKENS, usage_metadata.candidates_token_count + ) + + if is_experimental_semconv() and isinstance(gc_span, GenerateContentSpan): + set_operation_details_attributes_from_response( + llm_response, + gc_span.operation_details_attributes, + gc_span.operation_details_common_attributes, + ) + + else: + otel_logger.emit( + LogRecord( + event_name='gen_ai.choice', + body={ + 'content': _serialize_content_with_elision( + llm_response.content + ), + 'index': 0, # ADK always returns a single candidate + } + | ( + {'finish_reason': llm_response.finish_reason.value} + if llm_response.finish_reason is not None + else {} + ), + attributes={GEN_AI_SYSTEM: _guess_gemini_system_name()}, + ) + ) + + def _guess_gemini_system_name() -> str: return ( GenAiSystemValues.VERTEX_AI.name.lower() diff --git a/tests/unittests/telemetry/test_functional.py b/tests/unittests/telemetry/test_functional.py index f7d4b0a3..3b7d93c4 100644 --- a/tests/unittests/telemetry/test_functional.py +++ b/tests/unittests/telemetry/test_functional.py @@ -97,6 +97,14 @@ async def test_tracer_start_as_current_span( def wrapped_firstiter(coro): nonlocal firstiter + # Skip check for specific async context managers in tracing.py, + # as their internal generators are not expected to be Aclosing-wrapped. + if ( + coro.__name__ == 'use_inference_span' + or coro.__name__ == '_use_native_generate_content_span' + ): + firstiter(coro) + return assert any( isinstance(referrer, Aclosing) or isinstance(indirect_referrer, Aclosing) diff --git a/tests/unittests/telemetry/test_spans.py b/tests/unittests/telemetry/test_spans.py index bb084676..793c0bb3 100644 --- a/tests/unittests/telemetry/test_spans.py +++ b/tests/unittests/telemetry/test_spans.py @@ -26,20 +26,23 @@ from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.telemetry.tracing import ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS from google.adk.telemetry.tracing import trace_agent_invocation from google.adk.telemetry.tracing import trace_call_llm -from google.adk.telemetry.tracing import trace_generate_content_result +from google.adk.telemetry.tracing import trace_inference_result from google.adk.telemetry.tracing import trace_merged_tool_calls from google.adk.telemetry.tracing import trace_send_data from google.adk.telemetry.tracing import trace_tool_call -from google.adk.telemetry.tracing import use_generate_content_span +from google.adk.telemetry.tracing import use_inference_span from google.adk.tools.base_tool import BaseTool from google.genai import types from opentelemetry._logs import LogRecord from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_AGENT_NAME from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_CONVERSATION_ID +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_INPUT_MESSAGES from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_OPERATION_NAME +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_OUTPUT_MESSAGES from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_REQUEST_MODEL from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_RESPONSE_FINISH_REASONS from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_SYSTEM +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_SYSTEM_INSTRUCTIONS from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_USAGE_INPUT_TOKENS from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_USAGE_OUTPUT_TOKENS from opentelemetry.semconv._incubating.attributes.user_attributes import USER_ID @@ -731,12 +734,12 @@ async def test_generate_content_span( ) # Act - with use_generate_content_span( + async with use_inference_span( llm_request, invocation_context, model_response_event - ) as span: - assert span is mock_span + ) as gc_span: + assert gc_span.span is mock_span - trace_generate_content_result(span, llm_response) + trace_inference_result(gc_span, llm_response) # Assert Span mock_tracer.start_as_current_span.assert_called_once_with( @@ -810,3 +813,206 @@ async def test_generate_content_span( assert choice_log is not None assert choice_log.body == expected_choice_body assert choice_log.attributes == {GEN_AI_SYSTEM: 'test_system'} + + +@pytest.mark.asyncio +@mock.patch('google.adk.telemetry.tracing.otel_logger') +@mock.patch('google.adk.telemetry.tracing.tracer') +@mock.patch( + 'google.adk.telemetry.tracing._guess_gemini_system_name', + return_value='test_system', +) +@pytest.mark.parametrize( + 'capture_content', + ['SPAN_AND_EVENT', 'EVENT_ONLY', 'SPAN_ONLY', 'NO_CONTENT'], +) +async def test_generate_content_span_with_experimental_semconv( + mock_guess_system_name, + mock_tracer, + mock_otel_logger, + monkeypatch, + capture_content, +): + """Test native generate_content span creation with attributes and logs with experimental semconv enabled.""" + # Arrange + monkeypatch.setenv( + 'OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT', + str(capture_content).lower(), + ) + monkeypatch.setenv( + 'OTEL_SEMCONV_STABILITY_OPT_IN', + 'gen_ai_latest_experimental', + ) + monkeypatch.setattr( + 'google.adk.telemetry.tracing._instrumented_with_opentelemetry_instrumentation_google_genai', + lambda: False, + ) + + agent = LlmAgent(name='test_agent', model='not-a-gemini-model') + invocation_context = await _create_invocation_context(agent) + + system_instruction = types.Content( + parts=[types.Part.from_text(text='You are a helpful assistant.')], + ) + + user_content1 = types.Content(role='user', parts=[types.Part(text='Hello')]) + user_content2 = types.Content(role='user', parts=[types.Part(text='World')]) + + model_content = types.Content( + role='model', parts=[types.Part(text='Response')] + ) + + llm_request = LlmRequest( + model='some-model', + contents=[user_content1, user_content2], + config=types.GenerateContentConfig( + system_instruction=system_instruction, + ), + ) + llm_response = LlmResponse( + content=model_content, + finish_reason=types.FinishReason.STOP, + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=10, + candidates_token_count=20, + ), + ) + + model_response_event = mock.MagicMock() + model_response_event.id = 'event-123' + + mock_span = ( + mock_tracer.start_as_current_span.return_value.__enter__.return_value + ) + + # Act + async with use_inference_span( + llm_request, + invocation_context, + model_response_event, + ) as gc_span: + assert gc_span.span is mock_span + + trace_inference_result(gc_span, llm_response) + + # Expected attributes + expected_system_instructions = [ + { + 'content': 'You are a helpful assistant.', + 'type': 'text', + }, + ] + expected_input_messages = [ + { + 'role': 'user', + 'parts': [ + {'content': 'Hello', 'type': 'text'}, + ], + }, + { + 'role': 'user', + 'parts': [ + {'content': 'World', 'type': 'text'}, + ], + }, + ] + expected_output_messages = [{ + 'role': 'assistant', + 'parts': [ + {'content': 'Response', 'type': 'text'}, + ], + 'finish_reason': 'stop', + }] + # Assert Span + mock_tracer.start_as_current_span.assert_called_once_with( + 'generate_content some-model' + ) + + mock_span.set_attribute.assert_any_call( + GEN_AI_OPERATION_NAME, 'generate_content' + ) + mock_span.set_attribute.assert_any_call(GEN_AI_REQUEST_MODEL, 'some-model') + mock_span.set_attribute.assert_any_call( + GEN_AI_RESPONSE_FINISH_REASONS, ['stop'] + ) + mock_span.set_attribute.assert_any_call(GEN_AI_USAGE_INPUT_TOKENS, 10) + mock_span.set_attribute.assert_any_call(GEN_AI_USAGE_OUTPUT_TOKENS, 20) + + mock_span.set_attributes.assert_called_once_with({ + GEN_AI_AGENT_NAME: invocation_context.agent.name, + GEN_AI_CONVERSATION_ID: invocation_context.session.id, + USER_ID: invocation_context.session.user_id, + 'gcp.vertex.agent.event_id': 'event-123', + 'gcp.vertex.agent.invocation_id': invocation_context.invocation_id, + }) + + if capture_content in ['SPAN_AND_EVENT', 'SPAN_ONLY']: + mock_span.set_attribute.assert_any_call( + GEN_AI_SYSTEM_INSTRUCTIONS, + '[{"content":"You are a helpful assistant.","type":"text"}]', + ) + mock_span.set_attribute.assert_any_call( + GEN_AI_INPUT_MESSAGES, + '[{"role":"user","parts":[{"content":"Hello","type":"text"}]},{"role":"user","parts":[{"content":"World","type":"text"}]}]', + ) + mock_span.set_attribute.assert_any_call( + GEN_AI_OUTPUT_MESSAGES, + '[{"role":"assistant","parts":[{"content":"Response","type":"text"}],"finish_reason":"stop"}]', + ) + + else: + all_attribute_calls = mock_span.set_attribute.call_args_list + assert GEN_AI_SYSTEM_INSTRUCTIONS not in all_attribute_calls + assert GEN_AI_INPUT_MESSAGES not in all_attribute_calls + assert GEN_AI_OUTPUT_MESSAGES not in all_attribute_calls + + # Assert Logs + assert mock_otel_logger.emit.call_count == 1 + + log_records: list[LogRecord] = [ + call.args[0] for call in mock_otel_logger.emit.call_args_list + ] + + operation_details_log = next( + ( + lr + for lr in log_records + if lr.event_name == 'gen_ai.client.inference.operation.details' + ), + None, + ) + + assert operation_details_log is not None + assert operation_details_log.attributes is not None + + attributes = operation_details_log.attributes + + if capture_content in ['SPAN_AND_EVENT', 'EVENT_ONLY']: + assert GEN_AI_SYSTEM_INSTRUCTIONS in attributes + assert ( + attributes[GEN_AI_SYSTEM_INSTRUCTIONS] == expected_system_instructions + ) + assert GEN_AI_INPUT_MESSAGES in attributes + assert attributes[GEN_AI_INPUT_MESSAGES] == expected_input_messages + assert GEN_AI_OUTPUT_MESSAGES in attributes + assert attributes[GEN_AI_OUTPUT_MESSAGES] == expected_output_messages + else: + assert GEN_AI_SYSTEM_INSTRUCTIONS not in attributes + assert GEN_AI_INPUT_MESSAGES not in attributes + assert GEN_AI_OUTPUT_MESSAGES not in attributes + + assert GEN_AI_USAGE_INPUT_TOKENS in attributes + assert attributes[GEN_AI_USAGE_INPUT_TOKENS] == 10 + assert GEN_AI_USAGE_OUTPUT_TOKENS in attributes + assert attributes[GEN_AI_USAGE_OUTPUT_TOKENS] == 20 + assert 'gcp.vertex.agent.event_id' in attributes + assert attributes['gcp.vertex.agent.event_id'] == 'event-123' + assert 'gcp.vertex.agent.invocation_id' in attributes + assert ( + attributes['gcp.vertex.agent.invocation_id'] + == invocation_context.invocation_id + ) + assert GEN_AI_AGENT_NAME in attributes + assert attributes[GEN_AI_AGENT_NAME] == invocation_context.agent.name + assert GEN_AI_CONVERSATION_ID in attributes + assert attributes[GEN_AI_CONVERSATION_ID] == invocation_context.session.id From b38b708e23220e2dfa2932bc46fbc6c49a2ad275 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Thu, 26 Feb 2026 08:05:48 -0800 Subject: [PATCH 050/102] chore: Update release-please to always bump minor Co-authored-by: Sasha Sobran PiperOrigin-RevId: 875727292 --- .github/release-please-config.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/release-please-config.json b/.github/release-please-config.json index e7ecc230..5395e5a4 100644 --- a/.github/release-please-config.json +++ b/.github/release-please-config.json @@ -4,6 +4,7 @@ "packages": { ".": { "release-type": "python", + "versioning": "always-bump-minor", "package-name": "google-adk", "include-component-in-tag": false, "skip-github-release": true, From 7a813b0987de37d1cd834a0630ee4695091431a8 Mon Sep 17 00:00:00 2001 From: Google Admin Date: Thu, 26 Feb 2026 08:05:49 -0800 Subject: [PATCH 051/102] chore: refactor Github Action Co-authored-by: Sasha Sobran COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4535 from google:lsc-1771431513.0644026 200d04544fafa4c555f2031f66ef6a3c2ff8a774 PiperOrigin-RevId: 875727295 --- .github/workflows/check-file-contents.yml | 14 +++++++------- .github/workflows/isort.yml | 4 ++-- .github/workflows/pyink.yml | 4 ++-- .github/workflows/release-cherry-pick.yml | 6 ++++-- .github/workflows/release-finalize.yml | 4 +++- .github/workflows/release-publish.yml | 11 ++++++----- 6 files changed, 24 insertions(+), 19 deletions(-) diff --git a/.github/workflows/check-file-contents.yml b/.github/workflows/check-file-contents.yml index 7670733e..8e506d92 100644 --- a/.github/workflows/check-file-contents.yml +++ b/.github/workflows/check-file-contents.yml @@ -30,8 +30,8 @@ jobs: - name: Check for logger pattern in all changed Python files run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -61,8 +61,8 @@ jobs: - name: Check for import pattern in certain changed Python files run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -88,8 +88,8 @@ jobs: - name: Check for import from cli package in certain changed Python files run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -110,4 +110,4 @@ jobs: fi else echo "✅ No relevant Python files found." - fi \ No newline at end of file + fi diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index 49536911..840d4ea8 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -42,8 +42,8 @@ jobs: - name: Run isort on changed files id: run_isort run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files:" echo "$CHANGED_FILES" diff --git a/.github/workflows/pyink.yml b/.github/workflows/pyink.yml index d2eac1da..a2d9e6d7 100644 --- a/.github/workflows/pyink.yml +++ b/.github/workflows/pyink.yml @@ -42,8 +42,8 @@ jobs: - name: Run pyink on changed files id: run_pyink run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files:" echo "$CHANGED_FILES" diff --git a/.github/workflows/release-cherry-pick.yml b/.github/workflows/release-cherry-pick.yml index bf2247c4..ac5e5c08 100644 --- a/.github/workflows/release-cherry-pick.yml +++ b/.github/workflows/release-cherry-pick.yml @@ -30,8 +30,10 @@ jobs: - name: Cherry-pick commit run: | - echo "Cherry-picking ${{ inputs.commit_sha }} to release/candidate" - git cherry-pick ${{ inputs.commit_sha }} + echo "Cherry-picking ${INPUTS_COMMIT_SHA} to release/candidate" + git cherry-pick ${INPUTS_COMMIT_SHA} + env: + INPUTS_COMMIT_SHA: ${{ inputs.commit_sha }} - name: Push changes run: | diff --git a/.github/workflows/release-finalize.yml b/.github/workflows/release-finalize.yml index ade58ec2..b9d6203f 100644 --- a/.github/workflows/release-finalize.yml +++ b/.github/workflows/release-finalize.yml @@ -68,9 +68,11 @@ jobs: - name: Rename release/candidate to release/v{version} if: steps.check.outputs.is_release_pr == 'true' run: | - VERSION="v${{ steps.version.outputs.version }}" + VERSION="v${STEPS_VERSION_OUTPUTS_VERSION}" git push origin "release/candidate:refs/heads/release/$VERSION" ":release/candidate" echo "Renamed release/candidate to release/$VERSION" + env: + STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} - name: Update PR label to tagged if: steps.check.outputs.is_release_pr == 'true' diff --git a/.github/workflows/release-publish.yml b/.github/workflows/release-publish.yml index 5979cd9c..95ee326a 100644 --- a/.github/workflows/release-publish.yml +++ b/.github/workflows/release-publish.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Validate branch run: | - if [[ ! "${{ github.ref_name }}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + if [[ ! "${GITHUB_REF_NAME}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then echo "Error: Must run from a release/v* branch (e.g., release/v0.3.0)" exit 1 fi @@ -23,7 +23,7 @@ jobs: - name: Extract version id: version run: | - VERSION="${{ github.ref_name }}" + VERSION="${GITHUB_REF_NAME}" VERSION="${VERSION#release/v}" echo "version=$VERSION" >> $GITHUB_OUTPUT echo "Publishing version: $VERSION" @@ -51,9 +51,10 @@ jobs: - name: Create merge-back PR env: GH_TOKEN: ${{ secrets.RELEASE_PAT }} + STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} run: | gh pr create \ --base main \ - --head "${{ github.ref_name }}" \ - --title "chore: merge release v${{ steps.version.outputs.version }} to main" \ - --body "Syncs version bump and CHANGELOG from release v${{ steps.version.outputs.version }} to main." + --head "${GITHUB_REF_NAME}" \ + --title "chore: merge release v${STEPS_VERSION_OUTPUTS_VERSION} to main" \ + --body "Syncs version bump and CHANGELOG from release v${STEPS_VERSION_OUTPUTS_VERSION} to main." From 4dd4d5ecb6a1dadbc41389dac208616f6d21bc6e Mon Sep 17 00:00:00 2001 From: Wiktoria Walczak Date: Thu, 26 Feb 2026 08:42:26 -0800 Subject: [PATCH 052/102] feat(otel): add `gen_ai.tool.definitions` to experimental semconv Co-authored-by: Wiktoria Walczak PiperOrigin-RevId: 875741416 --- .../adk/telemetry/_experimental_semconv.py | 195 ++++++++++++++++++ tests/unittests/telemetry/test_spans.py | 163 ++++++++++++++- 2 files changed, 356 insertions(+), 2 deletions(-) diff --git a/src/google/adk/telemetry/_experimental_semconv.py b/src/google/adk/telemetry/_experimental_semconv.py index acec4437..dbfb3f14 100644 --- a/src/google/adk/telemetry/_experimental_semconv.py +++ b/src/google/adk/telemetry/_experimental_semconv.py @@ -28,6 +28,8 @@ from typing import TypedDict from google.genai import types from google.genai.models import t as transformers +from mcp import ClientSession as McpClientSession +from mcp import Tool as McpTool from opentelemetry._logs import Logger from opentelemetry._logs import LogRecord from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_INPUT_MESSAGES @@ -42,12 +44,19 @@ from opentelemetry.util.types import AttributeValue from ..models.llm_request import LlmRequest from ..models.llm_response import LlmResponse +try: + from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_TOOL_DEFINITIONS +except ImportError: + GEN_AI_TOOL_DEFINITIONS = 'gen_ai.tool_definitions' + OTEL_SEMCONV_STABILITY_OPT_IN = 'OTEL_SEMCONV_STABILITY_OPT_IN' OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = ( 'OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT' ) +FUNCTION_TOOL_DEFINITION_TYPE = 'function' + class Text(TypedDict): content: str @@ -93,6 +102,21 @@ class OutputMessage(TypedDict): finish_reason: str +class FunctionToolDefinition(TypedDict): + name: str + description: str | None + parameters: Any + type: Literal['function'] + + +class GenericToolDefinition(TypedDict): + name: str + type: str + + +ToolDefinition = FunctionToolDefinition | GenericToolDefinition + + def _safe_json_serialize_no_whitespaces(obj) -> str: """Convert any Python object to a JSON-serializable type or string. @@ -129,6 +153,158 @@ def get_content_capturing_mode() -> str: ).upper() +def _model_dump_to_tool_definition(tool: Any) -> dict[str, Any]: + model_dump = tool.model_dump(exclude_none=True) + + name = ( + model_dump.get('name') + or getattr(tool, 'name', None) + or type(tool).__name__ + ) + description = model_dump.get('description') or getattr( + tool, 'description', None + ) + parameters = model_dump.get('parameters') or model_dump.get('inputSchema') + return FunctionToolDefinition( + name=name, + description=description, + parameters=parameters, + type=FUNCTION_TOOL_DEFINITION_TYPE, + ) + + +def _clean_parameters(params: Any) -> Any: + """Converts parameter objects into plain dicts.""" + if params is None: + return None + if isinstance(params, dict): + return params + if hasattr(params, 'to_dict'): + return params.to_dict() + if hasattr(params, 'model_dump'): + return params.model_dump(exclude_none=True) + + try: + # Check if it's already a standard JSON type. + json.dumps(params) + return params + + except (TypeError, ValueError): + return { + 'type': 'object', + 'properties': { + 'serialization_error': { + 'type': 'string', + 'description': ( + f'Failed to serialize parameters: {type(params).__name__}' + ), + } + }, + } + + +def _tool_to_tool_definition(tool: types.Tool) -> list[dict[str, Any]]: + definitions = [] + if tool.function_declarations: + for fd in tool.function_declarations: + definitions.append( + FunctionToolDefinition( + name=getattr(fd, 'name', type(fd).__name__), + description=getattr(fd, 'description', None), + parameters=_clean_parameters(getattr(fd, 'parameters', None)), + type=FUNCTION_TOOL_DEFINITION_TYPE, + ) + ) + + # Generic types + if hasattr(tool, 'model_dump'): + exclude_fields = {'function_declarations'} + fields = { + k: v + for k, v in tool.model_dump().items() + if v is not None and k not in exclude_fields + } + + for tool_type, _ in fields.items(): + definitions.append( + GenericToolDefinition( + name=tool_type, + type=tool_type, + ) + ) + + return definitions + + +def _tool_definition_from_callable_tool(tool: Any) -> dict[str, Any]: + doc = getattr(tool, '__doc__', '') or '' + return FunctionToolDefinition( + name=getattr(tool, '__name__', type(tool).__name__), + description=doc.strip(), + parameters=None, + type=FUNCTION_TOOL_DEFINITION_TYPE, + ) + + +def _tool_definition_from_mcp_tool(tool: McpTool) -> dict[str, Any]: + if hasattr(tool, 'model_dump'): + return _model_dump_to_tool_definition(tool) + + return FunctionToolDefinition( + name=getattr(tool, 'name', type(tool).__name__), + description=getattr(tool, 'description', None), + parameters=getattr(tool, 'input_schema', None), + type=FUNCTION_TOOL_DEFINITION_TYPE, + ) + + +async def _to_tool_definitions( + tool: types.ToolUnionDict, +) -> list[dict[str, Any]]: + + if isinstance(tool, types.Tool): + return _tool_to_tool_definition(tool) + + if callable(tool): + return [_tool_definition_from_callable_tool(tool)] + + if isinstance(tool, McpTool): + return [_tool_definition_from_mcp_tool(tool)] + + if isinstance(tool, McpClientSession): + result = await tool.list_tools() + return [_model_dump_to_tool_definition(t) for t in result.tools] + + return [ + GenericToolDefinition( + name='UnserializableTool', + type=type(tool).__name__, + ) + ] + + +def _operation_details_attributes_no_content( + operation_details_attributes: Mapping[str, AttributeValue], +) -> dict[str, AttributeValue]: + tool_def = operation_details_attributes.get(GEN_AI_TOOL_DEFINITIONS) + if not tool_def: + return {} + + return { + GEN_AI_TOOL_DEFINITIONS: [ + FunctionToolDefinition( + name=td['name'], + description=td['description'], + parameters=None, + type=td['type'], + ) + if 'parameters' in td + else td + for td in tool_def + ] + } + + def _to_input_message( content: types.Content, ) -> InputMessage: @@ -264,8 +440,17 @@ async def set_operation_details_attributes_from_request( system_instructions = _to_system_instructions(llm_request.config) + tool_definitions = [] + if tools := llm_request.config.tools: + for tool in tools: + definitions = await _to_tool_definitions(tool) + for de in definitions: + if de: + tool_definitions.append(de) + operation_details_attributes[GEN_AI_INPUT_MESSAGES] = input_messages operation_details_attributes[GEN_AI_SYSTEM_INSTRUCTIONS] = system_instructions + operation_details_attributes[GEN_AI_TOOL_DEFINITIONS] = tool_definitions def set_operation_details_attributes_from_response( @@ -310,6 +495,11 @@ def maybe_log_completion_details( if capturing_mode in ['EVENT_ONLY', 'SPAN_AND_EVENT']: final_attributes = final_attributes | operation_details_attributes + else: + final_attributes = ( + final_attributes + | _operation_details_attributes_no_content(operation_details_attributes) + ) otel_logger.emit( LogRecord( @@ -321,3 +511,8 @@ def maybe_log_completion_details( if capturing_mode in ['SPAN_ONLY', 'SPAN_AND_EVENT']: for key, value in operation_details_attributes.items(): span.set_attribute(key, _safe_json_serialize_no_whitespaces(value)) + else: + for key, value in _operation_details_attributes_no_content( + operation_details_attributes + ).items(): + span.set_attribute(key, _safe_json_serialize_no_whitespaces(value)) diff --git a/tests/unittests/telemetry/test_spans.py b/tests/unittests/telemetry/test_spans.py index 793c0bb3..3c061e42 100644 --- a/tests/unittests/telemetry/test_spans.py +++ b/tests/unittests/telemetry/test_spans.py @@ -33,6 +33,9 @@ from google.adk.telemetry.tracing import trace_tool_call from google.adk.telemetry.tracing import use_inference_span from google.adk.tools.base_tool import BaseTool from google.genai import types +from mcp import ClientSession as McpClientSession +from mcp import ListToolsResult as McpListToolsResult +from mcp import Tool as McpTool from opentelemetry._logs import LogRecord from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_AGENT_NAME from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_CONVERSATION_ID @@ -48,6 +51,11 @@ from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_A from opentelemetry.semconv._incubating.attributes.user_attributes import USER_ID import pytest +try: + from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_TOOL_DEFINITIONS +except ImportError: + GEN_AI_TOOL_DEFINITIONS = 'gen_ai.tool_definitions' + class Event: @@ -815,6 +823,52 @@ async def test_generate_content_span( assert choice_log.attributes == {GEN_AI_SYSTEM: 'test_system'} +def _mock_callable_tool(): + """Description of some tool.""" + return 'result' + + +def _mock_mcp_client_session() -> McpClientSession: + mock_session = mock.create_autospec(spec=McpClientSession, instance=True) + + mock_tool_obj = McpTool( + name='mcp_tool', + description='Tool from session', + inputSchema={ + 'type': 'object', + 'properties': {'query': {'type': 'string'}}, + }, + ) + mock_result = mock.create_autospec(McpListToolsResult, instance=True) + mock_result.tools = [mock_tool_obj] + + mock_session.list_tools = mock.AsyncMock(return_value=mock_result) + + return mock_session + + +def _mock_mcp_tool(): + return McpTool( + name='mcp_tool', + description='A standalone mcp tool', + inputSchema={ + 'type': 'object', + 'properties': {'id': {'type': 'integer'}}, + }, + ) + + +def _mock_tool_dict() -> types.ToolDict: + return types.ToolDict( + function_declarations=[ + types.FunctionDeclarationDict( + name='mock_tool', description='Description of mock tool.' + ), + ], + google_maps=types.GoogleMaps(), + ) + + @pytest.mark.asyncio @mock.patch('google.adk.telemetry.tracing.otel_logger') @mock.patch('google.adk.telemetry.tracing.tracer') @@ -862,11 +916,18 @@ async def test_generate_content_span_with_experimental_semconv( role='model', parts=[types.Part(text='Response')] ) + tools = [ + _mock_callable_tool, + _mock_tool_dict(), + _mock_mcp_client_session(), + _mock_mcp_tool(), + ] + llm_request = LlmRequest( model='some-model', contents=[user_content1, user_content2], config=types.GenerateContentConfig( - system_instruction=system_instruction, + system_instruction=system_instruction, tools=tools ), ) llm_response = LlmResponse( @@ -923,6 +984,92 @@ async def test_generate_content_span_with_experimental_semconv( ], 'finish_reason': 'stop', }] + expected_tool_definitions = [ + { + 'name': '_mock_callable_tool', + 'description': 'Description of some tool.', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'mock_tool', + 'description': 'Description of mock tool.', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'google_maps', + 'type': 'google_maps', + }, + { + 'name': 'mcp_tool', + 'description': 'Tool from session', + 'parameters': { + 'type': 'object', + 'properties': {'query': {'type': 'string'}}, + }, + 'type': 'function', + }, + { + 'name': 'mcp_tool', + 'description': 'A standalone mcp tool', + 'parameters': { + 'type': 'object', + 'properties': {'id': {'type': 'integer'}}, + }, + 'type': 'function', + }, + ] + expected_tool_definitions_no_content = [ + { + 'name': '_mock_callable_tool', + 'description': 'Description of some tool.', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'mock_tool', + 'description': 'Description of mock tool.', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'google_maps', + 'type': 'google_maps', + }, + { + 'name': 'mcp_tool', + 'description': 'Tool from session', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'mcp_tool', + 'description': 'A standalone mcp tool', + 'parameters': None, + 'type': 'function', + }, + ] + expected_tool_definitions_json = ( + '[{"name":"_mock_callable_tool","description":"Description of some' + ' tool.","parameters":null,"type":"function"},{"name":"mock_tool","description":"Description' + ' of mock' + ' tool.","parameters":null,"type":"function"},{"name":"google_maps","type":"google_maps"},{"name":"mcp_tool","description":"Tool' + ' from' + ' session","parameters":{"type":"object","properties":{"query":{"type":"string"}}},"type":"function"},{"name":"mcp_tool","description":"A' + ' standalone mcp' + ' tool","parameters":{"type":"object","properties":{"id":{"type":"integer"}}},"type":"function"}]' + ) + + expected_tool_definitions_no_content_json = ( + '[{"name":"_mock_callable_tool","description":"Description of some' + ' tool.","parameters":null,"type":"function"},{"name":"mock_tool","description":"Description' + ' of mock' + ' tool.","parameters":null,"type":"function"},{"name":"google_maps","type":"google_maps"},{"name":"mcp_tool","description":"Tool' + ' from' + ' session","parameters":null,"type":"function"},{"name":"mcp_tool","description":"A' + ' standalone mcp tool","parameters":null,"type":"function"}]' + ) # Assert Span mock_tracer.start_as_current_span.assert_called_once_with( 'generate_content some-model' @@ -959,12 +1106,17 @@ async def test_generate_content_span_with_experimental_semconv( GEN_AI_OUTPUT_MESSAGES, '[{"role":"assistant","parts":[{"content":"Response","type":"text"}],"finish_reason":"stop"}]', ) - + mock_span.set_attribute.assert_any_call( + GEN_AI_TOOL_DEFINITIONS, expected_tool_definitions_json + ) else: all_attribute_calls = mock_span.set_attribute.call_args_list assert GEN_AI_SYSTEM_INSTRUCTIONS not in all_attribute_calls assert GEN_AI_INPUT_MESSAGES not in all_attribute_calls assert GEN_AI_OUTPUT_MESSAGES not in all_attribute_calls + mock_span.set_attribute.assert_any_call( + GEN_AI_TOOL_DEFINITIONS, expected_tool_definitions_no_content_json + ) # Assert Logs assert mock_otel_logger.emit.call_count == 1 @@ -996,10 +1148,17 @@ async def test_generate_content_span_with_experimental_semconv( assert attributes[GEN_AI_INPUT_MESSAGES] == expected_input_messages assert GEN_AI_OUTPUT_MESSAGES in attributes assert attributes[GEN_AI_OUTPUT_MESSAGES] == expected_output_messages + assert GEN_AI_TOOL_DEFINITIONS in attributes + assert attributes[GEN_AI_TOOL_DEFINITIONS] == expected_tool_definitions else: assert GEN_AI_SYSTEM_INSTRUCTIONS not in attributes assert GEN_AI_INPUT_MESSAGES not in attributes assert GEN_AI_OUTPUT_MESSAGES not in attributes + assert GEN_AI_TOOL_DEFINITIONS in attributes + assert ( + attributes[GEN_AI_TOOL_DEFINITIONS] + == expected_tool_definitions_no_content + ) assert GEN_AI_USAGE_INPUT_TOKENS in attributes assert attributes[GEN_AI_USAGE_INPUT_TOKENS] == 10 From 65d9a726c578b2231d540b868c415d51c8f9337b Mon Sep 17 00:00:00 2001 From: Carlos Chinchilla Corbacho <188046461+cchinchilla-dev@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:27:18 -0800 Subject: [PATCH 053/102] chore: add @override decorators to LoggingPlugin callback methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merge https://github.com/google/adk-python/pull/4572 ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: #4496 **2. Or, if no issue exists, describe the change:** **Problem:** `LoggingPlugin` overrides 12 callback methods from `BasePlugin` but none use the `@override` decorator. Every other plugin in the package (`DebugLoggingPlugin`, `ReplayPlugin`, `RecordingsPlugin`, `EnsureRetryOptionsPlugin`, `_RequestIntercepterPlugin`) already follows this practice. With `mypy --strict` enabled in `pyproject.toml`, missing `@override` means renamed or removed base-class methods would be silently missed in `LoggingPlugin` while being caught everywhere else. **Solution:** Import `override` from `typing_extensions` and decorate all 12 overridden callbacks. Purely additive: one import line and 12 decorators. No behavioral, API, or runtime change. ### Testing Plan **Unit Tests:** - [ ] I have added or updated unit tests for my change. - [x] All unit tests pass locally. No new tests are required — `@override` is a static-analysis-only decorator with no runtime effect. Ran `mypy` on the file before and after the change: same preexisting warnings, no new errors introduced. CI will validate via the existing test suite and linting checks. **Manual End-to-End (E2E) Tests:** Not applicable. This change adds only decorators with no runtime behavior. Verified by comparing `mypy` output before and after — no new errors. ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [ ] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [ ] I have manually tested my changes end-to-end. - [ ] Any dependent changes have been merged and published in downstream modules. ### Additional context I am the author of the original issue (#4496). A previous PR (#4544) was opened but is pending clarification, so I'm submitting this complete PR as the original issue author. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4572 from cchinchilla-dev:feat/add_override_decorators_to_loggingplugin_4496 142ed872bee39db782c1dccb84f40906ee849bb8 PiperOrigin-RevId: 875811966 --- src/google/adk/plugins/logging_plugin.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/google/adk/plugins/logging_plugin.py b/src/google/adk/plugins/logging_plugin.py index df37ee7e..b95e178d 100644 --- a/src/google/adk/plugins/logging_plugin.py +++ b/src/google/adk/plugins/logging_plugin.py @@ -19,6 +19,7 @@ from typing import Optional from typing import TYPE_CHECKING from google.genai import types +from typing_extensions import override from ..agents.base_agent import BaseAgent from ..agents.callback_context import CallbackContext @@ -66,6 +67,7 @@ class LoggingPlugin(BasePlugin): """ super().__init__(name) + @override async def on_user_message_callback( self, *, @@ -87,6 +89,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Branch: {invocation_context.branch}") return None + @override async def before_run_callback( self, *, invocation_context: InvocationContext ) -> Optional[types.Content]: @@ -99,6 +102,7 @@ class LoggingPlugin(BasePlugin): ) return None + @override async def on_event_callback( self, *, invocation_context: InvocationContext, event: Event ) -> Optional[Event]: @@ -122,6 +126,7 @@ class LoggingPlugin(BasePlugin): return None + @override async def after_run_callback( self, *, invocation_context: InvocationContext ) -> Optional[None]: @@ -134,6 +139,7 @@ class LoggingPlugin(BasePlugin): ) return None + @override async def before_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> Optional[types.Content]: @@ -145,6 +151,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Branch: {callback_context._invocation_context.branch}") return None + @override async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> Optional[types.Content]: @@ -154,6 +161,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Invocation ID: {callback_context.invocation_id}") return None + @override async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: @@ -179,6 +187,7 @@ class LoggingPlugin(BasePlugin): return None + @override async def after_model_callback( self, *, callback_context: CallbackContext, llm_response: LlmResponse ) -> Optional[LlmResponse]: @@ -206,6 +215,7 @@ class LoggingPlugin(BasePlugin): return None + @override async def before_tool_callback( self, *, @@ -221,6 +231,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Arguments: {self._format_args(tool_args)}") return None + @override async def after_tool_callback( self, *, @@ -237,6 +248,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Result: {self._format_args(result)}") return None + @override async def on_model_error_callback( self, *, @@ -251,6 +263,7 @@ class LoggingPlugin(BasePlugin): return None + @override async def on_tool_error_callback( self, *, From 7b7ddda46ca701952f002b2807b89dbef5322414 Mon Sep 17 00:00:00 2001 From: Keyur Joshi Date: Thu, 26 Feb 2026 11:40:14 -0800 Subject: [PATCH 054/102] feat: Add interface between optimization infra and LocalEvalService details: * Enables the use of ADK evaluations via LocalEvalService for optimizing agents. * Provides flexibility in choosing eval sets and eval cases for training and validation. * Converts ADK eval results into a compact format useful for whitebox agent optimization. Co-authored-by: Keyur Joshi PiperOrigin-RevId: 875818012 --- .../adk/optimization/local_eval_sampler.py | 367 +++++++++++++++++ src/google/adk/optimization/sampler.py | 5 +- .../optimization/local_eval_sampler_test.py | 383 ++++++++++++++++++ 3 files changed, 754 insertions(+), 1 deletion(-) create mode 100644 src/google/adk/optimization/local_eval_sampler.py create mode 100644 tests/unittests/optimization/local_eval_sampler_test.py diff --git a/src/google/adk/optimization/local_eval_sampler.py b/src/google/adk/optimization/local_eval_sampler.py new file mode 100644 index 00000000..b00c3428 --- /dev/null +++ b/src/google/adk/optimization/local_eval_sampler.py @@ -0,0 +1,367 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any +from typing import Literal +from typing import Optional + +from pydantic import BaseModel +from pydantic import Field + +from ..agents.llm_agent import Agent +from ..evaluation.base_eval_service import EvaluateConfig +from ..evaluation.base_eval_service import EvaluateRequest +from ..evaluation.base_eval_service import InferenceConfig +from ..evaluation.base_eval_service import InferenceRequest +from ..evaluation.base_eval_service import InferenceResult +from ..evaluation.eval_case import get_all_tool_calls_with_responses +from ..evaluation.eval_case import IntermediateData +from ..evaluation.eval_case import Invocation +from ..evaluation.eval_case import InvocationEvents +from ..evaluation.eval_config import EvalConfig +from ..evaluation.eval_config import get_eval_metrics_from_config +from ..evaluation.eval_metrics import EvalStatus +from ..evaluation.eval_result import EvalCaseResult +from ..evaluation.eval_sets_manager import EvalSetsManager +from ..evaluation.local_eval_service import LocalEvalService +from ..evaluation.simulation.user_simulator_provider import UserSimulatorProvider +from ..utils.context_utils import Aclosing +from .data_types import UnstructuredSamplingResult +from .sampler import Sampler + +logger = logging.getLogger("google_adk." + __name__) + + +def _log_eval_summary(eval_results: list[EvalCaseResult]): + """Logs a summary of eval results.""" + num_pass, num_fail, num_other = 0, 0, 0 + for eval_result in eval_results: + eval_result: EvalCaseResult + if eval_result.final_eval_status == EvalStatus.PASSED: + num_pass += 1 + elif eval_result.final_eval_status == EvalStatus.FAILED: + num_fail += 1 + else: + num_other += 1 + log_str = f"Evaluation summary: {num_pass} PASSED, {num_fail} FAILED" + if num_other: + log_str += f", {num_other} OTHER" + logger.info(log_str) + + +def extract_tool_call_data( + intermediate_data: IntermediateData | InvocationEvents, +) -> list[dict[str, Any]]: + """Extracts tool calls and their responses from intermediate data.""" + call_response_pairs = get_all_tool_calls_with_responses(intermediate_data) + result = [] + for tool_call, tool_response in call_response_pairs: + result.append({ + "name": tool_call.name, + "args": tool_call.args, + "response": tool_response.response if tool_response else None, + }) + return result + + +def extract_single_invocation_info( + invocation: Invocation, +) -> dict[str, Any]: + """Extracts useful information from a single invocation.""" + user_prompt = "" + for part in invocation.user_content.parts: + if part.text and not part.thought: + user_prompt += part.text + agent_response = "" + if invocation.final_response: + for part in invocation.final_response.parts: + if part.text and not part.thought: + agent_response += part.text + result = {"user_prompt": user_prompt, "agent_response": agent_response} + if invocation.intermediate_data: + tool_call_data = extract_tool_call_data(invocation.intermediate_data) + result["tool_calls"] = tool_call_data + return result + + +class LocalEvalSamplerConfig(BaseModel): + """Contains configuration options required by the LocalEvalServiceInterface.""" + + eval_config: EvalConfig = Field( + required=True, + description="The configuration for the evaluation.", + ) + + app_name: str = Field( + required=True, + description="The app name to use for evaluation.", + ) + + train_eval_set: str = Field( + required=True, + description="The name of the eval set to use for optimization.", + ) + + train_eval_case_ids: Optional[list[str]] = Field( + default=None, + description=( + "The ids of the eval cases to use for optimization. If not provided," + " all eval cases in the train_eval_set will be used." + ), + ) + + validation_eval_set: Optional[str] = Field( + default=None, + description=( + "The name of the eval set to use for validating the optimized agent." + " If not provided, the train_eval_set will also be used for" + " validation." + ), + ) + + validation_eval_case_ids: Optional[list[str]] = Field( + default=None, + description=( + "The ids of the eval cases to use for validating the optimized agent." + " If not provided, all eval cases in the validation_eval_set will be" + " used. If validation_eval_set is also not provided, all train eval" + " cases will be used." + ), + ) + + +class LocalEvalSampler(Sampler[UnstructuredSamplingResult]): + """Evaluates candidate agents with the ADK's LocalEvalService.""" + + def __init__( + self, + config: LocalEvalSamplerConfig, + eval_sets_manager: EvalSetsManager, + ): + self._config = config + self._eval_sets_manager = eval_sets_manager + + self._train_eval_set = self._config.train_eval_set + self._train_eval_case_ids = ( + self._config.train_eval_case_ids + or self._get_eval_case_ids(self._train_eval_set) + ) + + self._validation_eval_set = ( + self._config.validation_eval_set or self._train_eval_set + ) + if self._config.validation_eval_case_ids: + self._validation_eval_case_ids = self._config.validation_eval_case_ids + elif self._config.validation_eval_set: + self._validation_eval_case_ids = self._get_eval_case_ids( + self._validation_eval_set + ) + else: + self._validation_eval_case_ids = self._train_eval_case_ids + + def _get_selected_example_set_id( + self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET] + ) -> str: + """Returns the ID of the selected example set.""" + return { + Sampler.TRAIN_SET: self._train_eval_set, + Sampler.VALIDATION_SET: self._validation_eval_set, + }[example_set] + + def _get_all_example_ids( + self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET] + ) -> list[str]: + """Returns the IDs of all examples in the selected example set.""" + return { + Sampler.TRAIN_SET: self._train_eval_case_ids, + Sampler.VALIDATION_SET: self._validation_eval_case_ids, + }[example_set] + + def _get_eval_case_ids(self, eval_set_id: str) -> list[str]: + """Returns the ids of eval cases in the given eval set.""" + eval_set = self._eval_sets_manager.get_eval_set( + app_name=self._config.app_name, + eval_set_id=eval_set_id, + ) + if eval_set: + return [eval_case.eval_id for eval_case in eval_set.eval_cases] + else: + raise ValueError( + f"Eval set `{eval_set_id}` does not exist for app" + f" `{self._config.app_name}`." + ) + + async def _evaluate_agent( + self, + agent: Agent, + eval_set_id: str, + eval_case_ids: list[str], + ) -> list[EvalCaseResult]: + """Evaluates the agent on the requested eval cases and returns the results. + + Args: + agent: The agent to evaluate. + eval_set_id: The id of the eval set to use for evaluation. + eval_case_ids: The ids of the eval cases to use for evaluation. + + Returns: + A list of EvalCaseResult, one per eval case. + """ + # create the inference request + inference_request = InferenceRequest( + app_name=self._config.app_name, + eval_set_id=eval_set_id, + eval_case_ids=eval_case_ids, + inference_config=InferenceConfig(), + ) + + # create the LocalEvalService + user_simulator_provider = UserSimulatorProvider( + self._config.eval_config.user_simulator_config + ) + eval_service = LocalEvalService( + root_agent=agent, + eval_sets_manager=self._eval_sets_manager, + user_simulator_provider=user_simulator_provider, + ) + + # inference/sampling + async with Aclosing( + eval_service.perform_inference(inference_request=inference_request) + ) as agen: + inference_results: list[InferenceResult] = [ + inference_result async for inference_result in agen + ] + + # evaluation + eval_metrics = get_eval_metrics_from_config(self._config.eval_config) + evaluate_request = EvaluateRequest( + inference_results=inference_results, + evaluate_config=EvaluateConfig(eval_metrics=eval_metrics), + ) + async with Aclosing( + eval_service.evaluate(evaluate_request=evaluate_request) + ) as agen: + eval_results: list[EvalCaseResult] = [ + eval_result async for eval_result in agen + ] + + return eval_results + + def _extract_eval_data( + self, + eval_set_id: str, + eval_results: list[EvalCaseResult], + ) -> dict[str, dict[str, Any]]: + """Extracts evaluation data from the eval results.""" + eval_data = {} + for eval_result in eval_results: + eval_result_dict = {} + eval_case = self._eval_sets_manager.get_eval_case( + app_name=self._config.app_name, + eval_set_id=eval_set_id, + eval_case_id=eval_result.eval_id, + ) + if eval_case and eval_case.conversation_scenario: + eval_result_dict["conversation_scenario"] = ( + eval_case.conversation_scenario + ) + + per_invocation_results = [] + for ( + per_invocation_result + ) in eval_result.eval_metric_result_per_invocation: + eval_metric_results = [] + for eval_metric_result in per_invocation_result.eval_metric_results: + eval_metric_results.append({ + "metric_name": eval_metric_result.metric_name, + "score": round(eval_metric_result.score, 2), # accurate enough + "eval_status": eval_metric_result.eval_status.name, + }) + per_invocation_result_dict = { + "actual_invocation": extract_single_invocation_info( + per_invocation_result.actual_invocation + ), + "eval_metric_results": eval_metric_results, + } + if per_invocation_result.expected_invocation: + per_invocation_result_dict["expected_invocation"] = ( + extract_single_invocation_info( + per_invocation_result.expected_invocation + ) + ) + per_invocation_results.append(per_invocation_result_dict) + eval_result_dict["invocations"] = per_invocation_results + eval_data[eval_result.eval_id] = eval_result_dict + + return eval_data + + def get_train_example_ids(self) -> list[str]: + """Returns the UIDs of examples to use for training the agent.""" + return self._train_eval_case_ids + + def get_validation_example_ids(self) -> list[str]: + """Returns the UIDs of examples to use for validating the optimized agent.""" + return self._validation_eval_case_ids + + async def sample_and_score( + self, + candidate: Agent, + example_set: Literal[ + Sampler.TRAIN_SET, Sampler.VALIDATION_SET + ] = Sampler.VALIDATION_SET, + batch: Optional[list[str]] = None, + capture_full_eval_data: bool = False, + ) -> UnstructuredSamplingResult: + """Evaluates the candidate agent on the batch of examples using the ADK LocalEvalService. + + Args: + candidate: The candidate agent to be evaluated. + example_set: The set of examples to evaluate the candidate agent on. + Possible values are "train" and "validation". + batch: UIDs of examples to evaluate the candidate agent on. If not + provided, all examples from the chosen set will be used. + capture_full_eval_data: If false, it is enough to only calculate the + scores for each example. If true, this method should also capture all + other data required for optimizing the agent (e.g., outputs, + trajectories, and tool calls). + + Returns: + The evaluation results, containing the scores for each example and (if + requested) other data required for optimization. + """ + eval_set_id = self._get_selected_example_set_id(example_set) + if batch is None: + batch = self._get_all_example_ids(example_set) + + eval_results = await self._evaluate_agent(candidate, eval_set_id, batch) + _log_eval_summary(eval_results) + + scores = { + eval_result.eval_id: ( + 1.0 if eval_result.final_eval_status == EvalStatus.PASSED else 0.0 + ) + for eval_result in eval_results + } + + eval_data = ( + self._extract_eval_data(eval_set_id, eval_results) + if capture_full_eval_data + else None + ) + + return UnstructuredSamplingResult(scores=scores, data=eval_data) diff --git a/src/google/adk/optimization/sampler.py b/src/google/adk/optimization/sampler.py index 0a0ff45d..632e5d3d 100644 --- a/src/google/adk/optimization/sampler.py +++ b/src/google/adk/optimization/sampler.py @@ -32,6 +32,9 @@ class Sampler(ABC, Generic[SamplingResult]): to get evaluation results for the candidate agent on the batch of examples. """ + TRAIN_SET = "train" + VALIDATION_SET = "validation" + @abstractmethod def get_train_example_ids(self) -> list[str]: """Returns the UIDs of examples to use for training the agent.""" @@ -46,7 +49,7 @@ class Sampler(ABC, Generic[SamplingResult]): async def sample_and_score( self, candidate: Agent, - example_set: Literal["train", "validation"] = "validation", + example_set: Literal[TRAIN_SET, VALIDATION_SET] = VALIDATION_SET, batch: Optional[list[str]] = None, capture_full_eval_data: bool = False, ) -> SamplingResult: diff --git a/tests/unittests/optimization/local_eval_sampler_test.py b/tests/unittests/optimization/local_eval_sampler_test.py new file mode 100644 index 00000000..6ebd99cb --- /dev/null +++ b/tests/unittests/optimization/local_eval_sampler_test.py @@ -0,0 +1,383 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.agents.llm_agent import Agent +from google.adk.evaluation.base_eval_service import EvaluateConfig +from google.adk.evaluation.base_eval_service import EvaluateRequest +from google.adk.evaluation.base_eval_service import InferenceConfig +from google.adk.evaluation.base_eval_service import InferenceRequest +from google.adk.evaluation.base_eval_service import InferenceResult +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_case import InvocationEvent +from google.adk.evaluation.eval_case import InvocationEvents +from google.adk.evaluation.eval_config import EvalConfig +from google.adk.evaluation.eval_config import EvalMetric +from google.adk.evaluation.eval_metrics import EvalMetricResult +from google.adk.evaluation.eval_metrics import EvalMetricResultPerInvocation +from google.adk.evaluation.eval_metrics import EvalStatus +from google.adk.evaluation.eval_result import EvalCaseResult +from google.adk.evaluation.eval_sets_manager import EvalSetsManager +from google.adk.optimization.local_eval_sampler import _log_eval_summary +from google.adk.optimization.local_eval_sampler import extract_single_invocation_info +from google.adk.optimization.local_eval_sampler import extract_tool_call_data +from google.adk.optimization.local_eval_sampler import LocalEvalSampler +from google.adk.optimization.local_eval_sampler import LocalEvalSamplerConfig +from google.genai import types +import pytest + + +def test_log_eval_summary(mocker): + statuses = ( + [EvalStatus.PASSED] * 3 + + [EvalStatus.FAILED] * 2 + + [EvalStatus.NOT_EVALUATED] + ) + expected_log = "Evaluation summary: 3 PASSED, 2 FAILED, 1 OTHER" + + eval_results = [ + mocker.MagicMock(spec=EvalCaseResult, final_eval_status=status) + for status in statuses + ] + mock_logger = mocker.patch( + "google.adk.optimization.local_eval_sampler.logger" + ) + + _log_eval_summary(eval_results) + + mock_logger.info.assert_called_once_with(expected_log) + + +def test_extract_tool_call_data(): + # omitting IntermediateData tests as it is no longer used + # case 1: empty invocation events + assert not extract_tool_call_data(InvocationEvents()) + # case 2: multi call invocation events + multi_call_invocation_events = InvocationEvents( + invocation_events=[ + InvocationEvent( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + id="call_1", + name="tool_1", + args={"a": 1}, + ) + ), + types.Part( + function_call=types.FunctionCall( + id="call_2", + name="tool_2", + args={"b": 2}, + ) + ), + types.Part( + function_response=types.FunctionResponse( + id="call_1", + name="tool_1", + response={"result_1": "done"}, + ) + ), + types.Part( + function_response=types.FunctionResponse( + id="call_2", + name="tool_2", + response={"result_2": "done"}, + ) + ), + ] + ), + ) + ] + ) + expected_entries = [ + { + "name": "tool_1", + "args": {"a": 1}, + "response": {"result_1": "done"}, + }, + { + "name": "tool_2", + "args": {"b": 2}, + "response": {"result_2": "done"}, + }, + ] + result = extract_tool_call_data(multi_call_invocation_events) + # order is not guaranteed + for expected_entry in expected_entries: + assert expected_entry in result + assert len(result) == len(expected_entries) + + +def test_extract_single_invocation_info(): + invocation = Invocation( + user_content=types.Content( + parts=[ + types.Part(text="user thought", thought=True), + types.Part(text="Hello agent!"), + ] + ), + final_response=types.Content( + parts=[ + types.Part(text="agent thought", thought=True), + types.Part(text="Hello user!"), + ] + ), + ) + + result = extract_single_invocation_info(invocation) + + assert result == { + "user_prompt": "Hello agent!", + "agent_response": "Hello user!", + } + + +@pytest.mark.parametrize( + "config_kwargs, expected_attrs", + [ + ( + {"train_eval_set": "train_set"}, + { + "_train_eval_set": "train_set", + "_train_eval_case_ids": ["train_set_1", "train_set_2"], + "_validation_eval_set": "train_set", + "_validation_eval_case_ids": ["train_set_1", "train_set_2"], + }, + ), + ( + {"train_eval_set": "train_set", "train_eval_case_ids": ["t1"]}, + { + "_train_eval_case_ids": ["t1"], + "_validation_eval_case_ids": ["t1"], + }, + ), + ( + {"train_eval_set": "train_set", "validation_eval_set": "val_set"}, + { + "_validation_eval_set": "val_set", + "_validation_eval_case_ids": ["val_set_1", "val_set_2"], + }, + ), + ( + {"train_eval_set": "train_set", "validation_eval_case_ids": ["v1"]}, + { + "_validation_eval_case_ids": ["v1"], + }, + ), + ( + { + "train_eval_set": "train_set", + "train_eval_case_ids": ["t1"], + "validation_eval_set": "val_set", + "validation_eval_case_ids": ["v1"], + }, + { + "_train_eval_case_ids": ["t1"], + "_validation_eval_set": "val_set", + "_validation_eval_case_ids": ["v1"], + }, + ), + ], +) +def test_local_eval_service_interface_init( + mocker, config_kwargs, expected_attrs +): + mock_eval_sets_manager = mocker.MagicMock(spec=EvalSetsManager) + + def mock_get_eval_case_ids(self, eval_set_id): + return [f"{eval_set_id}_1", f"{eval_set_id}_2"] + + mocker.patch.object( + LocalEvalSampler, + "_get_eval_case_ids", + autospec=True, + side_effect=mock_get_eval_case_ids, + ) + + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), app_name="test_app", **config_kwargs + ) + interface = LocalEvalSampler(config, mock_eval_sets_manager) + + for attr, expected_value in expected_attrs.items(): + assert getattr(interface, attr) == expected_value + + +@pytest.mark.asyncio +async def test_evaluate_agent(mocker): + # Mocking LocalEvalService and its methods + mock_eval_service_cls = mocker.patch( + "google.adk.optimization.local_eval_sampler.LocalEvalService" + ) + mock_eval_service = mock_eval_service_cls.return_value + + # mocking inference + mock_inference_result = mocker.MagicMock(spec=InferenceResult) + + async def mock_perform_inference(*args, **kwargs): + yield mock_inference_result + + mock_eval_service.perform_inference.side_effect = mock_perform_inference + + # mocking evaluate + mock_eval_case_result = mocker.MagicMock(spec=EvalCaseResult) + + async def mock_evaluate(*args, **kwargs): + yield mock_eval_case_result + + mock_eval_service.evaluate.side_effect = mock_evaluate + + # mocking get_eval_metrics_from_config + mock_metrics = [EvalMetric(metric_name="test_metric")] + mocker.patch( + "google.adk.optimization.local_eval_sampler.get_eval_metrics_from_config", + return_value=mock_metrics, + ) + + mocker.patch("google.adk.evaluation.base_eval_service.EvaluateConfig") + + # Initialize Interface + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), + app_name="test_app", + train_eval_set="train_set", + train_eval_case_ids=["t1"], + ) + interface = LocalEvalSampler(config, mocker.MagicMock(spec=EvalSetsManager)) + + # Call _evaluate_agent + results = await interface._evaluate_agent( + mocker.MagicMock(spec=Agent), "train_set", ["t1"] + ) + + # Assertions + mock_eval_service.perform_inference.assert_called_once_with( + inference_request=InferenceRequest( + app_name="test_app", + eval_set_id="train_set", + eval_case_ids=["t1"], + inference_config=InferenceConfig(), + ) + ) + mock_eval_service.evaluate.assert_called_once_with( + evaluate_request=EvaluateRequest( + inference_results=[mock_inference_result], + evaluate_config=EvaluateConfig(eval_metrics=mock_metrics), + ) + ) + assert results == [mock_eval_case_result] + + +@pytest.mark.asyncio +async def test_extract_eval_data(mocker): + # Mock components + mock_eval_sets_manager = mocker.MagicMock(spec=EvalSetsManager) + mock_eval_case = mocker.MagicMock() + mock_eval_case.conversation_scenario = "test_scenario" + mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case + + # Mock per invocation result + mock_actual_invocation = mocker.MagicMock(spec=Invocation) + mock_expected_invocation = mocker.MagicMock(spec=Invocation) + mock_metric_result = mocker.MagicMock(spec=EvalMetricResult) + mock_metric_result.metric_name = "test_metric" + mock_metric_result.score = 0.854 # should be rounded to 0.85 + mock_metric_result.eval_status = EvalStatus.PASSED + + mock_per_inv_result = mocker.MagicMock(spec=EvalMetricResultPerInvocation) + mock_per_inv_result.actual_invocation = mock_actual_invocation + mock_per_inv_result.expected_invocation = mock_expected_invocation + mock_per_inv_result.eval_metric_results = [mock_metric_result] + + mock_eval_result = mocker.MagicMock(spec=EvalCaseResult) + mock_eval_result.eval_id = "t1" + mock_eval_result.eval_metric_result_per_invocation = [mock_per_inv_result] + + # Mock extract_single_invocation_info + mocker.patch( + "google.adk.optimization.local_eval_sampler.extract_single_invocation_info", + side_effect=[{"info": "actual"}, {"info": "expected"}], + ) + + # Initialize Interface + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), + app_name="test_app", + train_eval_set="train_set", + train_eval_case_ids=["t1"], + ) + interface = LocalEvalSampler(config, mock_eval_sets_manager) + + # Call _extract_eval_data + eval_data = interface._extract_eval_data("train_set", [mock_eval_result]) + + # Assertions + assert "t1" in eval_data + assert eval_data["t1"]["conversation_scenario"] == "test_scenario" + assert len(eval_data["t1"]["invocations"]) == 1 + inv = eval_data["t1"]["invocations"][0] + assert inv["actual_invocation"] == {"info": "actual"} + assert inv["expected_invocation"] == {"info": "expected"} + assert inv["eval_metric_results"] == [ + {"metric_name": "test_metric", "score": 0.85, "eval_status": "PASSED"} + ] + + +@pytest.mark.asyncio +async def test_sample_and_score(mocker): + # Mock results + mock_eval_result_1 = mocker.MagicMock(spec=EvalCaseResult) + mock_eval_result_1.eval_id = "t1" + mock_eval_result_1.final_eval_status = EvalStatus.PASSED + + mock_eval_result_2 = mocker.MagicMock(spec=EvalCaseResult) + mock_eval_result_2.eval_id = "t2" + mock_eval_result_2.final_eval_status = EvalStatus.FAILED + + eval_results = [mock_eval_result_1, mock_eval_result_2] + + # Initialize Interface + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), + app_name="test_app", + train_eval_set="train_set", + train_eval_case_ids=["t1", "t2"], + ) + interface = LocalEvalSampler(config, mocker.MagicMock(spec=EvalSetsManager)) + + # Patch internal methods + mocker.patch.object(interface, "_evaluate_agent", return_value=eval_results) + mock_log_summary = mocker.patch( + "google.adk.optimization.local_eval_sampler._log_eval_summary" + ) + mock_extract_data = mocker.patch.object( + interface, "_extract_eval_data", return_value={"t1": {}, "t2": {}} + ) + + # Call sample_and_score + result = await interface.sample_and_score( + mocker.MagicMock(spec=Agent), + example_set="train", + capture_full_eval_data=True, + ) + + # Assertions + assert result.scores == {"t1": 1.0, "t2": 0.0} + assert result.data == {"t1": {}, "t2": {}} + mock_log_summary.assert_called_once_with(eval_results) + mock_extract_data.assert_called_once_with("train_set", eval_results) From b4610fe1c6e3986b93b8f2fcae79ae436c4118cb Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Thu, 26 Feb 2026 12:23:58 -0800 Subject: [PATCH 055/102] chore: Add README for integrations folder Co-authored-by: Kathy Wu PiperOrigin-RevId: 875836816 --- src/google/adk/integrations/README.md | 35 +++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 src/google/adk/integrations/README.md diff --git a/src/google/adk/integrations/README.md b/src/google/adk/integrations/README.md new file mode 100644 index 00000000..56ab2b33 --- /dev/null +++ b/src/google/adk/integrations/README.md @@ -0,0 +1,35 @@ +# ADK Integrations + +This directory houses modules that integrate ADK with external tools and +services. The goal is to provide an organized and scalable way to extend ADK's +capabilities. + +Integrations with external systems, such as the Agent Registry, BigQuery, +ApiHub, etc., should be developed within sub-packages in this folder. This +centralization makes it easier for developers to find, use, and contribute to +various integrations. + +## What Belongs Here? + +* Code that connects ADK to other services, APIs, or tools. +* Modules that depend on third-party libraries not included in the core ADK + dependencies. + +## Guidelines for Contributions + +1. **Self-Contained Packages:** Each integration should reside in its own + sub-directory (e.g., `integrations/my_service/`). +2. **Internal Structure:** Integration sub-packages are free to manage their + own internal code structure and design patterns. They do not need to + strictly follow the core ADK framework's structure. +3. **Dependencies:** To keep the core ADK lightweight, dependencies required + for a specific integration must be optional. These should be defined as + "extras" in the `pyproject.toml`. Users will install them using commands + like `pip install "google-adk[my_service]"`. The extra name should match the + integration directory name. +4. **Lazy Importing:** Implement lazy importing within the integration code. If + a user tries to use an integration without installing the necessary extras, + catch the `ModuleNotFoundError` and raise a descriptive error message + guiding the user to the correct installation command. +5. **Documentation:** Ensure clear documentation is provided for each + integration, including setup, configuration, and usage examples. From d55afede1b367cf7ee051b609c8adec863d61565 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Thu, 26 Feb 2026 13:33:15 -0800 Subject: [PATCH 056/102] chore: Stop auto-triggering Release Please after cherry-picks Co-authored-by: Xuan Yang PiperOrigin-RevId: 875867583 --- .github/workflows/release-cherry-pick.yml | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/.github/workflows/release-cherry-pick.yml b/.github/workflows/release-cherry-pick.yml index ac5e5c08..ad324a08 100644 --- a/.github/workflows/release-cherry-pick.yml +++ b/.github/workflows/release-cherry-pick.yml @@ -1,5 +1,6 @@ # Step 3 (optional): Cherry-picks a commit from main to the release/candidate branch. # Use between step 1 and step 4 to include bug fixes in an in-progress release. +# Note: Does NOT auto-trigger release-please to preserve manual changelog edits. name: "Release: Cherry-pick" on: @@ -12,7 +13,6 @@ on: permissions: contents: write - actions: write jobs: cherry-pick: @@ -39,10 +39,5 @@ jobs: run: | git push origin release/candidate echo "Successfully cherry-picked commit to release/candidate" - - - name: Trigger Release Please - env: - GH_TOKEN: ${{ github.token }} - run: | - gh workflow run release-please.yml --repo ${{ github.repository }} --ref release/candidate - echo "Triggered Release Please workflow" + echo "Note: Release Please is NOT auto-triggered to preserve manual changelog edits." + echo "Run release-please.yml manually if you want to regenerate the changelog." From ebbc1147863956e85931f8d46abb0632e3d1cf67 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 26 Feb 2026 14:30:20 -0800 Subject: [PATCH 057/102] fix: Validate session before streaming instead of eagerly advancing the runner generator Co-authored-by: George Weale PiperOrigin-RevId: 875892569 --- src/google/adk/cli/adk_web_server.py | 63 +++++++++++++--------------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index e032178e..469b33fe 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -1668,46 +1668,39 @@ class AdkWebServer: async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE runner = await self.get_runner_async(req.app_name) - agen = runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - run_config=RunConfig(streaming_mode=stream_mode), - invocation_id=req.invocation_id, - ) - # Eagerly advance the generator to trigger session validation - # before the streaming response is created. This lets us return - # a proper HTTP 404 for missing sessions without a redundant - # get_session call — the Runner's single _get_or_create_session - # call is the only one that runs. - first_event = None - first_error = None - try: - first_event = await anext(agen) - except SessionNotFoundError as e: - await agen.aclose() - raise HTTPException(status_code=404, detail=str(e)) from e - except StopAsyncIteration: - await agen.aclose() - except Exception as e: - first_error = e + # Validate session existence before starting the stream. + # We check directly here instead of eagerly advancing the + # runner's async generator with anext(), because splitting + # generator consumption across two asyncio Tasks (request + # handler vs StreamingResponse) breaks OpenTelemetry context + # detachment. + if not runner.auto_create_session: + session = await self.session_service.get_session( + app_name=req.app_name, + user_id=req.user_id, + session_id=req.session_id, + ) + if not session: + raise HTTPException( + status_code=404, + detail=f"Session not found: {req.session_id}", + ) # Convert the events to properly formatted SSE async def event_generator(): - async with Aclosing(agen): + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + run_config=RunConfig(streaming_mode=stream_mode), + invocation_id=req.invocation_id, + ) + ) as agen: try: - if first_error: - raise first_error - - async def all_events(): - if first_event is not None: - yield first_event - async for event in agen: - yield event - - async for event in all_events(): + async for event in agen: # ADK Web renders artifacts from `actions.artifactDelta` # during part processing *and* during action processing # 1) the original event with `artifactDelta` cleared (content) From 8a3161202e4bac0bb8e8801b100f4403c1c75646 Mon Sep 17 00:00:00 2001 From: Ke Wang Date: Thu, 26 Feb 2026 14:46:00 -0800 Subject: [PATCH 058/102] feat(skill): Add BashTool PiperOrigin-RevId: 875899505 --- src/google/adk/tools/bash_tool.py | 150 ++++++++++++++++ tests/unittests/tools/test_bash_tool.py | 229 ++++++++++++++++++++++++ 2 files changed, 379 insertions(+) create mode 100644 src/google/adk/tools/bash_tool.py create mode 100644 tests/unittests/tools/test_bash_tool.py diff --git a/src/google/adk/tools/bash_tool.py b/src/google/adk/tools/bash_tool.py new file mode 100644 index 00000000..38e99643 --- /dev/null +++ b/src/google/adk/tools/bash_tool.py @@ -0,0 +1,150 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool to execute bash commands.""" + +from __future__ import annotations + +import dataclasses +import pathlib +import shlex +import subprocess +from typing import Any +from typing import Optional + +from google.genai import types + +from .. import features +from .base_tool import BaseTool +from .tool_context import ToolContext + + +@dataclasses.dataclass(frozen=True) +class BashToolPolicy: + """Configuration for allowed bash commands based on prefix matching. + + Set allowed_command_prefixes to ("*",) to allow all commands (default), + or explicitly list allowed prefixes. + """ + + allowed_command_prefixes: tuple[str, ...] = ("*",) + + +def _validate_command(command: str, policy: BashToolPolicy) -> Optional[str]: + """Validates a bash command against the permitted prefixes.""" + stripped = command.strip() + if not stripped: + return "Command is required." + + if "*" in policy.allowed_command_prefixes: + return None + + for prefix in policy.allowed_command_prefixes: + if stripped.startswith(prefix): + return None + + allowed = ", ".join(policy.allowed_command_prefixes) + return f"Command blocked. Permitted prefixes are: {allowed}" + + +@features.experimental(features.FeatureName.SKILL_TOOLSET) +class ExecuteBashTool(BaseTool): + """Tool to execute a validated bash command within a workspace directory.""" + + def __init__( + self, + *, + workspace: pathlib.Path | None = None, + policy: Optional[BashToolPolicy] = None, + ): + if workspace is None: + workspace = pathlib.Path.cwd() + policy = policy or BashToolPolicy() + allowed_hint = ( + "any command" + if "*" in policy.allowed_command_prefixes + else ( + "commands matching prefixes:" + f" {', '.join(policy.allowed_command_prefixes)}" + ) + ) + super().__init__( + name="execute_bash", + description=( + "Executes a bash command with the working directory set to the" + f" workspace. Allowed: {allowed_hint}. All commands require user" + " confirmation." + ), + ) + self._workspace = workspace + self._policy = policy + + def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters_json_schema={ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The bash command to execute.", + }, + }, + "required": ["command"], + }, + ) + + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + command = args.get("command") + if not command: + return {"error": "Command is required."} + + # Static validation. + error = _validate_command(command, self._policy) + if error: + return {"error": error} + + # Always request user confirmation. + if not tool_context.tool_confirmation: + tool_context.request_confirmation( + hint=f"Please approve or reject the bash command: {command}", + ) + tool_context.actions.skip_summarization = True + return { + "error": ( + "This tool call requires confirmation, please approve or reject." + ) + } + elif not tool_context.tool_confirmation.confirmed: + return {"error": "This tool call is rejected."} + + try: + result = subprocess.run( + shlex.split(command), + shell=False, + cwd=str(self._workspace), + capture_output=True, + text=True, + timeout=30, + ) + return { + "stdout": result.stdout, + "stderr": result.stderr, + "returncode": result.returncode, + } + except subprocess.TimeoutExpired: + return {"error": "Command timed out after 30 seconds."} diff --git a/tests/unittests/tools/test_bash_tool.py b/tests/unittests/tools/test_bash_tool.py new file mode 100644 index 00000000..e35c32b6 --- /dev/null +++ b/tests/unittests/tools/test_bash_tool.py @@ -0,0 +1,229 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.tools import bash_tool +from google.adk.tools import tool_context +from google.adk.tools.tool_confirmation import ToolConfirmation +import pytest + + +@pytest.fixture +def workspace(tmp_path): + """Creates a workspace mirroring the anthropics/skills PDF skill layout.""" + # Skill: pdf/ + skill_dir = tmp_path / "pdf" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\nname: pdf\n" + "description: Use this skill whenever the user wants to do" + " anything with PDF files.\n" + "---\n# PDF Processing Guide\n\n## Overview\n" + "This guide covers PDF processing operations." + ) + scripts = skill_dir / "scripts" + scripts.mkdir() + (scripts / "extract_form_structure.py").write_text( + "import sys; print(f'extracting from {sys.argv[1]}')" + ) + (scripts / "fill_pdf_form_with_annotations.py").write_text( + "print('filling form')" + ) + references = skill_dir / "references" + references.mkdir() + (references / "REFERENCE.md").write_text("# Reference\nDetailed docs.") + # A loose file at workspace root (not inside a skill). + (tmp_path / "sample.pdf").write_bytes(b"%PDF-1.4 fake") + return tmp_path + + +@pytest.fixture +def tool_context_no_confirmation(): + """ToolContext with no confirmation (initial call).""" + ctx = mock.create_autospec(tool_context.ToolContext, instance=True) + ctx.tool_confirmation = None + ctx.actions = mock.MagicMock() + return ctx + + +@pytest.fixture +def tool_context_confirmed(): + """ToolContext with confirmation approved.""" + ctx = mock.create_autospec(tool_context.ToolContext, instance=True) + confirmation = mock.create_autospec(ToolConfirmation, instance=True) + confirmation.confirmed = True + ctx.tool_confirmation = confirmation + ctx.actions = mock.MagicMock() + return ctx + + +@pytest.fixture +def tool_context_rejected(): + """ToolContext with confirmation rejected.""" + ctx = mock.create_autospec(tool_context.ToolContext, instance=True) + confirmation = mock.create_autospec(ToolConfirmation, instance=True) + confirmation.confirmed = False + ctx.tool_confirmation = confirmation + ctx.actions = mock.MagicMock() + return ctx + + +# --- _validate_command tests --- + + +class TestValidateCommand: + + def test_empty_command(self): + policy = bash_tool.BashToolPolicy() + assert bash_tool._validate_command("", policy) is not None + assert bash_tool._validate_command(" ", policy) is not None + + def test_default_policy_allows_everything(self): + policy = bash_tool.BashToolPolicy() + assert bash_tool._validate_command("rm -rf /", policy) is None + assert bash_tool._validate_command("cat /etc/passwd", policy) is None + assert bash_tool._validate_command("sudo curl", policy) is None + + def test_restricted_policy_allows_prefixes(self): + policy = bash_tool.BashToolPolicy(allowed_command_prefixes=("ls", "cat")) + assert bash_tool._validate_command("ls -la", policy) is None + assert bash_tool._validate_command("cat file.txt", policy) is None + + def test_restricted_policy_blocks_others(self): + policy = bash_tool.BashToolPolicy(allowed_command_prefixes=("ls", "cat")) + assert bash_tool._validate_command("rm -rf .", policy) is not None + assert bash_tool._validate_command("tree", policy) is not None + assert "Permitted prefixes are: ls, cat" in bash_tool._validate_command( + "tree", policy + ) + + +class TestExecuteBashTool: + + @pytest.mark.asyncio + async def test_requests_confirmation( + self, workspace, tool_context_no_confirmation + ): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "ls"}, + tool_context=tool_context_no_confirmation, + ) + assert "error" in result + assert "requires confirmation" in result["error"] + tool_context_no_confirmation.request_confirmation.assert_called_once() + + @pytest.mark.asyncio + async def test_rejected(self, workspace, tool_context_rejected): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "ls"}, tool_context=tool_context_rejected + ) + assert result == {"error": "This tool call is rejected."} + + @pytest.mark.asyncio + async def test_executes_when_confirmed( + self, workspace, tool_context_confirmed + ): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "ls"}, + tool_context=tool_context_confirmed, + ) + assert result["returncode"] == 0 + assert "pdf" in result["stdout"] + + @pytest.mark.asyncio + async def test_cat_skill_md(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "cat pdf/SKILL.md"}, + tool_context=tool_context_confirmed, + ) + assert "PDF Processing Guide" in result["stdout"] + + @pytest.mark.asyncio + async def test_python_script(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={ + "command": "python3 pdf/scripts/extract_form_structure.py test.pdf" + }, + tool_context=tool_context_confirmed, + ) + assert "extracting from test.pdf" in result["stdout"] + assert result["returncode"] == 0 + + @pytest.mark.asyncio + async def test_blocks_disallowed_by_policy( + self, workspace, tool_context_no_confirmation + ): + policy = bash_tool.BashToolPolicy(allowed_command_prefixes=("ls",)) + tool = bash_tool.ExecuteBashTool(workspace=workspace, policy=policy) + result = await tool.run_async( + args={"command": "rm -rf ."}, + tool_context=tool_context_no_confirmation, + ) + assert "error" in result + assert "Permitted prefixes are: ls" in result["error"] + tool_context_no_confirmation.request_confirmation.assert_not_called() + + @pytest.mark.asyncio + async def test_captures_stderr(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "python3 -c 'import sys; sys.stderr.write(\"err\")'"}, + tool_context=tool_context_confirmed, + ) + assert "err" in result["stderr"] + + @pytest.mark.asyncio + async def test_nonzero_returncode(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "python3 -c 'exit(42)'"}, + tool_context=tool_context_confirmed, + ) + assert result["returncode"] == 42 + + @pytest.mark.asyncio + async def test_timeout(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + with mock.patch( + "google.adk.tools.bash_tool.subprocess.run", + side_effect=__import__("subprocess").TimeoutExpired("cmd", 30), + ): + result = await tool.run_async( + args={"command": "python scripts/do_thing.py"}, + tool_context=tool_context_confirmed, + ) + assert "error" in result + assert "timed out" in result["error"].lower() + + @pytest.mark.asyncio + async def test_cwd_is_workspace(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "python3 -c 'import os; print(os.getcwd())'"}, + tool_context=tool_context_confirmed, + ) + assert result["stdout"].strip() == str(workspace) + + @pytest.mark.asyncio + async def test_no_command(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async(args={}, tool_context=tool_context_confirmed) + assert "error" in result + assert "required" in result["error"].lower() From 8ad8bc9b69e27f09ec67db14fc92d15c3c78bd53 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Thu, 26 Feb 2026 14:49:40 -0800 Subject: [PATCH 059/102] fix: Add a script to the sample skills agent Added a get_humidity script to demo the new RunSkillScript tool Co-authored-by: Kathy Wu PiperOrigin-RevId: 875901416 --- contributing/samples/skills_agent/agent.py | 8 ++++- .../skills/weather-skill/SKILL.md | 3 +- .../weather-skill/scripts/get_humidity.py | 29 +++++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 contributing/samples/skills_agent/skills/weather-skill/scripts/get_humidity.py diff --git a/contributing/samples/skills_agent/agent.py b/contributing/samples/skills_agent/agent.py index 9caf0ad7..9232545a 100644 --- a/contributing/samples/skills_agent/agent.py +++ b/contributing/samples/skills_agent/agent.py @@ -17,6 +17,7 @@ import pathlib from google.adk import Agent +from google.adk.code_executors.unsafe_local_code_executor import UnsafeLocalCodeExecutor from google.adk.skills import load_skill_from_dir from google.adk.skills import models from google.adk.tools.skill_toolset import SkillToolset @@ -44,7 +45,12 @@ weather_skill = load_skill_from_dir( pathlib.Path(__file__).parent / "skills" / "weather-skill" ) -my_skill_toolset = SkillToolset(skills=[greeting_skill, weather_skill]) +# WARNING: UnsafeLocalCodeExecutor has security concerns and should NOT +# be used in production environments. +my_skill_toolset = SkillToolset( + skills=[greeting_skill, weather_skill], + code_executor=UnsafeLocalCodeExecutor(), +) root_agent = Agent( model="gemini-2.5-flash", diff --git a/contributing/samples/skills_agent/skills/weather-skill/SKILL.md b/contributing/samples/skills_agent/skills/weather-skill/SKILL.md index 67d87105..6893ef67 100644 --- a/contributing/samples/skills_agent/skills/weather-skill/SKILL.md +++ b/contributing/samples/skills_agent/skills/weather-skill/SKILL.md @@ -4,4 +4,5 @@ description: A skill that provides weather information based on reference data. --- Step 1: Check 'references/weather_info.md' for the current weather. -Step 2: Provide the weather update to the user. +Step 2: If humidity is requested, use run 'scripts/get_humidity.py' with the `location` argument. +Step 3: Provide the update to the user. diff --git a/contributing/samples/skills_agent/skills/weather-skill/scripts/get_humidity.py b/contributing/samples/skills_agent/skills/weather-skill/scripts/get_humidity.py new file mode 100644 index 00000000..a2e1dc47 --- /dev/null +++ b/contributing/samples/skills_agent/skills/weather-skill/scripts/get_humidity.py @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + + +def get_humidity(location: str) -> str: + """Fetch live humidity for a given location. (Simulated)""" + print(f"Fetching live humidity for {location}...") + return "45% (Simulated)" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--location", type=str, default="Mountain View") + args = parser.parse_args() + + print(get_humidity(args.location)) From 1206addd6e4a806caffb1a1ba1523adb92f72098 Mon Sep 17 00:00:00 2001 From: "Wei (Jack) Sun" Date: Thu, 26 Feb 2026 16:00:16 -0800 Subject: [PATCH 060/102] chore: merge release v1.26.0 to main Merge https://github.com/google/adk-python/pull/4637 Syncs version bump and CHANGELOG from release v1.26.0 to main. Co-authored-by: Xuan Yang COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4637 from google:release/v1.26.0 427a983b18088bdc22272d02714393b0a779ecdf PiperOrigin-RevId: 875930970 --- .github/.release-please-manifest.json | 2 +- .github/release-please-config.json | 57 +++++++++++++---- CHANGELOG.md | 92 +++++++++++++++++++++++++++ src/google/adk/version.py | 2 +- 4 files changed, 140 insertions(+), 13 deletions(-) diff --git a/.github/.release-please-manifest.json b/.github/.release-please-manifest.json index 661ffa45..f97891a6 100644 --- a/.github/.release-please-manifest.json +++ b/.github/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.25.1" + ".": "1.26.0" } diff --git a/.github/release-please-config.json b/.github/release-please-config.json index 5395e5a4..053aab23 100644 --- a/.github/release-please-config.json +++ b/.github/release-please-config.json @@ -1,6 +1,6 @@ { "$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json", - "last-release-sha": "9f7d5b3f1476234e552b783415527cc4bac55b39", + "last-release-sha": "8f5428150d18ed732b66379c0acb806a9121c3cb", "packages": { ".": { "release-type": "python", @@ -10,16 +10,51 @@ "skip-github-release": true, "changelog-path": "CHANGELOG.md", "changelog-sections": [ - {"type": "feat", "section": "Features"}, - {"type": "fix", "section": "Bug Fixes"}, - {"type": "perf", "section": "Performance Improvements"}, - {"type": "refactor", "section": "Code Refactoring"}, - {"type": "docs", "section": "Documentation"}, - {"type": "test", "section": "Tests", "hidden": true}, - {"type": "build", "section": "Build System", "hidden": true}, - {"type": "ci", "section": "CI/CD", "hidden": true}, - {"type": "style", "section": "Styles", "hidden": true}, - {"type": "chore", "section": "Miscellaneous Chores", "hidden": true} + { + "type": "feat", + "section": "Features" + }, + { + "type": "fix", + "section": "Bug Fixes" + }, + { + "type": "perf", + "section": "Performance Improvements" + }, + { + "type": "refactor", + "section": "Code Refactoring" + }, + { + "type": "docs", + "section": "Documentation" + }, + { + "type": "test", + "section": "Tests", + "hidden": true + }, + { + "type": "build", + "section": "Build System", + "hidden": true + }, + { + "type": "ci", + "section": "CI/CD", + "hidden": true + }, + { + "type": "style", + "section": "Styles", + "hidden": true + }, + { + "type": "chore", + "section": "Miscellaneous Chores", + "hidden": true + } ] } } diff --git a/CHANGELOG.md b/CHANGELOG.md index c4867801..92a8197b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,97 @@ # Changelog +## [1.26.0](https://github.com/google/adk-python/compare/v1.25.1...v1.26.0) (2026-02-26) + + +### Features + +* **[Core]** + * Add intra-invocation compaction and token compaction pre-request ([485fcb8](https://github.com/google/adk-python/commit/485fcb84e3ca351f83416c012edcafcec479c1db)) + * Use `--memory_service_uri` in ADK CLI run command ([a7b5097](https://github.com/google/adk-python/commit/a7b509763c1732f0363e90952bb4c2672572d542)) + +* **[Models]** + * Add `/chat/completions` integration to `ApigeeLlm` ([9c4c445](https://github.com/google/adk-python/commit/9c4c44536904f5cf3301a5abb910a5666344a8c5)) + * Add `/chat/completions` streaming support to Apigee LLM ([121d277](https://github.com/google/adk-python/commit/121d27741684685c564e484704ae949c5f0807b1)) + * Expand LiteLlm supported models and add registry tests ([d5332f4](https://github.com/google/adk-python/commit/d5332f44347f44d60360e14205a2342a0c990d66)) + +* **[Tools]** + * Add `load_skill_from_dir()` method ([9f7d5b3](https://github.com/google/adk-python/commit/9f7d5b3f1476234e552b783415527cc4bac55b39)) + * Agent Skills spec compliance — validation, aliases, scripts, and auto-injection ([223d9a7](https://github.com/google/adk-python/commit/223d9a7ff52d8da702f1f436bd22e94ad78bd5da)) + * BigQuery ADK support for search catalog tool ([bef3f11](https://github.com/google/adk-python/commit/bef3f117b4842ce62760328304484cd26a1ec30a)) + * Make skill instruction optimizable and can adapt to user tasks ([21be6ad](https://github.com/google/adk-python/commit/21be6adcb86722a585b26f600c45c85e593b4ee0)) + * Pass trace context in MCP tool call's `_meta` field with OpenTelemetry propagator ([bcbfeba](https://github.com/google/adk-python/commit/bcbfeba953d46fca731b11542a00103cef374e57)) + +* **[Evals]** + * Introduce User Personas to the ADK evaluation framework ([6a808c6](https://github.com/google/adk-python/commit/6a808c60b38ad7140ddeb222887c6accc63edce9)) + +* **[Services]** + * Add generate/create modes for Vertex AI Memory Bank writes ([811e50a](https://github.com/google/adk-python/commit/811e50a0cbb181d502b9837711431ef78fca3f34)) + * Add support for memory consolidation via Vertex AI Memory Bank ([4a88804](https://github.com/google/adk-python/commit/4a88804ec7d17fb4031b238c362f27d240df0a13)) + +* **[A2A]** + * Add interceptor framework to `A2aAgentExecutor` ([87fcd77](https://github.com/google/adk-python/commit/87fcd77caa9672f219c12e5a0e2ff65cbbaaf6f3)) + +* **[Auth]** + * Add native support for `id_token` in OAuth2 credentials ([33f7d11](https://github.com/google/adk-python/commit/33f7d118b377b60f998c92944d2673679fddbc6e)) + * Support ID token exchange in `ServiceAccountCredentialExchanger` ([7be90db](https://github.com/google/adk-python/commit/7be90db24b41f1830e39ca3d7e15bf4dbfa5a304)), closes [#4458](https://github.com/google/adk-python/issues/4458) + +* **[Integrations]** + * Agent Registry in ADK ([abaa929](https://github.com/google/adk-python/commit/abaa92944c4cd43d206e2986d405d4ee07d45afe)) + * Add schema auto-upgrade, tool provenance, HITL tracing, and span hierarchy fix to BigQuery Agent Analytics plugin ([4260ef0](https://github.com/google/adk-python/commit/4260ef0c7c37ecdfea295fb0e1a933bb0df78bea)) + * Change default BigQuery table ID and update docstring ([7557a92](https://github.com/google/adk-python/commit/7557a929398ec2a1f946500d906cef5a4f86b5d1)) + * Update Agent Registry to create AgentCard from info in get agents endpoint ([c33d614](https://github.com/google/adk-python/commit/c33d614004a47d1a74951dd13628fd2300aeb9ef)) + +* **[Web]** + * Enable dependency injection for agent loader in FastAPI app gen ([34da2d5](https://github.com/google/adk-python/commit/34da2d5b26e82f96f1951334fe974a0444843720)) + + +### Bug Fixes + +* Add OpenAI strict JSON schema enforcement in LiteLLM ([2dbd1f2](https://github.com/google/adk-python/commit/2dbd1f25bdb1d88a6873d824b81b3dd5243332a4)), closes [#4573](https://github.com/google/adk-python/issues/4573) +* Add push notification config store to agent_to_a2a ([4ca904f](https://github.com/google/adk-python/commit/4ca904f11113c4faa3e17bb4a9662dca1f936e2e)), closes [#4126](https://github.com/google/adk-python/issues/4126) +* Add support for injecting a custom google.genai.Client into Gemini models ([48105b4](https://github.com/google/adk-python/commit/48105b49c5ab8e4719a66e7219f731b2cd293b00)), closes [#2560](https://github.com/google/adk-python/issues/2560) +* Add support for injecting a custom google.genai.Client into Gemini models ([c615757](https://github.com/google/adk-python/commit/c615757ba12093ba4a2ba19bee3f498fef91584c)), closes [#2560](https://github.com/google/adk-python/issues/2560) +* Check both `input_stream` parameter name and its annotation to decide whether it's a streaming tool that accept input stream ([d56cb41](https://github.com/google/adk-python/commit/d56cb4142c5040b6e7d13beb09123b8a59341384)) +* **deps:** Increase pydantic lower version to 2.7.0 ([dbd6420](https://github.com/google/adk-python/commit/dbd64207aebea8c5af19830a9a02d4c05d1d9469)) +* edit copybara and BUILD config for new adk/integrations folder (added with Agent Registry) ([37d52b4](https://github.com/google/adk-python/commit/37d52b4caf6738437e62fe804103efe4bde363a1)) +* Expand add_memory to accept MemoryEntry ([f27a9cf](https://github.com/google/adk-python/commit/f27a9cfb87caecb8d52967c50637ed5ad541cd07)) +* Fix pickling lock errors in McpSessionManager ([4e2d615](https://github.com/google/adk-python/commit/4e2d6159ae3552954aaae295fef3e09118502898)) +* fix typo in PlanReActPlanner instruction ([6d53d80](https://github.com/google/adk-python/commit/6d53d800d5f6dc5d4a3a75300e34d5a9b0f006f5)) +* handle UnicodeDecodeError when loading skills in ADK ([3fbc27f](https://github.com/google/adk-python/commit/3fbc27fa4ddb58b2b69ee1bea1e3a7b2514bd725)) +* Improve BigQuery Agent Analytics plugin reliability and code quality ([ea03487](https://github.com/google/adk-python/commit/ea034877ec15eef1be8f9a4be9fcd95446a3dc21)) +* Include list of skills in every message and remove list_skills tool from system instruction ([4285f85](https://github.com/google/adk-python/commit/4285f852d54670390b19302ed38306bccc0a7cee)) +* Invoke on_tool_error_callback for missing tools in live mode ([e6b601a](https://github.com/google/adk-python/commit/e6b601a2ab71b7e2df0240fd55550dca1eba8397)) +* Keep query params embedded in OpenAPI paths when using httpx ([ffbcc0a](https://github.com/google/adk-python/commit/ffbcc0a626deb24fe38eab402b3d6ace484115df)), closes [#4555](https://github.com/google/adk-python/issues/4555) +* Only relay the LiveRequest after tools is invoked ([b53bc55](https://github.com/google/adk-python/commit/b53bc555cceaa11dc53b42c9ca1d650592fb4365)) +* Parallelize tool resolution in LlmAgent.canonical_tools() ([7478bda](https://github.com/google/adk-python/commit/7478bdaa9817b0285b4119e8c739d7520373f719)) +* race condition in table creation for `DatabaseSessionService` ([fbe9ecc](https://github.com/google/adk-python/commit/fbe9eccd05e628daa67059ba2e6a0d03966b240d)) +* Re-export DEFAULT_SKILL_SYSTEM_INSTRUCTION to skills and skill/prompt.py to avoid breaking current users ([40ec134](https://github.com/google/adk-python/commit/40ec1343c2708e1cf0d39cd8b8a96f3729f843de)) +* Refactor LiteLLM streaming response parsing for compatibility with LiteLLM 1.81+ ([e8019b1](https://github.com/google/adk-python/commit/e8019b1b1b0b43dcc5fa23075942b31db502ffdd)), closes [#4225](https://github.com/google/adk-python/issues/4225) +* remove duplicate session GET when using API server, unbreak auto_session_create when using API server ([445dc18](https://github.com/google/adk-python/commit/445dc189e915ce5198e822ad7fadd6bb0880a95e)) +* Remove experimental decorators from user persona data models ([eccdf6d](https://github.com/google/adk-python/commit/eccdf6d01e70c37a1e5aa47c40d74469580365d2)) +* Replace the global DEFAULT_USER_PERSONA_REGISTRY with a function call to get_default_persona_registry ([2703613](https://github.com/google/adk-python/commit/2703613572a38bf4f9e25569be2ee678dc91b5b5)) +* **skill:** coloate default skill SI with skilltoolset ([fc1f1db](https://github.com/google/adk-python/commit/fc1f1db00562a79cd6c742cfd00f6267295c29a8)) +* Update agent_engine_sandbox_code_executor in ADK ([ee8d956](https://github.com/google/adk-python/commit/ee8d956413473d1bbbb025a470ad882c1487d8b8)) +* Update agent_engine_sandbox_code_executor in ADK ([dab80e4](https://github.com/google/adk-python/commit/dab80e4a8f3c5476f731335724bff5df3e6f3650)) +* Update sample skills agent to use weather-skill instead of weather_skill ([8f54281](https://github.com/google/adk-python/commit/8f5428150d18ed732b66379c0acb806a9121c3cb)) +* update Spanner query tools to async functions ([1dbcecc](https://github.com/google/adk-python/commit/1dbceccf36c28d693b0982b531a99877a3e75169)) +* use correct msg_out/msg_err keys for Agent Engine sandbox output ([b1e33a9](https://github.com/google/adk-python/commit/b1e33a90b4ba716d717e0488b84892b8a7f42aac)) +* Validate session before streaming instead of eagerly advancing the runner generator ([ab32f33](https://github.com/google/adk-python/commit/ab32f33e7418d452e65cf6f5b6cbfe1371600323)) +* **web:** allow session resume without new message ([30b2ed3](https://github.com/google/adk-python/commit/30b2ed3ef8ee6d3633743c0db00533683d3342d8)) + + +### Code Refactoring + +* Extract reusable function for building agent transfer instructions ([e1e0d63](https://github.com/google/adk-python/commit/e1e0d6361675e7b9a2c9b2523e3a72e2e5e7ce05)) +* Extract reusable private methods ([976a238](https://github.com/google/adk-python/commit/976a238544330528b4f9f4bea6c4e75ec13b33e1)) +* Extract reusable private methods ([42eeaef](https://github.com/google/adk-python/commit/42eeaef2b34c860f126c79c552435458614255ad)) +* Extract reusable private methods ([706f9fe](https://github.com/google/adk-python/commit/706f9fe74db0197e19790ca542d372ce46d0ae87)) + + +### Documentation + +* add `thinking_config` in `generate_content_config` in example agent ([c6b1c74](https://github.com/google/adk-python/commit/c6b1c74321faf62cc52d2518eb9ea0dcef050cde)) + ## [1.25.1](https://github.com/google/adk-python/compare/v1.25.0...v1.25.1) (2026-02-18) ### Bug Fixes diff --git a/src/google/adk/version.py b/src/google/adk/version.py index 1ce0bf5e..2e373f50 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.25.1" +__version__ = "1.26.0" From dff4c4404051b711c8be437ba0ae26ca2763df7d Mon Sep 17 00:00:00 2001 From: Lusha Wang Date: Fri, 27 Feb 2026 15:58:54 -0800 Subject: [PATCH 061/102] fix: Update agent_engine_sandbox_code_executor in ADK 1. For prototyping and testing purposes, sandbox name can be provided, and it will be used for all requests across the lifecycle of an agent 2. If no sandbox name is provided, agent engine name will be provided, and we will automatically create one sandbox per session, and the sandbox has TTL set for a year. If the sandbox stored in the session hits the TTL, it will not be in "STATE_RUNNING" so a new sandbox will be created. Co-authored-by: Lusha Wang PiperOrigin-RevId: 876450610 --- .../agent_engine_code_execution/README | 4 +- .../agent_engine_code_execution/agent.py | 7 +- .../agent_engine_sandbox_code_executor.py | 61 ++++++-- ...test_agent_engine_sandbox_code_executor.py | 131 ++++++++++++++++++ 4 files changed, 183 insertions(+), 20 deletions(-) diff --git a/contributing/samples/agent_engine_code_execution/README b/contributing/samples/agent_engine_code_execution/README index 8d5a4442..b0443ae2 100644 --- a/contributing/samples/agent_engine_code_execution/README +++ b/contributing/samples/agent_engine_code_execution/README @@ -7,9 +7,9 @@ This sample data science agent uses Agent Engine Code Execution Sandbox to execu ## How to use -* 1. Follow https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/code-execution/overview to create a code execution sandbox environment. +* 1. Follow https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create-an-agent-engine-instance to create an agent engine instance. Replace the AGENT_ENGINE_RESOURCE_NAME with the one you just created. A new sandbox environment under this agent engine instance will be created for each session with TTL of 1 year. But sandbox can only main its state for up to 14 days. This is the recommended usage for production environments. -* 2. Replace the SANDBOX_RESOURCE_NAME with the one you just created. If you dont want to create a new sandbox environment directly, the Agent Engine Code Execution Sandbox will create one for you by default using the AGENT_ENGINE_RESOURCE_NAME you specified, however, please ensure to clean up sandboxes after use; otherwise, it will consume quotas. +* 2. For testing or protyping purposes, create a sandbox environment by following this guide: https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create_a_sandbox. Replace the SANDBOX_RESOURCE_NAME with the one you just created. This will be used as the default sandbox environment for all the code executions throughout the lifetime of the agent. As the sandbox is re-used across sessions, all sessions will share the same Python environment and variable values." ## Sample prompt diff --git a/contributing/samples/agent_engine_code_execution/agent.py b/contributing/samples/agent_engine_code_execution/agent.py index d85989eb..a32e4ca4 100644 --- a/contributing/samples/agent_engine_code_execution/agent.py +++ b/contributing/samples/agent_engine_code_execution/agent.py @@ -85,11 +85,10 @@ When plotting trends, you should make sure to sort and order the data by the x-a """, code_executor=AgentEngineSandboxCodeExecutor( - # Replace with your sandbox resource name if you already have one. - sandbox_resource_name="SANDBOX_RESOURCE_NAME", + # Replace with your sandbox resource name if you already have one. Only use it for testing or prototyping purposes, because this will use the same sandbox for all requests. # "projects/vertex-agent-loadtest/locations/us-central1/reasoningEngines/6842889780301135872/sandboxEnvironments/6545148628569161728", - # Replace with agent engine resource name used for creating sandbox if - # sandbox_resource_name is not set. + sandbox_resource_name=None, + # Replace with agent engine resource name used for creating sandbox environment. agent_engine_resource_name="AGENT_ENGINE_RESOURCE_NAME", ), ) diff --git a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py index 69d1778a..071d59dc 100644 --- a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py +++ b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py @@ -38,10 +38,15 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): sandbox_resource_name: If set, load the existing resource name of the code interpreter extension instead of creating a new one. Format: projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789 + agent_engine_resource_name: The resource name of the agent engine to use + to create the code execution sandbox. Format: + projects/123/locations/us-central1/reasoningEngines/456 """ sandbox_resource_name: str = None + agent_engine_resource_name: str = None + def __init__( self, sandbox_resource_name: Optional[str] = None, @@ -67,30 +72,19 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): agent_engine_resource_name_pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' if sandbox_resource_name is not None: - self.sandbox_resource_name = sandbox_resource_name self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( sandbox_resource_name, sandbox_resource_name_pattern ) ) + self.sandbox_resource_name = sandbox_resource_name elif agent_engine_resource_name is not None: - from vertexai import types - self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( agent_engine_resource_name, agent_engine_resource_name_pattern ) ) - # @TODO - Add TTL for sandbox creation after it is available - # in SDK. - operation = self._get_api_client().agent_engines.sandboxes.create( - spec={'code_execution_environment': {}}, - name=agent_engine_resource_name, - config=types.CreateAgentEngineSandboxConfig( - display_name='default_sandbox' - ), - ) - self.sandbox_resource_name = operation.response.name + self.agent_engine_resource_name = agent_engine_resource_name else: raise ValueError( 'Either sandbox_resource_name or agent_engine_resource_name must be' @@ -103,6 +97,45 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): invocation_context: InvocationContext, code_execution_input: CodeExecutionInput, ) -> CodeExecutionResult: + # default to the sandbox resource name if set. + sandbox_name = self.sandbox_resource_name + if self.sandbox_resource_name is None: + from google.api_core import exceptions + from vertexai import types + + # use sandbox name stored in session if available. + sandbox_name = invocation_context.session.state.get('sandbox_name', None) + create_new_sandbox = False + if sandbox_name is None: + create_new_sandbox = True + else: + # Check if the sandbox is still running OR already expired due to ttl. + try: + sandbox = self._get_api_client().agent_engines.sandboxes.get( + name=sandbox_name + ) + if sandbox is None or sandbox.state != 'STATE_RUNNING': + create_new_sandbox = True + except exceptions.NotFound: + create_new_sandbox = True + + if create_new_sandbox: + # Create a new sandbox and assign it to sandbox_name. + operation = self._get_api_client().agent_engines.sandboxes.create( + spec={'code_execution_environment': {}}, + name=self.agent_engine_resource_name, + config=types.CreateAgentEngineSandboxConfig( + # VertexAiSessionService has a default TTL of 1 year, so we set + # the sandbox TTL to 1 year as well. For the current code + # execution sandbox, if it hasn't been used for 14 days, the + # state will be lost. + display_name='default_sandbox', + ttl='31536000s', + ), + ) + sandbox_name = operation.response.name + invocation_context.session.state['sandbox_name'] = sandbox_name + # Execute the code. input_data = { 'code': code_execution_input.code, @@ -119,7 +152,7 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): code_execution_response = ( self._get_api_client().agent_engines.sandboxes.execute_code( - name=self.sandbox_resource_name, + name=sandbox_name, input_data=input_data, ) ) diff --git a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py index 6022527f..9b27b82c 100644 --- a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py +++ b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py @@ -19,6 +19,7 @@ from unittest.mock import patch from google.adk.agents.invocation_context import InvocationContext from google.adk.code_executors.agent_engine_sandbox_code_executor import AgentEngineSandboxCodeExecutor from google.adk.code_executors.code_execution_utils import CodeExecutionInput +from google.adk.sessions.session import Session import pytest @@ -27,6 +28,10 @@ def mock_invocation_context() -> InvocationContext: """Fixture for a mock InvocationContext.""" mock = MagicMock(spec=InvocationContext) mock.invocation_id = "test-invocation-123" + session = MagicMock(spec=Session) + mock.session = session + session.state = {} + return mock @@ -118,3 +123,129 @@ class TestAgentEngineSandboxCodeExecutor: name="projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789", input_data={"code": 'print("hello world")'}, ) + + @patch("vertexai.Client") + def test_execute_code_recreates_sandbox_when_get_returns_none( + self, + mock_vertexai_client, + mock_invocation_context, + ): + # Setup Mocks + mock_api_client = MagicMock() + mock_vertexai_client.return_value = mock_api_client + + # Existing sandbox name stored in session, but get() will return None + existing_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/old" + mock_invocation_context.session.state = { + "sandbox_name": existing_sandbox_name + } + + # Mock get to return None (simulating missing/expired sandbox) + mock_api_client.agent_engines.sandboxes.get.return_value = None + + # Mock create operation to return a new sandbox resource name + operation_mock = MagicMock() + created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" + operation_mock.response.name = created_sandbox_name + mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock + + # Mock execute_code response + mock_response = MagicMock() + mock_json_output = MagicMock() + mock_json_output.mime_type = "application/json" + mock_json_output.data = json.dumps( + {"stdout": "recreated sandbox run", "stderr": ""} + ).encode("utf-8") + mock_json_output.metadata = None + mock_response.outputs = [mock_json_output] + mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( + mock_response + ) + + # Execute using agent_engine_resource_name so a sandbox can be created + executor = AgentEngineSandboxCodeExecutor( + agent_engine_resource_name=( + "projects/123/locations/us-central1/reasoningEngines/456" + ) + ) + code_input = CodeExecutionInput(code='print("hello world")') + result = executor.execute_code(mock_invocation_context, code_input) + + # Assert get was called for the existing sandbox + mock_api_client.agent_engines.sandboxes.get.assert_called_once_with( + name=existing_sandbox_name + ) + + # Assert create was called and session updated with new sandbox + mock_api_client.agent_engines.sandboxes.create.assert_called_once() + assert ( + mock_invocation_context.session.state["sandbox_name"] + == created_sandbox_name + ) + + # Assert execute_code used the created sandbox name + mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( + name=created_sandbox_name, + input_data={"code": 'print("hello world")'}, + ) + + @patch("vertexai.Client") + def test_execute_code_creates_sandbox_if_missing( + self, + mock_vertexai_client, + mock_invocation_context, + ): + # Setup Mocks + mock_api_client = MagicMock() + mock_vertexai_client.return_value = mock_api_client + + # Mock create operation to return a sandbox resource name + operation_mock = MagicMock() + created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" + operation_mock.response.name = created_sandbox_name + mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock + + # Mock execute_code response + mock_response = MagicMock() + mock_json_output = MagicMock() + mock_json_output.mime_type = "application/json" + mock_json_output.data = json.dumps( + {"stdout": "created sandbox run", "stderr": ""} + ).encode("utf-8") + mock_json_output.metadata = None + mock_response.outputs = [mock_json_output] + mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( + mock_response + ) + + # Ensure session.state behaves like a dict for storing sandbox_name + mock_invocation_context.session.state = {} + + # Execute using agent_engine_resource_name so a sandbox will be created + executor = AgentEngineSandboxCodeExecutor( + agent_engine_resource_name=( + "projects/123/locations/us-central1/reasoningEngines/456" + ), + sandbox_resource_name=None, + ) + code_input = CodeExecutionInput(code='print("hello world")') + result = executor.execute_code(mock_invocation_context, code_input) + + # Assert sandbox creation was called and session state updated + mock_api_client.agent_engines.sandboxes.create.assert_called_once() + create_call_kwargs = ( + mock_api_client.agent_engines.sandboxes.create.call_args.kwargs + ) + assert create_call_kwargs["name"] == ( + "projects/123/locations/us-central1/reasoningEngines/456" + ) + assert ( + mock_invocation_context.session.state["sandbox_name"] + == created_sandbox_name + ) + + # Assert execute_code used the created sandbox name + mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( + name=created_sandbox_name, + input_data={"code": 'print("hello world")'}, + ) From 8ddddc040ca10c75eca6752154773862069d9a1a Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Fri, 27 Feb 2026 17:11:46 -0800 Subject: [PATCH 062/102] chore: Use factory method to create invocation context in the runner Co-authored-by: Shangjie Chen PiperOrigin-RevId: 876474267 --- src/google/adk/runners.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 22011974..d6230752 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -1381,6 +1381,10 @@ class Runner: return event.content return None + def _create_invocation_context(self, **kwargs) -> InvocationContext: + """Creates an InvocationContext instance.""" + return InvocationContext(**kwargs) + def _new_invocation_context( self, session: Session, @@ -1415,7 +1419,7 @@ class Runner: if not isinstance(self.agent.code_executor, BuiltInCodeExecutor): self.agent.code_executor = BuiltInCodeExecutor() - return InvocationContext( + return self._create_invocation_context( artifact_service=self.artifact_service, session_service=self.session_service, memory_service=self.memory_service, From eb55eb7e7f0fa647d762205225c333dcd8a08dd0 Mon Sep 17 00:00:00 2001 From: nikkie Date: Sun, 1 Mar 2026 14:52:13 -0800 Subject: [PATCH 063/102] fix: typo in A2A EXPERIMENTAL warning Merge https://github.com/google/adk-python/pull/4462 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **2. Or, if no issue exists, describe the change:** **Problem:** I found a typo: arethemselves >/.../adk-python/src/google/adk/a2a/converters/event_converter.py:245: UserWarning: [EXPERIMENTAL] convert_a2a_message_to_event: ADK Implementation for A2A support (A2aAgentExecutor, RemoteA2aAgent and corresponding supporting components etc.) is in experimental mode and is subjected to breaking changes. A2A protocol and SDK arethemselves not experimental. Once it's stable enough the experimental mode will be removed. Your feedback is welcome. **Solution:** Just fix ### Testing Plan This is typo fix **Unit Tests:** N/A **Manual End-to-End (E2E) Tests:** N/A ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [ ] I have commented my code, particularly in hard-to-understand areas. - [ ] I have added tests that prove my fix is effective or that my feature works. - [ ] New and existing unit tests pass locally with my changes. - [ ] I have manually tested my changes end-to-end. - [ ] Any dependent changes have been merged and published in downstream modules. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4462 from ftnext:fix-typo-a2a-experimental-warning a117314eb524dd93351e17f2183eea080225e43e PiperOrigin-RevId: 877095788 --- src/google/adk/a2a/experimental.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/a2a/experimental.py b/src/google/adk/a2a/experimental.py index 77c31fde..7f331eb7 100644 --- a/src/google/adk/a2a/experimental.py +++ b/src/google/adk/a2a/experimental.py @@ -23,7 +23,7 @@ a2a_experimental = _make_feature_decorator( default_message=( "ADK Implementation for A2A support (A2aAgentExecutor, RemoteA2aAgent " "and corresponding supporting components etc.) is in experimental mode " - "and is subjected to breaking changes. A2A protocol and SDK are" + "and is subject to breaking changes. A2A protocol and SDK are " "themselves not experimental. Once it's stable enough the experimental " "mode will be removed. Your feedback is welcome." ), From 991abd44e94324093e72530b462ac80385841ee9 Mon Sep 17 00:00:00 2001 From: nikkie Date: Sun, 1 Mar 2026 22:31:51 -0800 Subject: [PATCH 064/102] =?UTF-8?q?chore:=20escape=20Click's=20wrapping=20?= =?UTF-8?q?Escape=20Click=E2=80=99s=20Wrapping=20in=20`adk=20deploy=20agen?= =?UTF-8?q?t=5Fengine`=20example?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merge https://github.com/google/adk-python/pull/4337 ### Link to Issue or Description of Change **2. Or, if no issue exists, describe the change:** **Problem:** `adk deploy agent_engine` help message loses newlines in docstring formatting スクリーンショット 2026-01-31 13 29 18 **Solution:** Add `\b` on a line by itself before the formatted block ### Testing Plan This is format improvement of help message, so I think there is no need to add test case. **Unit Tests:** - [ ] I have added or updated unit tests for my change. - [x] All unit tests pass locally. ``` % pytest tests/unittests/cli # Python 3.13.8 ======================= 260 passed, 140 warnings in 7.73s ======================== ``` **Manual End-to-End (E2E) Tests:** Ran `adk deploy agent_engine --help`, then saw ``` Example: # With Express Mode API Key adk deploy agent_engine --api_key=[api_key] my_agent # With Google Cloud Project and Region adk deploy agent_engine --project=[project] --region=[region] --display_name=[app_name] my_agent ``` ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [x] Any dependent changes have been merged and published in downstream modules. ### Additional context Same solution as #4258 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4337 from ftnext:escape-wrapping-deploy-agent-engine-example 09440380dd5b1e14a48151eba1b808def8ae3b6a PiperOrigin-RevId: 877205878 --- src/google/adk/cli/cli_tools_click.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index f55a8f10..b817d4b4 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -1967,9 +1967,11 @@ def cli_deploy_agent_engine( Example: + \b # With Express Mode API Key adk deploy agent_engine --api_key=[api_key] my_agent + \b # With Google Cloud Project and Region adk deploy agent_engine --project=[project] --region=[region] --display_name=[app_name] my_agent From 6a929af718fa77199d1eecc62b16c54beb1c8d84 Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 2 Mar 2026 09:25:00 -0800 Subject: [PATCH 065/102] fix: Prevent splitting of SSE events with artifactDelta for function resume requests When handling a /run_sse request that includes a functionCallEventId, do not split events that contain both content and artifactDelta Close #4487 Co-authored-by: George Weale PiperOrigin-RevId: 877435561 --- src/google/adk/cli/adk_web_server.py | 5 ++- tests/unittests/cli/test_fast_api.py | 51 ++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 469b33fe..afedb738 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -208,6 +208,8 @@ class RunAgentRequest(common.BaseModel): new_message: Optional[types.Content] = None streaming: bool = False state_delta: Optional[dict[str, Any]] = None + # for long-running function resume requests (e.g., OAuth callback) + function_call_event_id: Optional[str] = None # for resume long-running functions invocation_id: Optional[str] = None @@ -1707,7 +1709,8 @@ class AdkWebServer: # 2) a content-less "action-only" event carrying `artifactDelta` events_to_stream = [event] if ( - event.actions.artifact_delta + not req.function_call_event_id + and event.actions.artifact_delta and event.content and event.content.parts ): diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index d6ccf6e2..0ea28e66 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -1116,6 +1116,57 @@ def test_agent_run_sse_splits_artifact_delta( assert sse_events[1]["actions"]["artifactDelta"] == {"artifact.txt": 0} +def test_agent_run_sse_does_not_split_artifact_delta_for_function_resume( + test_app, create_test_session, monkeypatch +): + """Test /run_sse keeps artifactDelta with content for function resume flow.""" + info = create_test_session + + async def run_async_with_artifact_delta( + self, + *, + user_id: str, + session_id: str, + invocation_id: Optional[str] = None, + new_message: Optional[types.Content] = None, + state_delta: Optional[dict[str, Any]] = None, + run_config: Optional[RunConfig] = None, + ): + del user_id, session_id, invocation_id, new_message, state_delta, run_config + yield Event( + author="dummy agent", + invocation_id="invocation_id", + content=types.Content( + role="model", parts=[types.Part(text="LLM reply")] + ), + actions=EventActions(artifact_delta={"artifact.txt": 0}), + ) + + monkeypatch.setattr(Runner, "run_async", run_async_with_artifact_delta) + + payload = { + "app_name": info["app_name"], + "user_id": info["user_id"], + "session_id": info["session_id"], + "new_message": {"role": "user", "parts": [{"text": "Hello agent"}]}, + "streaming": True, + "functionCallEventId": "function-call-event-id", + } + + response = test_app.post("/run_sse", json=payload) + assert response.status_code == 200 + + sse_events = [ + json.loads(line.removeprefix("data: ")) + for line in response.text.splitlines() + if line.startswith("data: ") + ] + + assert len(sse_events) == 1 + assert sse_events[0]["content"]["parts"][0]["text"] == "LLM reply" + assert sse_events[0]["actions"]["artifactDelta"] == {"artifact.txt": 0} + + def test_agent_run_sse_yields_error_object_on_exception( test_app, create_test_session, monkeypatch ): From a61c7e388049ce647c41835fa5d123aab0c8208e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:21:08 -0800 Subject: [PATCH 066/102] chore(deps): bump flask from 3.1.1 to 3.1.3 in /contributing/samples/authn-adk-all-in-one in the pip group across 1 directory Merge https://github.com/google/adk-python/pull/4571 Bumps the pip group with 1 update in the /contributing/samples/authn-adk-all-in-one directory: [flask](https://github.com/pallets/flask). Updates `flask` from 3.1.1 to 3.1.3
Release notes

Sourced from flask's releases.

3.1.3

This is the Flask 3.1.3 security fix release, which fixes a security issue but does not otherwise change behavior and should not result in breaking changes compared to the latest feature release.

PyPI: https://pypi.org/project/Flask/3.1.3/ Changes: https://flask.palletsprojects.com/page/changes/#version-3-1-3

  • The session is marked as accessed for operations that only access the keys but not the values, such as in and len. GHSA-68rp-wp8r-4726

3.1.2

This is the Flask 3.1.2 fix release, which fixes bugs but does not otherwise change behavior and should not result in breaking changes compared to the latest feature release.

PyPI: https://pypi.org/project/Flask/3.1.2/ Changes: https://flask.palletsprojects.com/page/changes/#version-3-1-2 Milestone: https://github.com/pallets/flask/milestone/38?closed=1

  • stream_with_context does not fail inside async views. #5774
  • When using follow_redirects in the test client, the final state of session is correct. #5786
  • Relax type hint for passing bytes IO to send_file. #5776
Changelog

Sourced from flask's changelog.

Version 3.1.3

Released 2026-02-18

  • The session is marked as accessed for operations that only access the keys but not the values, such as in and len. :ghsa:68rp-wp8r-4726

Version 3.1.2

Released 2025-08-19

  • stream_with_context does not fail inside async views. :issue:5774
  • When using follow_redirects in the test client, the final state of session is correct. :issue:5786
  • Relax type hint for passing bytes IO to send_file. :issue:5776
Commits
  • 22d9247 release version 3.1.3
  • 089cb86 Merge commit from fork
  • c17f379 request context tracks session access
  • 27be933 start version 3.1.3
  • 4e652d3 Abort if the instance folder cannot be created (#5903)
  • 3d03098 Abort if the instance folder cannot be created
  • 407eb76 document using gevent for async (#5900)
  • ac5664d document using gevent for async
  • 4f79d5b Increase required flit_core version to 3.11 (#5865)
  • fe3b215 Increase required flit_core version to 3.11
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=flask&package-manager=pip&previous-version=3.1.1&new-version=3.1.3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore major version` will close this group update PR and stop Dependabot creating any more for the specific dependency's major version (unless you unignore this specific dependency's major version or upgrade to it yourself) - `@dependabot ignore minor version` will close this group update PR and stop Dependabot creating any more for the specific dependency's minor version (unless you unignore this specific dependency's minor version or upgrade to it yourself) - `@dependabot ignore ` will close this group update PR and stop Dependabot creating any more for the specific dependency (unless you unignore this specific dependency or upgrade to it yourself) - `@dependabot unignore ` will remove all of the ignore conditions of the specified dependency - `@dependabot unignore ` will remove the ignore condition of the specified dependency and ignore conditions You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/google/adk-python/network/alerts).
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4571 from google:dependabot/pip/contributing/samples/authn-adk-all-in-one/pip-81068183c6 6f577df16cb78beec2d26d873761173811f6ef5b PiperOrigin-RevId: 877461954 --- contributing/samples/authn-adk-all-in-one/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contributing/samples/authn-adk-all-in-one/requirements.txt b/contributing/samples/authn-adk-all-in-one/requirements.txt index 6cd3c4bb..777d8d52 100644 --- a/contributing/samples/authn-adk-all-in-one/requirements.txt +++ b/contributing/samples/authn-adk-all-in-one/requirements.txt @@ -1,5 +1,5 @@ google-adk==1.12 -Flask==3.1.1 +Flask==3.1.3 flask-cors==6.0.1 python-dotenv==1.1.1 PyJWT[crypto]==2.10.1 From f9c104faf73e2a002bb3092b50fb88f4eed78163 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Mon, 2 Mar 2026 10:28:46 -0800 Subject: [PATCH 067/102] fix: Preserve thought_signature in FunctionCall conversions between GenAI and A2A Close: https://github.com/google/adk-python/issues/4311 Co-authored-by: Xuan Yang PiperOrigin-RevId: 877465519 --- .../adk/a2a/converters/part_converter.py | 33 ++- .../a2a/converters/test_part_converter.py | 204 +++++++++++++++++- 2 files changed, 229 insertions(+), 8 deletions(-) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index 7b501f75..ce65a3de 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -104,10 +104,25 @@ def convert_a2a_part_to_genai_part( part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL ): + # Restore thought_signature if present + thought_signature = None + thought_sig_key = _get_adk_metadata_key('thought_signature') + if thought_sig_key in part.metadata: + sig_value = part.metadata[thought_sig_key] + if isinstance(sig_value, bytes): + thought_signature = sig_value + elif isinstance(sig_value, str): + try: + thought_signature = base64.b64decode(sig_value) + except Exception: + logger.warning( + 'Failed to decode thought_signature: %s', sig_value + ) return genai_types.Part( function_call=genai_types.FunctionCall.model_validate( part.data, by_alias=True - ) + ), + thought_signature=thought_signature, ) if ( part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] @@ -214,16 +229,22 @@ def convert_genai_part_to_a2a_part( # TODO once A2A defined how to service such information, migrate below # logic accordingly if part.function_call: + fc_metadata = { + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + } + # Preserve thought_signature if present + if part.thought_signature is not None: + fc_metadata[_get_adk_metadata_key('thought_signature')] = ( + base64.b64encode(part.thought_signature).decode('utf-8') + ) return a2a_types.Part( root=a2a_types.DataPart( data=part.function_call.model_dump( by_alias=True, exclude_none=True ), - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - }, + metadata=fc_metadata, ) ) diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index ec611fba..057b6c9e 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import json from unittest.mock import Mock from unittest.mock import patch @@ -74,7 +75,6 @@ class TestConvertA2aPartToGenaiPart: # Arrange test_bytes = b"test file content" # A2A FileWithBytes expects base64-encoded string - import base64 base64_encoded = base64.b64encode(test_bytes).decode("utf-8") a2a_part = a2a_types.Part( @@ -328,7 +328,6 @@ class TestConvertGenaiPartToA2aPart: assert isinstance(result.root, a2a_types.FilePart) assert isinstance(result.root.file, a2a_types.FileWithBytes) # A2A FileWithBytes now stores base64-encoded bytes to ensure round-trip compatibility - import base64 expected_base64 = base64.b64encode(test_bytes).decode("utf-8") assert result.root.file.bytes == expected_base64 @@ -841,3 +840,204 @@ class TestNewConstants: assert result.executable_code is not None assert result.executable_code.language == genai_types.Language.PYTHON assert result.executable_code.code == "print('Hello, World!')" + + +class TestThoughtSignaturePreservation: + """Tests for thought_signature preservation in function call conversions.""" + + def test_genai_function_call_with_thought_signature_to_a2a(self): + """Test that thought_signature is preserved when converting GenAI to A2A.""" + # Arrange + function_call = genai_types.FunctionCall( + id="fc_gemini3", + name="my_tool", + args={"document": "test content"}, + ) + genai_part = genai_types.Part( + function_call=function_call, + thought_signature=b"gemini3_signature_bytes", + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result.root, a2a_types.DataPart) + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + # thought_signature should be base64 encoded in metadata + thought_sig_key = _get_adk_metadata_key("thought_signature") + assert thought_sig_key in result.root.metadata + assert ( + base64.b64decode(result.root.metadata[thought_sig_key]) + == b"gemini3_signature_bytes" + ) + + def test_genai_function_call_without_thought_signature_to_a2a(self): + """Test function call without thought_signature doesn't add metadata key.""" + # Arrange + function_call = genai_types.FunctionCall( + id="fc_regular", + name="regular_tool", + args={}, + ) + genai_part = genai_types.Part(function_call=function_call) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result.root, a2a_types.DataPart) + # thought_signature key should not be present + thought_sig_key = _get_adk_metadata_key("thought_signature") + assert thought_sig_key not in result.root.metadata + + def test_a2a_function_call_with_thought_signature_to_genai(self): + """Test that thought_signature is restored when converting A2A to GenAI.""" + # Arrange + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data={ + "id": "fc_gemini3", + "name": "my_tool", + "args": {"document": "test content"}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + _get_adk_metadata_key("thought_signature"): ( + base64.b64encode(b"restored_signature").decode("utf-8") + ), + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.function_call is not None + assert result.function_call.name == "my_tool" + # thought_signature should be decoded back to bytes + assert result.thought_signature == b"restored_signature" + + def test_a2a_function_call_without_thought_signature_to_genai(self): + """Test function call without thought_signature returns None for it.""" + # Arrange + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data={ + "id": "fc_regular", + "name": "regular_tool", + "args": {}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.function_call is not None + assert result.function_call.name == "regular_tool" + # thought_signature should be None + assert result.thought_signature is None + + def test_function_call_with_thought_signature_round_trip(self): + """Test thought_signature is preserved in GenAI -> A2A -> GenAI round trip.""" + # Arrange + original_signature = b"round_trip_signature_test" + function_call = genai_types.FunctionCall( + id="fc_round_trip", + name="round_trip_tool", + args={"key": "value"}, + ) + original_part = genai_types.Part( + function_call=function_call, + thought_signature=original_signature, + ) + + # Act - Convert GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(original_part) + restored_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert restored_part is not None + assert restored_part.function_call is not None + assert restored_part.function_call.name == "round_trip_tool" + assert restored_part.thought_signature == original_signature + + def test_a2a_function_call_with_bytes_thought_signature_to_genai(self): + """Test that bytes thought_signature is used directly without decoding.""" + # Arrange - metadata contains raw bytes (not base64 encoded) + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data={ + "id": "fc_bytes", + "name": "bytes_tool", + "args": {}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + _get_adk_metadata_key( + "thought_signature" + ): b"raw_bytes_signature", + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.function_call is not None + # bytes should be used directly + assert result.thought_signature == b"raw_bytes_signature" + + def test_a2a_function_call_with_invalid_base64_thought_signature(self): + """Test that invalid base64 thought_signature logs warning and returns None.""" + # Arrange - metadata contains invalid base64 string + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data={ + "id": "fc_invalid", + "name": "invalid_sig_tool", + "args": {}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + _get_adk_metadata_key( + "thought_signature" + ): "not_valid_base64!!!", + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.function_call is not None + assert result.function_call.name == "invalid_sig_tool" + # thought_signature should be None due to decode failure + assert result.thought_signature is None From 89df5fcf883b599cf7bfe40bde35b8d86ab0146b Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 2 Mar 2026 10:39:54 -0800 Subject: [PATCH 068/102] feat: Enable output schema with tools for LiteLlm models LiteLlm provides built-in handling for tool and response format compatibility across different providers, allowing output schemas to be used reliably with tools for any LiteLlm instance Close #3969 Co-authored-by: George Weale PiperOrigin-RevId: 877471153 --- src/google/adk/utils/output_schema_utils.py | 11 +++++++++++ tests/unittests/utils/test_output_schema_utils.py | 6 ++++++ 2 files changed, 17 insertions(+) diff --git a/src/google/adk/utils/output_schema_utils.py b/src/google/adk/utils/output_schema_utils.py index 7c494f92..bb31d098 100644 --- a/src/google/adk/utils/output_schema_utils.py +++ b/src/google/adk/utils/output_schema_utils.py @@ -30,6 +30,17 @@ from .variant_utils import GoogleLLMVariant def can_use_output_schema_with_tools(model: Union[str, BaseLlm]) -> bool: """Returns True if output schema with tools is supported.""" + # LiteLLM handles tools + response_format compatibility per-provider: + # - Providers with native support (OpenAI, Azure): both passed directly + # - Providers without (Fireworks): auto-converted to json_tool_call + + # tool_choice enforcement + # This is strictly more reliable than the SetModelResponseTool + # prompt-based workaround. + from ..models.lite_llm import LiteLlm + + if isinstance(model, LiteLlm): + return True + model_string = model if isinstance(model, str) else model.model return ( diff --git a/tests/unittests/utils/test_output_schema_utils.py b/tests/unittests/utils/test_output_schema_utils.py index fc2f6fb5..cf759c99 100644 --- a/tests/unittests/utils/test_output_schema_utils.py +++ b/tests/unittests/utils/test_output_schema_utils.py @@ -15,6 +15,7 @@ from google.adk.models.anthropic_llm import Claude from google.adk.models.google_llm import Gemini +from google.adk.models.lite_llm import LiteLlm from google.adk.utils.output_schema_utils import can_use_output_schema_with_tools import pytest @@ -37,6 +38,11 @@ import pytest (Claude(model="claude-3.7-sonnet"), "1", False), (Claude(model="claude-3.7-sonnet"), "0", False), (Claude(model="claude-3.7-sonnet"), None, False), + (LiteLlm(model="openai/gpt-4o"), "1", True), + (LiteLlm(model="openai/gpt-4o"), "0", True), + (LiteLlm(model="openai/gpt-4o"), None, True), + (LiteLlm(model="anthropic/claude-3.7-sonnet"), None, True), + (LiteLlm(model="fireworks_ai/llama-v3p1-70b"), None, True), ], ) def test_can_use_output_schema_with_tools( From 9c451662819a6c7de71be71d12ea715b2fe74135 Mon Sep 17 00:00:00 2001 From: Shruti Nair Date: Mon, 2 Mar 2026 10:50:18 -0800 Subject: [PATCH 069/102] feat: execute-type param addition in GkeCodeExecutor COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4111 from SHRUTI6991:execute-type/param-addition b1ec403e0927767d17c11cb9e894f6ccb4f08dd2 PiperOrigin-RevId: 877476098 --- pyproject.toml | 1 + .../adk/code_executors/gke_code_executor.py | 94 +++++++- .../code_executors/test_gke_code_executor.py | 224 +++++++++++++++++- 3 files changed, 309 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 46064446..a3df7d63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,7 @@ extensions = [ "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+ "docker>=7.0.0", # For ContainerCodeExecutor "kubernetes>=29.0.0", # For GkeCodeExecutor + "k8s-agent-sandbox>=0.1.1.post2", # For GkeCodeExecutor sandbox mode "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent "litellm>=1.75.5, <2.0.0", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. diff --git a/src/google/adk/code_executors/gke_code_executor.py b/src/google/adk/code_executors/gke_code_executor.py index 1dc46878..b44aa193 100644 --- a/src/google/adk/code_executors/gke_code_executor.py +++ b/src/google/adk/code_executors/gke_code_executor.py @@ -19,12 +19,24 @@ import uuid import kubernetes as k8s from kubernetes.watch import Watch +from pydantic import field_validator +from typing_extensions import Literal +from typing_extensions import override +from typing_extensions import TYPE_CHECKING from ..agents.invocation_context import InvocationContext from .base_code_executor import BaseCodeExecutor from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult +try: + from agentic_sandbox import SandboxClient +except ImportError: + SandboxClient = None + +if TYPE_CHECKING: + from agentic_sandbox import SandboxClient + # Expose these for tests to monkeypatch. client = k8s.client config = k8s.config @@ -36,9 +48,19 @@ logger = logging.getLogger("google_adk." + __name__) class GkeCodeExecutor(BaseCodeExecutor): """Executes Python code in a secure gVisor-sandboxed Pod on GKE. - This executor securely runs code by dynamically creating a Kubernetes Job for - each execution request. The user's code is mounted via a ConfigMap, and the - Pod is hardened with a strict security context and resource limits. + This executor supports two modes of execution: 'job' and 'sandbox'. + + Job Mode (default): + Securely runs code by dynamically creating a Kubernetes Job for each execution + request. The user's code is mounted via a ConfigMap, and the Pod is hardened + with a strict security context and resource limits. + + Sandbox Mode: + Executes code using the Agent Sandbox Client. This mode requires additional + infrastructure to be deployed in the cluster, specifically: + - Agent-sandbox controller + - Sandbox templates (e.g., python-sandbox-template) + - Sandbox router and gateway Key Features: - Sandboxed execution using the gVisor runtime. @@ -70,6 +92,7 @@ class GkeCodeExecutor(BaseCodeExecutor): namespace: str = "default" image: str = "python:3.11-slim" timeout_seconds: int = 300 + executor_type: Literal["job", "sandbox"] = "job" cpu_requested: str = "200m" mem_requested: str = "256Mi" # The maximum CPU the container can use, in "millicores". 1000m is 1 full CPU core. @@ -79,6 +102,10 @@ class GkeCodeExecutor(BaseCodeExecutor): kubeconfig_path: str | None = None kubeconfig_context: str | None = None + # Sandbox constants + sandbox_gateway_name: str | None = None + sandbox_template: str | None = "python-sandbox-template" + _batch_v1: k8s.client.BatchV1Api _core_v1: k8s.client.CoreV1Api @@ -136,10 +163,46 @@ class GkeCodeExecutor(BaseCodeExecutor): self._batch_v1 = client.BatchV1Api() self._core_v1 = client.CoreV1Api() - def execute_code( - self, - invocation_context: InvocationContext, - code_execution_input: CodeExecutionInput, + @field_validator("executor_type") + @classmethod + def _check_sandbox_dependency(cls, v: str) -> str: + if v == "sandbox" and SandboxClient is None: + raise ImportError( + "k8s-agent-sandbox not found. To use Agent Sandbox, please install" + " google-adk with the extensions extra: pip install" + " google-adk[extensions]" + ) + return v + + def _execute_in_sandbox(self, code: str) -> CodeExecutionResult: + """Executes code using Agent Sandbox Client.""" + try: + with SandboxClient( + template_name=self.sandbox_template, + gateway_name=self.sandbox_gateway_name, + namespace=self.namespace, + ) as sandbox: + # Execute the code as a python script + sandbox.write("script.py", code) + result = sandbox.run("python3 script.py") + + return CodeExecutionResult(stdout=result.stdout, stderr=result.stderr) + except RuntimeError as e: + logger.error( + "SandboxClient failed to initialize or find gateway", exc_info=True + ) + raise RuntimeError(f"Sandbox infrastructure error: {e}") from e + except TimeoutError as e: + logger.error("Sandbox timed out", exc_info=True) + # Returning a result instead of raising allows the Agent to process + # the error gracefully. + return CodeExecutionResult(stderr=f"Sandbox timed out: {e}") + except Exception as e: + logger.error("Sandbox execution failed: %s", e, exc_info=True) + raise + + def _execute_as_job( + self, code: str, invocation_context: InvocationContext ) -> CodeExecutionResult: """Orchestrates the secure execution of a code snippet on GKE.""" job_name = f"adk-exec-{uuid.uuid4().hex[:10]}" @@ -150,7 +213,7 @@ class GkeCodeExecutor(BaseCodeExecutor): # 1. Create a ConfigMap to mount LLM-generated code into the Pod. # 2. Create a Job that runs the code from the ConfigMap. # 3. Set the Job as the ConfigMap's owner for automatic cleanup. - self._create_code_configmap(configmap_name, code_execution_input.code) + self._create_code_configmap(configmap_name, code) job_manifest = self._create_job_manifest( job_name, configmap_name, invocation_context ) @@ -162,7 +225,6 @@ class GkeCodeExecutor(BaseCodeExecutor): logger.info( f"Submitted Job '{job_name}' to namespace '{self.namespace}'." ) - logger.debug("Executing code:\n```\n%s\n```", code_execution_input.code) return self._watch_job_completion(job_name) except ApiException as e: @@ -186,6 +248,20 @@ class GkeCodeExecutor(BaseCodeExecutor): stderr=f"An unexpected executor error occurred: {e}" ) + @override + def execute_code( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + """Overrides the base method to route execution based on executor_type.""" + code = code_execution_input.code + if self.executor_type == "sandbox": + return self._execute_in_sandbox(code) + else: + # Fallback to existing GKE Job logic + return self._execute_as_job(code, invocation_context) + def _create_job_manifest( self, job_name: str, diff --git a/tests/unittests/code_executors/test_gke_code_executor.py b/tests/unittests/code_executors/test_gke_code_executor.py index 3d62fd8d..300780ca 100644 --- a/tests/unittests/code_executors/test_gke_code_executor.py +++ b/tests/unittests/code_executors/test_gke_code_executor.py @@ -71,19 +71,74 @@ class TestGkeCodeExecutor: assert executor.timeout_seconds == 300 assert executor.cpu_requested == "200m" assert executor.mem_limit == "512Mi" + assert executor.executor_type == "job" - def test_init_with_overrides(self): + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_init_with_overrides(self, mock_sandbox_client): """Tests that class attributes can be overridden at instantiation.""" executor = GkeCodeExecutor( namespace="test-ns", image="custom-python:latest", timeout_seconds=60, cpu_limit="1000m", + executor_type="sandbox", ) assert executor.namespace == "test-ns" assert executor.image == "custom-python:latest" assert executor.timeout_seconds == 60 assert executor.cpu_limit == "1000m" + assert executor.executor_type == "sandbox" + assert executor.sandbox_template == "python-sandbox-template" + + def test_init_backward_compatibility(self): + """Tests that the executor can be initialized with positional arguments.""" + executor = GkeCodeExecutor( + "/path/to/kubeconfig", + "test-context", + namespace="test-ns", + image="test-image", + timeout_seconds=100, + executor_type="job", + cpu_requested="100m", + mem_requested="128Mi", + cpu_limit="200m", + mem_limit="256Mi", + ) + assert executor.namespace == "test-ns" + assert executor.image == "test-image" + assert executor.timeout_seconds == 100 + assert executor.executor_type == "job" + assert executor.cpu_requested == "100m" + assert executor.mem_requested == "128Mi" + assert executor.cpu_limit == "200m" + assert executor.mem_limit == "256Mi" + assert executor.kubeconfig_path == "/path/to/kubeconfig" + assert executor.kubeconfig_context == "test-context" + + def test_init_partial_positional_args(self): + """Tests initialization with partial positional arguments.""" + executor = GkeCodeExecutor("/path/to/kubeconfig") + assert executor.kubeconfig_path == "/path/to/kubeconfig" + assert executor.kubeconfig_context is None + + def test_init_mixed_args(self): + """Tests initialization with mixed positional and keyword arguments.""" + executor = GkeCodeExecutor( + "/path/to/kubeconfig", + kubeconfig_context="test-context", + namespace="test-ns", + ) + assert executor.kubeconfig_path == "/path/to/kubeconfig" + + def test_init_sandbox_missing_dependency(self): + """Tests that init raises ImportError if k8s-agent-sandbox is missing.""" + with patch( + "google.adk.code_executors.gke_code_executor.SandboxClient", None + ): + with pytest.raises(ImportError, match="k8s-agent-sandbox not found"): + GkeCodeExecutor(executor_type="sandbox") + + GkeCodeExecutor(executor_type="sandbox") @patch("google.adk.code_executors.gke_code_executor.Watch") def test_execute_code_success( @@ -225,3 +280,170 @@ class TestGkeCodeExecutor: assert sec_context.allow_privilege_escalation is False assert sec_context.read_only_root_filesystem is True assert sec_context.capabilities.drop == ["ALL"] + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_code_forks_to_sandbox( + self, + mock_sandbox_client, + mock_invocation_context, + mock_k8s_clients, + ): + """Tests execute_code with executor_type='sandbox'. + + Verifies that execute_code uses SandboxClient when executor_type is set to + 'sandbox'. + """ + # Setup Sandbox mock + mock_sandbox_instance = ( + mock_sandbox_client.return_value.__enter__.return_value + ) + mock_run_result = MagicMock() + mock_run_result.stdout = "sandbox stdout" + mock_run_result.stderr = None + mock_sandbox_instance.run.return_value = mock_run_result + + # Instantiate with sandbox type + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput(code='print("sandbox")') + + # Execute + result = executor.execute_code(mock_invocation_context, code_input) + + # Assertions + assert result.stdout == "sandbox stdout" + + # Verify SandboxClient was used + mock_sandbox_client.assert_called_once() + mock_sandbox_instance.run.assert_called_once() + + # Verify Job path was NOT taken + mock_k8s_clients["batch_v1"].create_namespaced_job.assert_not_called() + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_code_sandbox_connection_error( + self, + mock_sandbox_client, + mock_invocation_context, + ): + """Tests handling of exceptions from SandboxClient.""" + # Setup Sandbox mock to raise exception + mock_sandbox_client.return_value.__enter__.side_effect = Exception( + "Connection failed" + ) + + # Instantiate with sandbox type + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput(code='print("sandbox")') + + # Execute & Assert + with pytest.raises(Exception, match="Connection failed"): + executor.execute_code(mock_invocation_context, code_input) + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_code_sandbox_runtime_error( + self, + mock_sandbox_client, + mock_invocation_context, + ): + """Tests handling of RuntimeError from SandboxClient.""" + mock_sandbox_client.return_value.__enter__.side_effect = RuntimeError( + "Gateway not found" + ) + + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput(code='print("sandbox")') + + with pytest.raises( + RuntimeError, match="Sandbox infrastructure error: Gateway not found" + ): + executor.execute_code(mock_invocation_context, code_input) + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_code_sandbox_timeout_error( + self, + mock_sandbox_client, + mock_invocation_context, + ): + """Tests handling of TimeoutError from SandboxClient.""" + mock_sandbox_client.return_value.__enter__.side_effect = TimeoutError( + "Execution timed out" + ) + + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput(code='print("sandbox")') + + result = executor.execute_code(mock_invocation_context, code_input) + + assert result.stdout == "" + assert "Sandbox timed out: Execution timed out" in result.stderr + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + @patch("google.adk.code_executors.gke_code_executor.Watch") + def test_execute_code_forks_to_job( + self, + mock_watch, + mock_sandbox_client, + mock_invocation_context, + mock_k8s_clients, + ): + """Tests that execute_code uses K8s Job when executor_type='job'.""" + # Setup K8s Job mocks (success path) + mock_job = MagicMock() + mock_job.status.succeeded = True + mock_watch.return_value.stream.return_value = [{"object": mock_job}] + + mock_pod = MagicMock() + mock_pod.metadata.name = "pod-1" + mock_k8s_clients["core_v1"].list_namespaced_pod.return_value.items = [ + mock_pod + ] + mock_k8s_clients["core_v1"].read_namespaced_pod_log.return_value = ( + "job stdout" + ) + + # Instantiate with job type + executor = GkeCodeExecutor(executor_type="job") + code_input = CodeExecutionInput(code='print("job")') + + # Execute + result = executor.execute_code(mock_invocation_context, code_input) + + # Assertions + assert result.stdout == "job stdout" + + # Verify Job path WAS taken + mock_k8s_clients["batch_v1"].create_namespaced_job.assert_called_once() + + # Verify SandboxClient was NOT used + mock_sandbox_client.assert_not_called() + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_in_sandbox_returns_stderr( + self, + mock_sandbox_client, + mock_invocation_context, + ): + """Tests that stderr from the sandbox run is propagated to the result.""" + # Setup Sandbox mock + mock_sandbox_instance = ( + mock_sandbox_client.return_value.__enter__.return_value + ) + mock_run_result = MagicMock() + mock_run_result.stdout = "" + mock_run_result.stderr = "oops\n" + mock_sandbox_instance.run.return_value = mock_run_result + + # Instantiate with sandbox type + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput( + code="import sys; print('oops', file=sys.stderr)" + ) + + # Execute + result = executor.execute_code(mock_invocation_context, code_input) + + # Assertions + assert result.stdout == "" + assert result.stderr == "oops\n" + mock_sandbox_instance.write.assert_called_with("script.py", code_input.code) + mock_sandbox_instance.run.assert_called_with("python3 script.py") From 90f28deea5cb9d564d3a7d2443c2adc96d0165a1 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Mon, 2 Mar 2026 11:41:12 -0800 Subject: [PATCH 070/102] chore: Allow custom parameter names for ToolContext injection Co-authored-by: Xuan Yang PiperOrigin-RevId: 877502827 --- src/google/adk/flows/llm_flows/functions.py | 4 +- .../tools/_automatic_function_calling_util.py | 8 +- src/google/adk/tools/crewai_tool.py | 10 +- src/google/adk/tools/function_tool.py | 17 +-- src/google/adk/tools/mcp_tool/mcp_tool.py | 19 ++-- src/google/adk/utils/context_utils.py | 60 ++++++++++ .../unittests/tools/mcp_tool/test_mcp_tool.py | 43 +++++++ tests/unittests/tools/test_crewai_tool.py | 32 ++++++ tests/unittests/tools/test_function_tool.py | 89 +++++++++++++++ tests/unittests/utils/test_context_utils.py | 107 ++++++++++++++++++ 10 files changed, 364 insertions(+), 25 deletions(-) create mode 100644 tests/unittests/utils/test_context_utils.py diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 6082e1a7..66274d3d 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -150,8 +150,8 @@ async def _call_tool_in_thread_pool( args_to_call = tool._preprocess_args(args) signature = inspect.signature(tool.func) valid_params = {param for param in signature.parameters} - if 'tool_context' in valid_params: - args_to_call['tool_context'] = tool_context + if tool._context_param_name in valid_params: + args_to_call[tool._context_param_name] = tool_context args_to_call = { k: v for k, v in args_to_call.items() if k in valid_params } diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py index a3097ad4..392e256b 100644 --- a/src/google/adk/tools/_automatic_function_calling_util.py +++ b/src/google/adk/tools/_automatic_function_calling_util.py @@ -151,9 +151,13 @@ def _remove_title(schema: Dict): def _get_pydantic_schema(func: Callable) -> Dict: + from ..utils.context_utils import find_context_parameter + fields_dict = _get_fields_dict(func) - if 'tool_context' in fields_dict.keys(): - fields_dict.pop('tool_context') + # Remove context parameter (detected by type or fallback to 'tool_context' name) + context_param = find_context_parameter(func) or 'tool_context' + if context_param in fields_dict.keys(): + fields_dict.pop(context_param) return pydantic.create_model(func.__name__, **fields_dict).model_json_schema() diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py index f8022e11..fca8ba9f 100644 --- a/src/google/adk/tools/crewai_tool.py +++ b/src/google/adk/tools/crewai_tool.py @@ -90,19 +90,19 @@ class CrewaiTool(FunctionTool): # remove arguments like `self` that are managed by the framework and not # intended to be passed through **kwargs. args_to_call.pop('self', None) - # We also remove `tool_context` that might have been passed in `args`, + # We also remove context param that might have been passed in `args`, # as it will be explicitly injected later if it's a valid parameter. - args_to_call.pop('tool_context', None) + args_to_call.pop(self._context_param_name, None) else: # For functions without **kwargs, use the original filtering. args_to_call = { k: v for k, v in args_to_call.items() if k in valid_params } - # Inject tool_context if it's an explicit parameter. This will add it + # Inject context if it's an explicit parameter. This will add it # or overwrite any value that might have been passed in `args`. - if 'tool_context' in valid_params: - args_to_call['tool_context'] = tool_context + if self._context_param_name in valid_params: + args_to_call[self._context_param_name] = tool_context # Check for missing mandatory arguments mandatory_args = self._get_mandatory_args() diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 6b8496dc..10e32a54 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -28,6 +28,7 @@ import pydantic from typing_extensions import override from ..utils.context_utils import Aclosing +from ..utils.context_utils import find_context_parameter from ._automatic_function_calling_util import build_function_declaration from .base_tool import BaseTool from .tool_context import ToolContext @@ -80,7 +81,9 @@ class FunctionTool(BaseTool): super().__init__(name=name, description=doc) self.func = func - self._ignore_params = ['tool_context', 'input_stream'] + # Detect context parameter by type annotation, fallback to 'tool_context' name + self._context_param_name = find_context_parameter(func) or 'tool_context' + self._ignore_params = [self._context_param_name, 'input_stream'] self._require_confirmation = require_confirmation @override @@ -162,8 +165,8 @@ class FunctionTool(BaseTool): signature = inspect.signature(self.func) valid_params = {param for param in signature.parameters} - if 'tool_context' in valid_params: - args_to_call['tool_context'] = tool_context + if self._context_param_name in valid_params: + args_to_call[self._context_param_name] = tool_context # Filter args_to_call to only include valid parameters for the function args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params} @@ -195,8 +198,8 @@ You could retry calling this tool, but it is IMPORTANT for you to provide all th if require_confirmation: if not tool_context.tool_confirmation: args_to_show = args_to_call.copy() - if 'tool_context' in args_to_show: - args_to_show.pop('tool_context') + if self._context_param_name in args_to_show: + args_to_show.pop(self._context_param_name) tool_context.request_confirmation( hint=( @@ -254,8 +257,8 @@ You could retry calling this tool, but it is IMPORTANT for you to provide all th args_to_call['input_stream'] = invocation_context.active_streaming_tools[ self.name ].stream - if 'tool_context' in signature.parameters: - args_to_call['tool_context'] = tool_context + if self._context_param_name in signature.parameters: + args_to_call[self._context_param_name] = tool_context # TODO: support tool confirmation for live mode. async with Aclosing(self.func(**args_to_call)) as agen: diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index f31768a0..bf279f52 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -41,6 +41,7 @@ from ...auth.auth_schemes import AuthScheme from ...auth.auth_tool import AuthConfig from ...features import FeatureName from ...features import is_feature_enabled +from ...utils.context_utils import find_context_parameter from .._gemini_schema_util import _to_gemini_schema from ..base_authenticated_tool import BaseAuthenticatedTool from ..tool_context import ToolContext @@ -242,14 +243,18 @@ class McpTool(BaseAuthenticatedTool): for param in signature.parameters.values() ) - if "tool_context" in valid_params or has_kwargs: - args_to_call["tool_context"] = tool_context + # Detect context parameter by type or fallback to 'tool_context' name + context_param = ( + find_context_parameter(self._require_confirmation) or "tool_context" + ) + if context_param in valid_params or has_kwargs: + args_to_call[context_param] = tool_context # Filter args_to_call only if there's no **kwargs if not has_kwargs: - # Add tool_context to valid_params if it was added to args_to_call - if "tool_context" in args_to_call: - valid_params.add("tool_context") + # Add context param to valid_params if it was added to args_to_call + if context_param in args_to_call: + valid_params.add(context_param) args_to_call = { k: v for k, v in args_to_call.items() if k in valid_params } @@ -264,10 +269,6 @@ class McpTool(BaseAuthenticatedTool): if require_confirmation: if not tool_context.tool_confirmation: - args_to_show = args.copy() - if "tool_context" in args_to_show: - args_to_show.pop("tool_context") - tool_context.request_confirmation( hint=( f"Please approve or reject the tool call {self.name}() by" diff --git a/src/google/adk/utils/context_utils.py b/src/google/adk/utils/context_utils.py index b47180cd..cb68d800 100644 --- a/src/google/adk/utils/context_utils.py +++ b/src/google/adk/utils/context_utils.py @@ -21,6 +21,66 @@ Please do not rely on the implementation details. from __future__ import annotations from contextlib import aclosing +import inspect +from typing import Any +from typing import Callable +from typing import get_args +from typing import get_origin +from typing import Union # Re-export aclosing for backward compatibility Aclosing = aclosing + + +def _is_context_type(annotation: Any) -> bool: + """Check if an annotation is the Context type. + + This checks if the annotation is exactly Context or a type alias of Context + (e.g., ToolContext, CallbackContext). Also handles Optional[Context] types. + + Args: + annotation: The type annotation to check. + + Returns: + True if the annotation is the Context type, False otherwise. + """ + from ..agents.context import Context + + if annotation is inspect.Parameter.empty: + return False + + # Handle Optional[Context] and Union types + origin = get_origin(annotation) + if origin is Union: + args = get_args(annotation) + return any( + _is_context_type(arg) for arg in args if not isinstance(arg, type(None)) + ) + + # Check if it's exactly the Context type (or an alias like ToolContext) + return annotation is Context + + +def find_context_parameter(func: Callable[..., Any]) -> str | None: + """Find the parameter name that has a Context type annotation. + + This function inspects the signature of a callable and returns the name + of the first parameter that is annotated with Context or a type alias of + Context (e.g., ToolContext, CallbackContext). + + Args: + func: The callable to inspect. + + Returns: + The parameter name if found, None otherwise. + """ + if func is None: + return None + try: + signature = inspect.signature(func) + except (ValueError, TypeError): + return None + for name, param in signature.parameters.items(): + if _is_context_type(param.annotation): + return name + return None diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index f38a8bbc..6d7fa0a3 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -17,6 +17,7 @@ from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from google.adk.agents.context import Context from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import HttpAuth @@ -972,3 +973,45 @@ class TestMCPTool: assert factory_calls[0][0] == "test_tool" # callback_context is the tool_context itself (ToolContext extends CallbackContext) assert factory_calls[0][1] is tool_context + + @pytest.mark.asyncio + async def test_run_async_require_confirmation_callable_with_context_type( + self, + ): + """Test require_confirmation callable with Context type annotation.""" + + async def _require_confirmation_func(param1: str, ctx: Context): + return True + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + require_confirmation=_require_confirmation_func, + ) + tool_context = Mock(spec=ToolContext) + tool_context.tool_confirmation = None + tool_context.request_confirmation = Mock() + args = {"param1": "test_value", "extra_arg": 123} + + with patch.object( + tool, "_invoke_callable", new_callable=AsyncMock + ) as mock_invoke_callable: + mock_invoke_callable.return_value = True + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Verify context is passed with detected parameter name 'ctx' + expected_args_to_call = { + "param1": "test_value", + "ctx": tool_context, + } + mock_invoke_callable.assert_called_once_with( + _require_confirmation_func, expected_args_to_call + ) + + assert result == { + "error": ( + "This tool call requires confirmation, please approve or reject." + ) + } + tool_context.request_confirmation.assert_called_once() diff --git a/tests/unittests/tools/test_crewai_tool.py b/tests/unittests/tools/test_crewai_tool.py index 9feb094b..a0028233 100644 --- a/tests/unittests/tools/test_crewai_tool.py +++ b/tests/unittests/tools/test_crewai_tool.py @@ -21,6 +21,7 @@ pytest.importorskip( "google.adk.tools.crewai_tool", reason="Requires Python 3.10+" ) +from google.adk.agents.context import Context from google.adk.agents.invocation_context import InvocationContext from google.adk.sessions.session import Session from google.adk.tools.crewai_tool import CrewaiTool @@ -52,6 +53,14 @@ def _crewai_tool_with_context(tool_context: ToolContext, *args, **kwargs): } +def _crewai_tool_with_context_type(ctx: Context, *args, **kwargs): + """CrewAI tool with Context type annotation.""" + return { + "search_query": kwargs.get("search_query"), + "context_present": bool(ctx), + } + + class MockCrewaiBaseTool: """Mock CrewAI BaseTool for testing.""" @@ -180,3 +189,26 @@ async def test_crewai_tool_get_declaration(): # Verify that the args_schema was used to build the declaration mock_crewai_tool.args_schema.model_json_schema.assert_called_once() + + +@pytest.mark.asyncio +async def test_crewai_tool_with_context_type_annotation(mock_tool_context): + """Test CrewaiTool with Context type annotation and custom parameter name.""" + mock_crewai_tool = MockCrewaiBaseTool(_crewai_tool_with_context_type) + tool = CrewaiTool( + mock_crewai_tool, + name="context_type_tool", + description="Context type tool", + ) + + # Verify the context parameter is detected by type + assert tool._context_param_name == "ctx" + + # Test that context is properly injected + result = await tool.run_async( + args={"search_query": "test query"}, + tool_context=mock_tool_context, + ) + + assert result["search_query"] == "test query" + assert result["context_present"] diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index 9b1d1abd..9c76529f 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -14,6 +14,7 @@ from unittest.mock import MagicMock +from google.adk.agents.context import Context from google.adk.agents.invocation_context import InvocationContext from google.adk.sessions.session import Session from google.adk.tools.function_tool import FunctionTool @@ -440,3 +441,91 @@ async def test_run_async_parameter_filtering(mock_tool_context): assert result == {"arg1": "test", "arg2": 42} # Explicitly verify that unexpected_param was filtered out and not passed to the function assert "unexpected_param" not in result + + +def test_context_param_detection_with_context_type(): + """Test that FunctionTool detects context parameter by Context type annotation.""" + + def my_tool(query: str, ctx: Context) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "ctx" + assert tool._ignore_params == ["ctx", "input_stream"] + + +def test_context_param_detection_with_tool_context_type(): + """Test that FunctionTool detects context parameter by ToolContext type annotation.""" + + def my_tool(query: str, tool_context: ToolContext) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "tool_context" + assert tool._ignore_params == ["tool_context", "input_stream"] + + +def test_context_param_detection_with_custom_name(): + """Test that FunctionTool detects context parameter with any name if type is Context.""" + + def my_tool(query: str, my_custom_context: Context) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "my_custom_context" + assert tool._ignore_params == ["my_custom_context", "input_stream"] + + +def test_context_param_detection_fallback_to_name(): + """Test that FunctionTool falls back to 'tool_context' name when no type annotation.""" + + def my_tool(query: str, tool_context) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "tool_context" + assert tool._ignore_params == ["tool_context", "input_stream"] + + +def test_context_param_detection_no_context(): + """Test that FunctionTool defaults to 'tool_context' when no context param exists.""" + + def my_tool(query: str, count: int) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "tool_context" + assert tool._ignore_params == ["tool_context", "input_stream"] + + +@pytest.mark.asyncio +async def test_run_async_with_custom_context_param_name(mock_tool_context): + """Test that run_async correctly injects context with custom parameter name.""" + + def my_tool(query: str, ctx: Context) -> dict: + return {"query": query, "has_context": ctx is not None} + + tool = FunctionTool(my_tool) + result = await tool.run_async( + args={"query": "test"}, + tool_context=mock_tool_context, + ) + + assert result == {"query": "test", "has_context": True} + + +@pytest.mark.asyncio +async def test_run_async_with_context_type_annotation(mock_tool_context): + """Test that run_async works with Context type annotation.""" + + async def async_tool(query: str, context: Context) -> dict: + return {"query": query, "context_type": type(context).__name__} + + tool = FunctionTool(async_tool) + result = await tool.run_async( + args={"query": "hello"}, + tool_context=mock_tool_context, + ) + + assert result["query"] == "hello" + assert result["context_type"] == "Context" diff --git a/tests/unittests/utils/test_context_utils.py b/tests/unittests/utils/test_context_utils.py new file mode 100644 index 00000000..a5e2d656 --- /dev/null +++ b/tests/unittests/utils/test_context_utils.py @@ -0,0 +1,107 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for context_utils module.""" + +from typing import Optional + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.context import Context +from google.adk.tools.tool_context import ToolContext +from google.adk.utils.context_utils import find_context_parameter + + +class TestFindContextParameter: + """Tests for find_context_parameter function.""" + + def test_find_context_parameter_with_context_type(self): + """Test detection of Context type annotation.""" + + def my_tool(query: str, ctx: Context) -> str: + return query + + assert find_context_parameter(my_tool) == 'ctx' + + def test_find_context_parameter_with_tool_context_type(self): + """Test detection of ToolContext type annotation.""" + + def my_tool(query: str, tool_context: ToolContext) -> str: + return query + + assert find_context_parameter(my_tool) == 'tool_context' + + def test_find_context_parameter_with_callback_context_type(self): + """Test detection of CallbackContext type annotation.""" + + def my_callback(ctx: CallbackContext) -> None: + pass + + assert find_context_parameter(my_callback) == 'ctx' + + def test_find_context_parameter_with_optional_context(self): + """Test detection of Optional[Context] type annotation.""" + + def my_tool(query: str, context: Optional[Context] = None) -> str: + return query + + assert find_context_parameter(my_tool) == 'context' + + def test_find_context_parameter_with_custom_name(self): + """Test that any parameter name works with Context type.""" + + def my_tool(query: str, my_custom_ctx: Context) -> str: + return query + + assert find_context_parameter(my_tool) == 'my_custom_ctx' + + def test_find_context_parameter_no_context(self): + """Test function without context parameter returns None.""" + + def my_tool(query: str, count: int) -> str: + return query + + assert find_context_parameter(my_tool) is None + + def test_find_context_parameter_no_annotations(self): + """Test function without type annotations returns None.""" + + def my_tool(query, ctx): + return query + + assert find_context_parameter(my_tool) is None + + def test_find_context_parameter_with_none_func(self): + """Test that None function returns None.""" + assert find_context_parameter(None) is None + + def test_find_context_parameter_returns_first_match(self): + """Test that first context parameter is returned if multiple exist.""" + + def my_tool(first_ctx: Context, second_ctx: Context) -> str: + return 'test' + + assert find_context_parameter(my_tool) == 'first_ctx' + + def test_find_context_parameter_with_mixed_params(self): + """Test context parameter detection with various other parameters.""" + + def my_tool( + query: str, + count: int, + ctx: Context, + optional_param: Optional[str] = None, + ) -> str: + return query + + assert find_context_parameter(my_tool) == 'ctx' From f324fa2d62442301ebb2e7974eb97ea870471410 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Mar 2026 11:50:54 -0800 Subject: [PATCH 071/102] fix: Propagate file names during A2A to/from Genai Part conversion This change updates the part_converter to ensure that the name field in a2a_types.FileWithUri and a2a_types.FileWithBytes is correctly mapped to the display_name field in genai_types.FileData and genai_types.Blob, respectively, during conversions between A2A and Genai Part types. Tests are updated to verify this propagation in both directions. PiperOrigin-RevId: 877507283 --- .../adk/a2a/converters/part_converter.py | 7 +++++- .../a2a/converters/test_part_converter.py | 22 +++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index ce65a3de..ef4a94fd 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -70,7 +70,9 @@ def convert_a2a_part_to_genai_part( if isinstance(part.file, a2a_types.FileWithUri): return genai_types.Part( file_data=genai_types.FileData( - file_uri=part.file.uri, mime_type=part.file.mime_type + file_uri=part.file.uri, + mime_type=part.file.mime_type, + display_name=part.file.name, ) ) @@ -79,6 +81,7 @@ def convert_a2a_part_to_genai_part( inline_data=genai_types.Blob( data=base64.b64decode(part.file.bytes), mime_type=part.file.mime_type, + display_name=part.file.name, ) ) else: @@ -188,6 +191,7 @@ def convert_genai_part_to_a2a_part( file=a2a_types.FileWithUri( uri=part.file_data.file_uri, mime_type=part.file_data.mime_type, + name=part.file_data.display_name, ) ) ) @@ -211,6 +215,7 @@ def convert_genai_part_to_a2a_part( file=a2a_types.FileWithBytes( bytes=base64.b64encode(part.inline_data.data).decode('utf-8'), mime_type=part.inline_data.mime_type, + name=part.inline_data.display_name, ) ) diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 057b6c9e..446e1185 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -55,7 +55,9 @@ class TestConvertA2aPartToGenaiPart: a2a_part = a2a_types.Part( root=a2a_types.FilePart( file=a2a_types.FileWithUri( - uri="gs://bucket/file.txt", mime_type="text/plain" + uri="gs://bucket/file.txt", + mime_type="text/plain", + name="my_file.txt", ) ) ) @@ -69,6 +71,7 @@ class TestConvertA2aPartToGenaiPart: assert result.file_data is not None assert result.file_data.file_uri == "gs://bucket/file.txt" assert result.file_data.mime_type == "text/plain" + assert result.file_data.display_name == "my_file.txt" def test_convert_file_part_with_bytes(self): """Test conversion of A2A FilePart with bytes to GenAI Part.""" @@ -80,7 +83,9 @@ class TestConvertA2aPartToGenaiPart: a2a_part = a2a_types.Part( root=a2a_types.FilePart( file=a2a_types.FileWithBytes( - bytes=base64_encoded, mime_type="text/plain" + bytes=base64_encoded, + mime_type="text/plain", + name="my_bytes.txt", ) ) ) @@ -95,6 +100,7 @@ class TestConvertA2aPartToGenaiPart: # The converter decodes base64 back to original bytes assert result.inline_data.data == test_bytes assert result.inline_data.mime_type == "text/plain" + assert result.inline_data.display_name == "my_bytes.txt" def test_convert_data_part_function_call(self): """Test conversion of A2A DataPart with function call metadata.""" @@ -296,7 +302,9 @@ class TestConvertGenaiPartToA2aPart: # Arrange genai_part = genai_types.Part( file_data=genai_types.FileData( - file_uri="gs://bucket/file.txt", mime_type="text/plain" + file_uri="gs://bucket/file.txt", + mime_type="text/plain", + display_name="my_file.txt", ) ) @@ -310,13 +318,18 @@ class TestConvertGenaiPartToA2aPart: assert isinstance(result.root.file, a2a_types.FileWithUri) assert result.root.file.uri == "gs://bucket/file.txt" assert result.root.file.mime_type == "text/plain" + assert result.root.file.name == "my_file.txt" def test_convert_inline_data_part(self): """Test conversion of GenAI inline_data Part to A2A Part.""" # Arrange test_bytes = b"test file content" genai_part = genai_types.Part( - inline_data=genai_types.Blob(data=test_bytes, mime_type="text/plain") + inline_data=genai_types.Blob( + data=test_bytes, + mime_type="text/plain", + display_name="my_bytes.txt", + ) ) # Act @@ -332,6 +345,7 @@ class TestConvertGenaiPartToA2aPart: expected_base64 = base64.b64encode(test_bytes).decode("utf-8") assert result.root.file.bytes == expected_base64 assert result.root.file.mime_type == "text/plain" + assert result.root.file.name == "my_bytes.txt" def test_convert_inline_data_part_with_video_metadata(self): """Test conversion of GenAI inline_data Part with video metadata to A2A Part.""" From b4bad26720ff6b5f6e7393e67611628b1c3f98a8 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 2 Mar 2026 13:32:48 -0800 Subject: [PATCH 072/102] chore: Update pydantic versions Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 877553194 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a3df7d63..0441c72d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ dependencies = [ "opentelemetry-resourcedetector-gcp>=1.9.0a0, <2.0.0", "opentelemetry-sdk>=1.36.0, <1.39.0", "pyarrow>=14.0.0", - "pydantic>=2.7.0, <3.0.0", # For data validation/models + "pydantic>=2.12.0, <3.0.0", # For data validation/models "python-dateutil>=2.9.0.post0, <3.0.0", # For Vertext AI Session Service "python-dotenv>=1.0.0, <2.0.0", # To manage environment variables "requests>=2.32.4, <3.0.0", From 72f3e7e1e00d93c632883027bf6d31a9095cd6c2 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Mar 2026 13:43:32 -0800 Subject: [PATCH 073/102] fix: update Bigtable query tools to async functions PiperOrigin-RevId: 877558234 --- src/google/adk/tools/bigtable/query_tool.py | 98 +++--- .../bigtable/test_bigtable_query_tool.py | 309 +++++++++--------- 2 files changed, 210 insertions(+), 197 deletions(-) diff --git a/src/google/adk/tools/bigtable/query_tool.py b/src/google/adk/tools/bigtable/query_tool.py index 63267f01..bf64b282 100644 --- a/src/google/adk/tools/bigtable/query_tool.py +++ b/src/google/adk/tools/bigtable/query_tool.py @@ -15,6 +15,7 @@ from __future__ import annotations """Tool to execute SQL queries against Bigtable.""" +import asyncio import json import logging from typing import Any @@ -32,7 +33,7 @@ logger = logging.getLogger("google_adk." + __name__) DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50 -def execute_sql( +async def execute_sql( project_id: str, instance_id: str, query: str, @@ -65,7 +66,7 @@ def execute_sql( Examples: Fetch data or insights from a table: - >>> execute_sql("my_project", "my_instance", + >>> await execute_sql("my_project", "my_instance", ... "SELECT * from mytable", credentials, config, tool_context) { "status": "SUCCESS", @@ -80,51 +81,54 @@ def execute_sql( """ del tool_context # Unused for now - try: - bt_client = client.get_bigtable_data_client( - project=project_id, credentials=credentials - ) - eqi = bt_client.execute_query( - query=query, - instance_id=instance_id, - parameters=parameters, - parameter_types=parameter_types, - ) - - rows: List[Dict[str, Any]] = [] - max_rows = ( - settings.max_query_result_rows - if settings and settings.max_query_result_rows > 0 - else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS - ) - counter = max_rows - truncated = False + def _execute_sql(): try: - for row in eqi: - if counter <= 0: - truncated = True - break - row_values = {} - for key, val in dict(row.fields).items(): - try: - # if the json serialization of the value succeeds, use it as is - json.dumps(val) - except (TypeError, ValueError, OverflowError): - val = str(val) - row_values[key] = val - rows.append(row_values) - counter -= 1 - finally: - eqi.close() + bt_client = client.get_bigtable_data_client( + project=project_id, credentials=credentials + ) + eqi = bt_client.execute_query( + query=query, + instance_id=instance_id, + parameters=parameters, + parameter_types=parameter_types, + ) - result = {"status": "SUCCESS", "rows": rows} - if truncated: - result["result_is_likely_truncated"] = True - return result + rows: List[Dict[str, Any]] = [] + max_rows = ( + settings.max_query_result_rows + if settings and settings.max_query_result_rows > 0 + else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS + ) + counter = max_rows + truncated = False + try: + for row in eqi: + if counter <= 0: + truncated = True + break + row_values = {} + for key, val in dict(row.fields).items(): + try: + # if the json serialization of the value succeeds, use it as is + json.dumps(val) + except (TypeError, ValueError, OverflowError): + val = str(val) + row_values[key] = val + rows.append(row_values) + counter -= 1 + finally: + eqi.close() - except Exception as ex: - logger.error("Bigtable query failed: %s", ex) - return { - "status": "ERROR", - "error_details": str(ex), - } + result = {"status": "SUCCESS", "rows": rows} + if truncated: + result["result_is_likely_truncated"] = True + return result + + except Exception as ex: + logger.exception("Bigtable query failed") + return { + "status": "ERROR", + "error_details": str(ex), + } + + return await asyncio.to_thread(_execute_sql) diff --git a/tests/unittests/tools/bigtable/test_bigtable_query_tool.py b/tests/unittests/tools/bigtable/test_bigtable_query_tool.py index abcef88e..46b65a3a 100644 --- a/tests/unittests/tools/bigtable/test_bigtable_query_tool.py +++ b/tests/unittests/tools/bigtable/test_bigtable_query_tool.py @@ -14,136 +14,191 @@ from __future__ import annotations -from typing import Optional from unittest import mock -from google.adk.tools.base_tool import BaseTool -from google.adk.tools.bigtable import BigtableCredentialsConfig -from google.adk.tools.bigtable.bigtable_toolset import BigtableToolset +from google.adk.tools.bigtable import client from google.adk.tools.bigtable.query_tool import execute_sql from google.adk.tools.bigtable.settings import BigtableToolSettings from google.adk.tools.tool_context import ToolContext from google.auth.credentials import Credentials -from google.cloud import bigtable from google.cloud.bigtable.data.execute_query import ExecuteQueryIterator import pytest -def test_execute_sql_basic(): - """Test execute_sql tool basic functionality.""" +@pytest.mark.asyncio +@pytest.mark.parametrize( + ( + "query", + "settings", + "parameters", + "parameter_types", + "execute_query_side_effect", + "iterator_yield_values", + "expected_result", + ), + [ + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(), + None, + None, + None, + [{"col1": "val1", "col2": 123}], + {"status": "SUCCESS", "rows": [{"col1": "val1", "col2": 123}]}, + id="basic", + ), + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(max_query_result_rows=1), + None, + None, + None, + [{"col1": "val1"}, {"col1": "val2"}], + { + "status": "SUCCESS", + "rows": [{"col1": "val1"}], + "result_is_likely_truncated": True, + }, + id="truncated", + ), + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(), + None, + None, + Exception("Test error"), + None, + {"status": "ERROR", "error_details": "Test error"}, + id="error", + ), + pytest.param( + "SELECT * FROM my_table WHERE col1 = @param1", + BigtableToolSettings(), + {"param1": "val1"}, + {"param1": "string"}, + None, + [{"col1": "val1"}], + {"status": "SUCCESS", "rows": [{"col1": "val1"}]}, + id="with_parameters", + ), + pytest.param( + "SELECT * FROM my_table WHERE 1=0", + BigtableToolSettings(), + None, + None, + None, + [], + {"status": "SUCCESS", "rows": []}, + id="empty_results", + ), + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(max_query_result_rows=10), + None, + None, + None, + [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], + { + "status": "SUCCESS", + "rows": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], + }, + id="multiple_rows", + ), + pytest.param( + "SELECT * FROM my_table", + None, + None, + None, + None, + [{"id": i} for i in range(51)], + { + "status": "SUCCESS", + "rows": [{"id": i} for i in range(50)], + "result_is_likely_truncated": True, + }, + id="settings_none_uses_default", + ), + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(), + None, + None, + None, + Exception("Iteration failed"), + {"status": "ERROR", "error_details": "Iteration failed"}, + id="iteration_error_calls_close", + ), + ], +) +async def test_execute_sql( + query, + settings, + parameters, + parameter_types, + execute_query_side_effect, + iterator_yield_values, + expected_result, +): + """Test execute_sql tool functionality.""" project = "my_project" instance_id = "my_instance" - query = "SELECT * FROM my_table" credentials = mock.create_autospec(Credentials, instance=True) tool_context = mock.create_autospec(ToolContext, instance=True) - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_data_client" - ) as mock_get_client: + with mock.patch.object(client, "get_bigtable_data_client") as mock_get_client: mock_client = mock.MagicMock() mock_get_client.return_value = mock_client - mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) - mock_client.execute_query.return_value = mock_iterator - # Mock row data - mock_row = mock.MagicMock() - mock_row.fields = {"col1": "val1", "col2": 123} - mock_iterator.__iter__.return_value = [mock_row] + if execute_query_side_effect: + mock_client.execute_query.side_effect = execute_query_side_effect + else: + mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) + mock_client.execute_query.return_value = mock_iterator - result = execute_sql( + if isinstance(iterator_yield_values, Exception): + + def raise_error(): + yield mock.MagicMock() + raise iterator_yield_values + + mock_iterator.__iter__.side_effect = raise_error + else: + mock_rows = [] + for fields in iterator_yield_values: + mock_row = mock.MagicMock() + mock_row.fields = fields + mock_rows.append(mock_row) + mock_iterator.__iter__.return_value = mock_rows + + result = await execute_sql( project_id=project, instance_id=instance_id, credentials=credentials, query=query, - settings=BigtableToolSettings(), + settings=settings, tool_context=tool_context, + parameters=parameters, + parameter_types=parameter_types, ) - expected_rows = [{"col1": "val1", "col2": 123}] - assert result == {"status": "SUCCESS", "rows": expected_rows} - mock_client.execute_query.assert_called_once_with( - query=query, - instance_id=instance_id, - parameters=None, - parameter_types=None, - ) - mock_iterator.close.assert_called_once() + if expected_result["status"] == "ERROR": + assert result["status"] == "ERROR" + assert expected_result["error_details"] in result["error_details"] + else: + assert result == expected_result + + if not execute_query_side_effect: + mock_client.execute_query.assert_called_once_with( + query=query, + instance_id=instance_id, + parameters=parameters, + parameter_types=parameter_types, + ) + mock_iterator.close.assert_called_once() -def test_execute_sql_truncated(): - """Test execute_sql tool truncation functionality.""" - project = "my_project" - instance_id = "my_instance" - query = "SELECT * FROM my_table" - credentials = mock.create_autospec(Credentials, instance=True) - tool_context = mock.create_autospec(ToolContext, instance=True) - - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_data_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) - mock_client.execute_query.return_value = mock_iterator - - # Mock row data - mock_row1 = mock.MagicMock() - mock_row1.fields = {"col1": "val1"} - mock_row2 = mock.MagicMock() - mock_row2.fields = {"col1": "val2"} - mock_iterator.__iter__.return_value = [mock_row1, mock_row2] - - result = execute_sql( - project_id=project, - instance_id=instance_id, - credentials=credentials, - query=query, - settings=BigtableToolSettings(max_query_result_rows=1), - tool_context=tool_context, - ) - - expected_rows = [{"col1": "val1"}] - assert result == { - "status": "SUCCESS", - "rows": expected_rows, - "result_is_likely_truncated": True, - } - mock_client.execute_query.assert_called_once_with( - query=query, - instance_id=instance_id, - parameters=None, - parameter_types=None, - ) - mock_iterator.close.assert_called_once() - - -def test_execute_sql_error(): - """Test execute_sql tool error handling.""" - project = "my_project" - instance_id = "my_instance" - query = "SELECT * FROM my_table" - credentials = mock.create_autospec(Credentials, instance=True) - tool_context = mock.create_autospec(ToolContext, instance=True) - - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_data_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_client.execute_query.side_effect = Exception("Test error") - - result = execute_sql( - project_id=project, - instance_id=instance_id, - credentials=credentials, - query=query, - settings=BigtableToolSettings(), - tool_context=tool_context, - ) - assert result == {"status": "ERROR", "error_details": "Test error"} - - -def test_execute_sql_row_value_circular_reference_fallback(): +@pytest.mark.asyncio +async def test_execute_sql_row_value_circular_reference_fallback(): """Test execute_sql converts circular row values to strings.""" project = "my_project" instance_id = "my_instance" @@ -151,9 +206,7 @@ def test_execute_sql_row_value_circular_reference_fallback(): credentials = mock.create_autospec(Credentials, instance=True) tool_context = mock.create_autospec(ToolContext, instance=True) - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_data_client" - ) as mock_get_client: + with mock.patch.object(client, "get_bigtable_data_client") as mock_get_client: mock_client = mock.MagicMock() mock_get_client.return_value = mock_client mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) @@ -164,7 +217,7 @@ def test_execute_sql_row_value_circular_reference_fallback(): mock_row.fields = {"col1": circular_value} mock_iterator.__iter__.return_value = [mock_row] - result = execute_sql( + result = await execute_sql( project_id=project, instance_id=instance_id, credentials=credentials, @@ -175,47 +228,3 @@ def test_execute_sql_row_value_circular_reference_fallback(): assert result["status"] == "SUCCESS" assert result["rows"][0]["col1"] == str(circular_value) - - -def test_execute_sql_with_parameters(): - """Test execute_sql tool with parameters and parameter_types.""" - project = "my_project" - instance_id = "my_instance" - query = "SELECT * FROM my_table WHERE col1 = @param1" - credentials = mock.create_autospec(Credentials, instance=True) - tool_context = mock.create_autospec(ToolContext, instance=True) - parameters = {"param1": "val1"} - parameter_types = {"param1": "string"} - - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_data_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) - mock_client.execute_query.return_value = mock_iterator - - # Mock row data - mock_row = mock.MagicMock() - mock_row.fields = {"col1": "val1"} - mock_iterator.__iter__.return_value = [mock_row] - - result = execute_sql( - project_id=project, - instance_id=instance_id, - credentials=credentials, - query=query, - settings=BigtableToolSettings(), - tool_context=tool_context, - parameters=parameters, - parameter_types=parameter_types, - ) - - assert result["status"] == "SUCCESS" - mock_client.execute_query.assert_called_once_with( - query=query, - instance_id=instance_id, - parameters=parameters, - parameter_types=parameter_types, - ) - mock_iterator.close.assert_called_once() From 8e79a12d6bcde43cc33247b7ee6cc9e929fa6288 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Mar 2026 13:59:18 -0800 Subject: [PATCH 074/102] fix: Make invocation_context optional in convert_event_to_a2a_message PiperOrigin-RevId: 877565033 --- src/google/adk/a2a/converters/event_converter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py index 59bbefa1..a2a0ee75 100644 --- a/src/google/adk/a2a/converters/event_converter.py +++ b/src/google/adk/a2a/converters/event_converter.py @@ -370,7 +370,7 @@ def convert_a2a_message_to_event( @a2a_experimental def convert_event_to_a2a_message( event: Event, - invocation_context: InvocationContext, + invocation_context: InvocationContext | None = None, role: Role = Role.agent, part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part, ) -> Optional[Message]: @@ -390,8 +390,6 @@ def convert_event_to_a2a_message( """ if not event: raise ValueError("Event cannot be None") - if not invocation_context: - raise ValueError("Invocation context cannot be None") if not event.content or not event.content.parts: return None From c59afc21cbed27d1328872cdc2b0e182ab2ca6c8 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 2 Mar 2026 14:27:07 -0800 Subject: [PATCH 075/102] refactor: extract reusable functions from hitl and auth preprocessor Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 877578253 --- src/google/adk/auth/auth_preprocessor.py | 215 +++++++++++------- .../flows/llm_flows/request_confirmation.py | 189 ++++++++------- .../unittests/auth/test_auth_preprocessor.py | 7 +- 3 files changed, 235 insertions(+), 176 deletions(-) diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 37ad6745..76dd2dda 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Any from typing import AsyncGenerator from typing_extensions import override @@ -25,6 +26,7 @@ from ..flows.llm_flows import functions from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME from ..models.llm_request import LlmRequest +from ..sessions.state import State from .auth_handler import AuthHandler from .auth_tool import AuthConfig from .auth_tool import AuthToolArguments @@ -35,6 +37,93 @@ from .auth_tool import AuthToolArguments TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_' +async def _store_auth_and_collect_resume_targets( + events: list[Event], + auth_fc_ids: set[str], + auth_responses: dict[str, Any], + state: State, +) -> set[str]: + """Store auth credentials and return original function call IDs to resume. + + Scans session events for ``adk_request_credential`` function calls whose + IDs are in *auth_fc_ids*, extracts ``credential_key`` from their + ``AuthToolArguments`` args, merges ``credential_key`` into the + corresponding auth response, stores credentials via ``AuthHandler``, + and returns the set of original function call IDs that should be + re-executed (excluding toolset auth). + + Args: + events: Session events to scan. + auth_fc_ids: IDs of ``adk_request_credential`` function calls to match. + auth_responses: Mapping of FC ID -> auth config response dict from the + client. + state: Session state for temporary credential storage. + + Returns: + Set of original function call IDs to resume. + """ + # Step 1: Scan events for matching adk_request_credential function calls + # to extract AuthToolArguments (contains credential_key). + requested_auth_config_by_id: dict[str, AuthConfig] = {} + for event in events: + event_function_calls = event.get_function_calls() + if not event_function_calls: + continue + try: + for function_call in event_function_calls: + if ( + function_call.id in auth_fc_ids + and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME + ): + args = AuthToolArguments.model_validate(function_call.args) + requested_auth_config_by_id[function_call.id] = args.auth_config + except TypeError: + continue + + # Step 2: Store credentials. Merge credential_key from the original + # request into the client's auth response before storing. + for fc_id in auth_fc_ids: + if fc_id not in auth_responses: + continue + auth_config = AuthConfig.model_validate(auth_responses[fc_id]) + requested_auth_config = requested_auth_config_by_id.get(fc_id) + if ( + requested_auth_config + and requested_auth_config.credential_key is not None + ): + auth_config.credential_key = requested_auth_config.credential_key + await AuthHandler(auth_config=auth_config).parse_and_store_auth_response( + state=state + ) + + # Step 3: Collect original function call IDs to resume, skipping + # toolset auth entries which don't map to a resumable function call. + tools_to_resume: set[str] = set() + for fc_id in auth_fc_ids: + requested_auth_config = requested_auth_config_by_id.get(fc_id) + if not requested_auth_config: + continue + # Re-parse to get function_call_id (AuthConfig doesn't carry it; + # AuthToolArguments does). + for event in events: + event_function_calls = event.get_function_calls() + if not event_function_calls: + continue + for function_call in event_function_calls: + if ( + function_call.id == fc_id + and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME + ): + args = AuthToolArguments.model_validate(function_call.args) + if args.function_call_id.startswith( + TOOLSET_AUTH_CREDENTIAL_ID_PREFIX + ): + continue + tools_to_resume.add(args.function_call_id) + + return tools_to_resume + + class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): """Handles auth information to build the LLM request.""" @@ -49,8 +138,8 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): if not events: return - request_euc_function_call_ids = set() - # find the last event with non-None content + # Find the last user-authored event with function responses to + # identify adk_request_credential responses. last_event_with_content = None for i in range(len(events) - 1, -1, -1): event = events[i] @@ -58,7 +147,6 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): last_event_with_content = event break - # check if the last event with content is authored by user if not last_event_with_content or last_event_with_content.author != 'user': return @@ -66,104 +154,55 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): if not responses: return - requested_auth_config_by_request_id = {} - # look for auth response + # Collect adk_request_credential function response IDs and their + # response dicts. + auth_fc_ids: set[str] = set() + auth_responses: dict[str, Any] = {} for function_call_response in responses: if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME: continue - # found the function call response for the system long running request euc - # function call - request_euc_function_call_ids.add(function_call_response.id) - - if request_euc_function_call_ids: - for event in events: - function_calls = event.get_function_calls() - if not function_calls: - continue - try: - for function_call in function_calls: - if ( - function_call.id in request_euc_function_call_ids - and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME - ): - args = AuthToolArguments.model_validate(function_call.args) - requested_auth_config_by_request_id[function_call.id] = ( - args.auth_config - ) - except TypeError: - continue - - for function_call_response in responses: - if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME: - continue - - auth_config = AuthConfig.model_validate(function_call_response.response) - requested_auth_config = requested_auth_config_by_request_id.get( - function_call_response.id - ) - if ( - requested_auth_config - and requested_auth_config.credential_key is not None - ): - auth_config.credential_key = requested_auth_config.credential_key - await AuthHandler(auth_config=auth_config).parse_and_store_auth_response( - state=invocation_context.session.state + auth_fc_ids.add(function_call_response.id) + auth_responses[function_call_response.id] = ( + function_call_response.response ) - if not request_euc_function_call_ids: + if not auth_fc_ids: return + # Store credentials and collect tools to resume. + tools_to_resume = await _store_auth_and_collect_resume_targets( + events, auth_fc_ids, auth_responses, invocation_context.session.state + ) + + if not tools_to_resume: + return + + # Find the original function call event and re-execute the tools + # that needed auth. for i in range(len(events) - 2, -1, -1): event = events[i] - # looking for the system long running request euc function call function_calls = event.get_function_calls() if not function_calls: continue - tools_to_resume = set() - - for function_call in function_calls: - if function_call.id not in request_euc_function_call_ids: - continue - args = AuthToolArguments.model_validate(function_call.args) - - # Skip toolset auth - auth response is already stored in session state - # and we don't need to resume a function call for toolsets - if args.function_call_id.startswith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX): - continue - - tools_to_resume.add(args.function_call_id) - if not tools_to_resume: - continue - - # found the system long running request euc function call - # looking for original function call that requests euc - for j in range(i - 1, -1, -1): - event = events[j] - function_calls = event.get_function_calls() - if not function_calls: - continue - - if any([ - function_call.id in tools_to_resume - for function_call in function_calls - ]): - if function_response_event := await functions.handle_function_calls_async( - invocation_context, - event, - { - tool.name: tool - for tool in await agent.canonical_tools( - ReadonlyContext(invocation_context) - ) - }, - # there could be parallel function calls that require auth - # auth response would be a dict keyed by function call id - tools_to_resume, - ): - yield function_response_event - return - return + if any([ + function_call.id in tools_to_resume + for function_call in function_calls + ]): + if function_response_event := await functions.handle_function_calls_async( + invocation_context, + event, + { + tool.name: tool + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ) + }, + tools_to_resume, + ): + yield function_response_event + return + return request_processor = _AuthLlmRequestProcessor() diff --git a/src/google/adk/flows/llm_flows/request_confirmation.py b/src/google/adk/flows/llm_flows/request_confirmation.py index f7b7f7f6..d066db79 100644 --- a/src/google/adk/flows/llm_flows/request_confirmation.py +++ b/src/google/adk/flows/llm_flows/request_confirmation.py @@ -15,6 +15,7 @@ from __future__ import annotations import json import logging +from typing import Any from typing import AsyncGenerator from typing import TYPE_CHECKING @@ -37,6 +38,65 @@ if TYPE_CHECKING: logger = logging.getLogger('google_adk.' + __name__) +def _parse_tool_confirmation(response: dict[str, Any]) -> ToolConfirmation: + """Parse ToolConfirmation from a function response dict. + + Handles both the direct dict format and the ADK client's + ``{'response': json_string}`` wrapper format. + + """ + if response and len(response.values()) == 1 and 'response' in response.keys(): + return ToolConfirmation.model_validate(json.loads(response['response'])) + return ToolConfirmation.model_validate(response) + + +def _resolve_confirmation_targets( + events: list[Event], + confirmation_fc_ids: set[str], + confirmations_by_fc_id: dict[str, ToolConfirmation], +) -> tuple[dict[str, ToolConfirmation], dict[str, types.FunctionCall]]: + """Find original function calls for confirmed tools. + + Scans events for ``adk_request_confirmation`` function calls whose IDs + are in *confirmation_fc_ids*, extracts the ``originalFunctionCall`` from + their args, and maps each confirmation to the original FC ID. + + Args: + events: Session events to scan. + confirmation_fc_ids: IDs of ``adk_request_confirmation`` function calls. + confirmations_by_fc_id: Mapping of confirmation FC ID -> + ``ToolConfirmation``. + + Returns: + Tuple of ``(tool_confirmation_dict, original_fcs_dict)`` where both + are keyed by the ORIGINAL function call IDs. + """ + tool_confirmation_dict: dict[str, ToolConfirmation] = {} + original_fcs_dict: dict[str, types.FunctionCall] = {} + + for event in events: + event_function_calls = event.get_function_calls() + if not event_function_calls: + continue + + for function_call in event_function_calls: + if function_call.id not in confirmation_fc_ids: + continue + + args = function_call.args + if 'originalFunctionCall' not in args: + continue + original_function_call = types.FunctionCall( + **args['originalFunctionCall'] + ) + tool_confirmation_dict[original_function_call.id] = ( + confirmations_by_fc_id[function_call.id] + ) + original_fcs_dict[original_function_call.id] = original_function_call + + return tool_confirmation_dict, original_fcs_dict + + class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): """Handles tool confirmation information to build the LLM request.""" @@ -53,14 +113,12 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): if not events: return - request_confirmation_function_responses = ( - dict() - ) # {function call id, tool confirmation} - + # Step 1: Find the last user-authored event and parse confirmation + # responses from it. + confirmations_by_fc_id: dict[str, ToolConfirmation] = {} confirmation_event_index = -1 for k in range(len(events) - 1, -1, -1): event = events[k] - # Find the first event authored by user if not event.author or event.author != 'user': continue responses = event.get_function_responses() @@ -70,101 +128,58 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): for function_response in responses: if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME: continue - - # Find the FunctionResponse event that contains the user provided tool - # confirmation - if ( + confirmations_by_fc_id[function_response.id] = _parse_tool_confirmation( function_response.response - and len(function_response.response.values()) == 1 - and 'response' in function_response.response.keys() - ): - # ADK client must send a resuming run request with a function response - # that always encapsulate the confirmation result with a 'response' - # key - tool_confirmation = ToolConfirmation.model_validate( - json.loads(function_response.response['response']) - ) - else: - tool_confirmation = ToolConfirmation.model_validate( - function_response.response - ) - request_confirmation_function_responses[function_response.id] = ( - tool_confirmation ) confirmation_event_index = k break - if not request_confirmation_function_responses: + if not confirmations_by_fc_id: return - for i in range(len(events) - 2, -1, -1): + # Step 2: Resolve confirmation targets using extracted helper. + confirmation_fc_ids = set(confirmations_by_fc_id.keys()) + tools_to_resume_with_confirmation, tools_to_resume_with_args = ( + _resolve_confirmation_targets( + events, confirmation_fc_ids, confirmations_by_fc_id + ) + ) + + if not tools_to_resume_with_confirmation: + return + + # Step 3: Remove tools that have already been confirmed (dedup). + for i in range(len(events) - 1, confirmation_event_index, -1): event = events[i] - # Find the system generated FunctionCall event requesting the tool - # confirmation - function_calls = event.get_function_calls() - if not function_calls: + fr_list = event.get_function_responses() + if not fr_list: continue - tools_to_resume_with_confirmation = ( - dict() - ) # {Function call id, tool confirmation} - tools_to_resume_with_args = dict() # {Function call id, function calls} - - for function_call in function_calls: - if ( - function_call.id - not in request_confirmation_function_responses.keys() - ): - continue - - args = function_call.args - if 'originalFunctionCall' not in args: - continue - original_function_call = types.FunctionCall( - **args['originalFunctionCall'] - ) - tools_to_resume_with_confirmation[original_function_call.id] = ( - request_confirmation_function_responses[function_call.id] - ) - tools_to_resume_with_args[original_function_call.id] = ( - original_function_call - ) + for function_response in fr_list: + if function_response.id in tools_to_resume_with_confirmation: + tools_to_resume_with_confirmation.pop(function_response.id) + tools_to_resume_with_args.pop(function_response.id) if not tools_to_resume_with_confirmation: - continue + break - # Remove the tools that have already been confirmed. - for i in range(len(events) - 1, confirmation_event_index, -1): - event = events[i] - function_response = event.get_function_responses() - if not function_response: - continue - - for function_response in event.get_function_responses(): - if function_response.id in tools_to_resume_with_confirmation: - tools_to_resume_with_confirmation.pop(function_response.id) - tools_to_resume_with_args.pop(function_response.id) - if not tools_to_resume_with_confirmation: - break - - if not tools_to_resume_with_confirmation: - continue - - if function_response_event := await functions.handle_function_call_list_async( - invocation_context, - tools_to_resume_with_args.values(), - { - tool.name: tool - for tool in await agent.canonical_tools( - ReadonlyContext(invocation_context) - ) - }, - # There could be parallel function calls that require input - # response would be a dict keyed by function call id - tools_to_resume_with_confirmation.keys(), - tools_to_resume_with_confirmation, - ): - yield function_response_event + if not tools_to_resume_with_confirmation: return + # Step 4: Re-execute the confirmed tools. + if function_response_event := await functions.handle_function_call_list_async( + invocation_context, + tools_to_resume_with_args.values(), + { + tool.name: tool + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ) + }, + tools_to_resume_with_confirmation.keys(), + tools_to_resume_with_confirmation, + ): + yield function_response_event + return + request_processor = _RequestConfirmationLlmRequestProcessor() diff --git a/tests/unittests/auth/test_auth_preprocessor.py b/tests/unittests/auth/test_auth_preprocessor.py index 04a64fc5..fb45cc34 100644 --- a/tests/unittests/auth/test_auth_preprocessor.py +++ b/tests/unittests/auth/test_auth_preprocessor.py @@ -79,7 +79,9 @@ class TestAuthLlmRequestProcessor: @pytest.fixture def mock_auth_config(self): """Create a mock AuthConfig.""" - return Mock(spec=AuthConfig) + config = Mock(spec=AuthConfig) + config.credential_key = None + return config @pytest.fixture def mock_function_response_with_auth(self, mock_auth_config): @@ -347,10 +349,12 @@ class TestAuthLlmRequestProcessor: auth_response_1, auth_response_2, ] + user_event_with_multiple_responses.get_function_calls.return_value = [] # Create system function call events system_function_call_1 = Mock() system_function_call_1.id = 'auth_id_1' + system_function_call_1.name = REQUEST_EUC_FUNCTION_CALL_NAME system_function_call_1.args = { 'function_call_id': 'tool_id_1', 'auth_config': mock_auth_config, @@ -358,6 +362,7 @@ class TestAuthLlmRequestProcessor: system_function_call_2 = Mock() system_function_call_2.id = 'auth_id_2' + system_function_call_2.name = REQUEST_EUC_FUNCTION_CALL_NAME system_function_call_2.args = { 'function_call_id': 'tool_id_2', 'auth_config': mock_auth_config, From 80c5a245557cd75870e72bff0ecfaafbd37fdbc7 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Mon, 2 Mar 2026 14:31:55 -0800 Subject: [PATCH 076/102] feat: Enhance BQ plugin with fork safety, auto views, and trace continuity - **Fork-safety (#4636):** Adds PID tracking to `BigQueryAgentAnalyticsPlugin` so that forked child processes detect stale gRPC channels and re-initialize instead of deadlocking. Uses the standard `os.getpid()` pattern (same as SQLAlchemy, gRPC-python). - **Auto-create views (#4639):** After ensuring the `agent_events` table exists, automatically creates 15 per-event-type BigQuery views (`v_llm_request`, `v_tool_completed`, etc.) that unnest JSON columns into typed, queryable columns. Controlled by `BigQueryLoggerConfig.create_views` (default `True`), idempotent via `CREATE OR REPLACE VIEW`. - **Trace-ID continuity & o11y alignment (#4645):** Fixes trace_id fracture between early events (USER_MESSAGE_RECEIVED, INVOCATION_STARTING) and later events (AGENT_STARTING onwards) when no ambient OTel span exists. Also aligns BQ rows with Cloud Trace span IDs when o11y is active. - **Span-ID consistency under ambient OTel (#4640 review):** Fixes `*_STARTING` / `*_COMPLETED` events producing mismatched span IDs when an ambient OTel span is active. Completion callbacks now check for ambient spans and defer to `_resolve_ids` Layer 2 instead of overriding with plugin-synthetic IDs. - **Stack leak safety (#4640 review):** Adds `TraceManager.clear_stack()` and makes `ensure_invocation_span()` clear stale records from *different* invocations, preventing span stack leaks across invocations. Uses `_active_invocation_id_ctx` to distinguish stale leak vs same-invocation re-entry. - **Root agent name staleness fix:** `init_trace()` now refreshes `_root_agent_name_ctx` unconditionally on each invocation (previously set-once-on-None). `after_run_callback` resets it alongside other invocation cleanup. - **Exception-safe cleanup:** `after_run_callback` uses `try/finally` to guarantee invocation state (`clear_stack`, `_active_invocation_id_ctx`, `_root_agent_name_ctx`) is always reset, even if `_log_event` raises. - **`on_tool_error_callback` span fix:** Previously discarded the span_id from `pop_span()`, causing TOOL_ERROR events to get the wrong span. Now captures and uses the popped span_id. Co-authored-by: Haiyuan Cao PiperOrigin-RevId: 877580395 --- .../bigquery_agent_analytics_plugin.py | 436 ++++- .../test_bigquery_agent_analytics_plugin.py | 1520 ++++++++++++++++- 2 files changed, 1852 insertions(+), 104 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 70a17f40..0f43de6b 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -27,6 +27,7 @@ import functools import json import logging import mimetypes +import os import random import time from types import MappingProxyType @@ -498,16 +499,35 @@ class BigQueryLoggerConfig: # dropped or altered). Safe to leave enabled; a version label on the # table ensures the diff runs at most once per schema version. auto_schema_upgrade: bool = True + # Automatically create per-event-type BigQuery views that unnest + # JSON columns into typed, queryable columns. + create_views: bool = True # ============================================================================== # HELPER: TRACE MANAGER (Async-Safe with ContextVars) # ============================================================================== +# NOTE: These contextvars are module-global, not plugin-instance-scoped. +# This is safe in practice for two reasons: +# 1. PluginManager enforces name-uniqueness, preventing two BQ plugin +# instances on the same Runner. +# 2. Concurrent asyncio tasks (e.g. two Runners in asyncio.gather) each +# get an isolated contextvar copy, so they don't interfere. +# The only problematic case would be two plugin instances interleaved +# within the *same* asyncio task without task boundaries — which the +# framework's PluginManager already prevents. _root_agent_name_ctx = contextvars.ContextVar( "_bq_analytics_root_agent_name", default=None ) +# Tracks the invocation_id that owns the current span stack so that +# ensure_invocation_span() can distinguish "same invocation re-entry" +# (idempotent) from "stale records from a previous invocation" (clear). +_active_invocation_id_ctx: contextvars.ContextVar[Optional[str]] = ( + contextvars.ContextVar("_bq_analytics_active_invocation_id", default=None) +) + @dataclass class _SpanRecord: @@ -553,12 +573,13 @@ class TraceManager: @staticmethod def init_trace(callback_context: CallbackContext) -> None: - if _root_agent_name_ctx.get() is None: - try: - root_agent = callback_context._invocation_context.agent.root_agent - _root_agent_name_ctx.set(root_agent.name) - except (AttributeError, ValueError): - pass + # Always refresh root_agent_name — it can change between + # invocations (e.g. different root agents in the same task). + try: + root_agent = callback_context._invocation_context.agent.root_agent + _root_agent_name_ctx.set(root_agent.name) + except (AttributeError, ValueError): + pass # Ensure records stack is initialized TraceManager._get_records() @@ -600,7 +621,16 @@ class TraceManager: # Create the span without attaching it to the ambient context. # This avoids re-parenting framework spans like ``call_llm`` # or ``execute_tool``. See #4561. - span = tracer.start_span(span_name) + # + # If the internal stack already has a span, create the new span + # as a child so it shares the same trace_id. Without this, each + # ``start_span`` would be an independent root with its own + # trace_id — causing trace_id fracture (see #4645). + records = TraceManager._get_records() + parent_ctx = None + if records and records[-1].span.get_span_context().is_valid: + parent_ctx = trace.set_span_in_context(records[-1].span) + span = tracer.start_span(span_name, context=parent_ctx) if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") @@ -614,7 +644,6 @@ class TraceManager: start_time_ns=time.time_ns(), ) - records = TraceManager._get_records() new_records = list(records) + [record] _span_records_ctx.set(new_records) @@ -651,6 +680,49 @@ class TraceManager: return span_id_str + @staticmethod + def ensure_invocation_span( + callback_context: CallbackContext, + ) -> None: + """Ensures a root span exists on the plugin stack for this invocation. + + Must be called before any events are logged so that every event in + the invocation shares the same trace_id. + + * If the stack has entries for the *current* invocation → no-op + (idempotent within the same invocation). + * If the stack has entries from a *different* invocation → clear + stale records and re-initialise (safety net for abnormal exit). + * If the ambient OTel span is valid → ``attach_current_span`` + (reuse the runner's span without owning it). + * Otherwise → ``push_span("invocation")`` (create a new root + span that will be popped in ``after_run_callback``). + """ + current_inv = callback_context.invocation_id + active_inv = _active_invocation_id_ctx.get() + + records = _span_records_ctx.get() + if records: + if active_inv == current_inv: + return # Already initialised for this invocation. + # Stale records from a previous invocation that wasn't cleaned + # up (e.g. exception skipped after_run_callback). Clear and + # re-init. + logger.debug( + "Clearing %d stale span records from previous invocation.", + len(records), + ) + TraceManager.clear_stack() + + _active_invocation_id_ctx.set(current_inv) + + # Check for a valid ambient span (e.g. the Runner's invocation span). + ambient = trace.get_current_span() + if ambient.get_span_context().is_valid: + TraceManager.attach_current_span(callback_context) + else: + TraceManager.push_span(callback_context, "invocation") + @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: """Ends the current span and pops it from the stack. @@ -679,6 +751,17 @@ class TraceManager: return record.span_id, duration_ms + @staticmethod + def clear_stack() -> None: + """Clears all span records. Safety net for cross-invocation cleanup.""" + records = _span_records_ctx.get() + if records: + # End any owned spans to avoid OTel resource leaks. + for record in reversed(records): + if record.owns_span: + record.span.end() + _span_records_ctx.set([]) + @staticmethod def get_current_span_and_parent() -> tuple[Optional[str], Optional[str]]: """Gets current span_id and parent span_id.""" @@ -1581,6 +1664,115 @@ def _get_events_schema() -> list[bigquery.SchemaField]: ] +# ============================================================================== +# ANALYTICS VIEW DEFINITIONS +# ============================================================================== + +# Columns included in every per-event-type view. +_VIEW_COMMON_COLUMNS = ( + "timestamp", + "event_type", + "agent", + "session_id", + "invocation_id", + "user_id", + "trace_id", + "span_id", + "parent_span_id", + "status", + "error_message", + "is_truncated", +) + +# Per-event-type column extractions. Each value is a list of +# ``"SQL_EXPR AS alias"`` strings that will be appended after the +# common columns in the view SELECT. +_EVENT_VIEW_DEFS: dict[str, list[str]] = { + "USER_MESSAGE_RECEIVED": [], + "LLM_REQUEST": [ + "JSON_VALUE(attributes, '$.model') AS model", + "content AS request_content", + "JSON_QUERY(attributes, '$.llm_config') AS llm_config", + "JSON_QUERY(attributes, '$.tools') AS tools", + ], + "LLM_RESPONSE": [ + "JSON_QUERY(content, '$.response') AS response", + ( + "CAST(JSON_VALUE(content, '$.usage.prompt')" + " AS INT64) AS usage_prompt_tokens" + ), + ( + "CAST(JSON_VALUE(content, '$.usage.completion')" + " AS INT64) AS usage_completion_tokens" + ), + ( + "CAST(JSON_VALUE(content, '$.usage.total')" + " AS INT64) AS usage_total_tokens" + ), + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ( + "CAST(JSON_VALUE(latency_ms," + " '$.time_to_first_token_ms') AS INT64) AS ttft_ms" + ), + "JSON_VALUE(attributes, '$.model_version') AS model_version", + "JSON_QUERY(attributes, '$.usage_metadata') AS usage_metadata", + ], + "LLM_ERROR": [ + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "TOOL_STARTING": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + ], + "TOOL_COMPLETED": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.result') AS tool_result", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "TOOL_ERROR": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "AGENT_STARTING": [ + "JSON_VALUE(content, '$.text_summary') AS agent_instruction", + ], + "AGENT_COMPLETED": [ + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "INVOCATION_STARTING": [], + "INVOCATION_COMPLETED": [], + "STATE_DELTA": [ + "JSON_QUERY(attributes, '$.state_delta') AS state_delta", + ], + "HITL_CREDENTIAL_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], + "HITL_CONFIRMATION_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], + "HITL_INPUT_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], +} + +_VIEW_SQL_TEMPLATE = """\ +CREATE OR REPLACE VIEW `{project}.{dataset}.{view_name}` AS +SELECT + {columns} +FROM + `{project}.{dataset}.{table}` +WHERE + event_type = '{event_type}' +""" + + # ============================================================================== # MAIN PLUGIN # ============================================================================== @@ -1592,7 +1784,7 @@ class _LoopState: batch_processor: BatchProcessor -@dataclass +@dataclass(kw_only=True) class EventData: """Typed container for structured fields passed to _log_event.""" @@ -1606,6 +1798,7 @@ class EventData: status: str = "OK" error_message: Optional[str] = None extra_attributes: dict[str, Any] = field(default_factory=dict) + trace_id_override: Optional[str] = None class BigQueryAgentAnalyticsPlugin(BasePlugin): @@ -1650,6 +1843,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): self.location = location self._started = False + self._startup_error: Optional[Exception] = None self._is_shutting_down = False self._setup_lock = None self.client = None @@ -1660,6 +1854,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): self.parser: Optional[HybridContentParser] = None self._schema = None self.arrow_schema = None + self._init_pid = os.getpid() def _cleanup_stale_loop_states(self) -> None: """Removes entries for event loops that have been closed.""" @@ -1912,6 +2107,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): existing_table = self.client.get_table(self.full_table_id) if self.config.auto_schema_upgrade: self._maybe_upgrade_schema(existing_table) + if self.config.create_views: + self._create_analytics_views() except cloud_exceptions.NotFound: logger.info("Table %s not found, creating table.", self.full_table_id) tbl = bigquery.Table(self.full_table_id, schema=self._schema) @@ -1921,10 +2118,13 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): ) tbl.clustering_fields = self.config.clustering_fields tbl.labels = {_SCHEMA_VERSION_LABEL_KEY: _SCHEMA_VERSION} + table_ready = False try: self.client.create_table(tbl) + table_ready = True except cloud_exceptions.Conflict: - pass + # Another process created it concurrently — still usable. + table_ready = True except Exception as e: logger.error( "Could not create table %s: %s", @@ -1932,6 +2132,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): e, exc_info=True, ) + if table_ready and self.config.create_views: + self._create_analytics_views() except Exception as e: logger.error( "Error checking for table %s: %s", @@ -1980,6 +2182,50 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): exc_info=True, ) + def _create_analytics_views(self) -> None: + """Creates per-event-type BigQuery views (idempotent). + + Each view filters the events table by ``event_type`` and + extracts JSON columns into typed, queryable columns. Uses + ``CREATE OR REPLACE VIEW`` so it is safe to call repeatedly. + Errors are logged but never raised. + """ + for event_type, extra_cols in _EVENT_VIEW_DEFS.items(): + view_name = "v_" + event_type.lower() + columns = ",\n ".join(list(_VIEW_COMMON_COLUMNS) + extra_cols) + sql = _VIEW_SQL_TEMPLATE.format( + project=self.project_id, + dataset=self.dataset_id, + view_name=view_name, + columns=columns, + table=self.table_id, + event_type=event_type, + ) + try: + self.client.query(sql).result() + except Exception as e: + logger.error( + "Failed to create view %s: %s", + view_name, + e, + exc_info=True, + ) + + async def create_analytics_views(self) -> None: + """Public async helper to (re-)create all analytics views. + + Useful when views need to be refreshed explicitly, for example + after a schema upgrade. Ensures the plugin is initialized + before attempting view creation. + """ + await self._ensure_started() + if not self._started: + raise RuntimeError( + "Plugin initialization failed; cannot create analytics views." + ) from self._startup_error + loop = asyncio.get_running_loop() + await loop.run_in_executor(self._executor, self._create_analytics_views) + async def shutdown(self, timeout: float | None = None) -> None: """Shuts down the plugin and releases resources. @@ -2031,13 +2277,39 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): state["offloader"] = None state["parser"] = None state["_started"] = False + state["_startup_error"] = None state["_is_shutting_down"] = False + state["_init_pid"] = 0 return state def __setstate__(self, state): """Custom unpickling to restore state.""" + # Backfill keys that may be absent in pickled state from older + # code versions so _ensure_started does not raise AttributeError. + state.setdefault("_init_pid", 0) self.__dict__.update(state) + def _reset_runtime_state(self) -> None: + """Resets all runtime state after a fork. + + gRPC channels and asyncio locks are not safe to use after + ``os.fork()``. This method clears them so the next call to + ``_ensure_started()`` re-initializes everything in the child + process. Pure-data fields like ``_schema`` and + ``arrow_schema`` are kept because they are safe across fork. + """ + self._setup_lock = None + self.client = None + self._loop_state_by_loop = {} + self._write_stream_name = None + self._executor = None + self.offloader = None + self.parser = None + self._started = False + self._startup_error = None + self._is_shutting_down = False + self._init_pid = os.getpid() + async def __aenter__(self) -> BigQueryAgentAnalyticsPlugin: await self._ensure_started() return self @@ -2047,6 +2319,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def _ensure_started(self, **kwargs) -> None: """Ensures that the plugin is started and initialized.""" + if os.getpid() != self._init_pid: + self._reset_runtime_state() if not self._started: # Kept original lock name as it was not explicitly changed. if self._setup_lock is None: @@ -2056,31 +2330,59 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): try: await self._lazy_setup(**kwargs) self._started = True + self._startup_error = None except Exception as e: + self._startup_error = e logger.error("Failed to initialize BigQuery Plugin: %s", e) @staticmethod - def _resolve_span_ids( + def _resolve_ids( event_data: EventData, - ) -> tuple[str, str]: - """Reads span/parent overrides from EventData, falling back to TraceManager. + callback_context: CallbackContext, + ) -> tuple[Optional[str], Optional[str], Optional[str]]: + """Resolves trace_id, span_id, and parent_span_id for a log row. + + Priority order (highest first): + 1. Explicit ``EventData`` overrides (needed for post-pop callbacks). + 2. Ambient OTel span (the framework's ``start_as_current_span``). + When present this aligns BQ rows with Cloud Trace / o11y. + 3. Plugin's internal span stack (``TraceManager``). + 4. ``invocation_id`` fallback for trace_id. Returns: - (span_id, parent_span_id) + (trace_id, span_id, parent_span_id) """ - current_span_id, current_parent_span_id = ( + # --- Layer 3: plugin stack baseline --- + trace_id = TraceManager.get_trace_id(callback_context) + plugin_span_id, plugin_parent_span_id = ( TraceManager.get_current_span_and_parent() ) + span_id = plugin_span_id + parent_span_id = plugin_parent_span_id - span_id = current_span_id + # --- Layer 2: ambient OTel span --- + ambient = trace.get_current_span() + ambient_ctx = ambient.get_span_context() + if ambient_ctx.is_valid: + trace_id = format(ambient_ctx.trace_id, "032x") + span_id = format(ambient_ctx.span_id, "016x") + # Reset parent — stale plugin-stack parent must not leak through + # when the ambient span is a root (no parent). + parent_span_id = None + # SDK spans expose .parent; non-recording spans do not. + parent_ctx = getattr(ambient, "parent", None) + if parent_ctx is not None and parent_ctx.span_id: + parent_span_id = format(parent_ctx.span_id, "016x") + + # --- Layer 1: explicit EventData overrides --- + if event_data.trace_id_override is not None: + trace_id = event_data.trace_id_override if event_data.span_id_override is not None: span_id = event_data.span_id_override - - parent_span_id = current_parent_span_id if event_data.parent_span_id_override is not None: parent_span_id = event_data.parent_span_id_override - return span_id, parent_span_id + return trace_id, span_id, parent_span_id @staticmethod def _extract_latency( @@ -2193,8 +2495,9 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): except Exception as e: logger.warning("Content formatter failed: %s", e) - trace_id = TraceManager.get_trace_id(callback_context) - span_id, parent_span_id = self._resolve_span_ids(event_data) + trace_id, span_id, parent_span_id = self._resolve_ids( + event_data, callback_context + ) if not self.parser: logger.warning("Parser not initialized; skipping event %s.", event_type) @@ -2261,6 +2564,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): user_message: The message content received from the user. """ callback_ctx = CallbackContext(invocation_context) + TraceManager.ensure_invocation_span(callback_ctx) await self._log_event( "USER_MESSAGE_RECEIVED", callback_ctx, @@ -2395,9 +2699,11 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): invocation_context: The context of the current invocation. """ await self._ensure_started() + callback_ctx = CallbackContext(invocation_context) + TraceManager.ensure_invocation_span(callback_ctx) await self._log_event( "INVOCATION_STARTING", - CallbackContext(invocation_context), + callback_ctx, ) @_safe_callback @@ -2409,12 +2715,40 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): Args: invocation_context: The context of the current invocation. """ - await self._log_event( - "INVOCATION_COMPLETED", - CallbackContext(invocation_context), - ) - # Ensure all logs are flushed before the agent returns - await self.flush() + try: + # Capture trace_id BEFORE popping the invocation-root span so + # that INVOCATION_COMPLETED shares the same trace_id as all + # earlier events in this invocation (fixes #4645). + callback_ctx = CallbackContext(invocation_context) + trace_id = TraceManager.get_trace_id(callback_ctx) + + # Pop the invocation-root span pushed by ensure_invocation_span. + span_id, duration = TraceManager.pop_span() + parent_span_id = TraceManager.get_current_span_id() + + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping STARTING/COMPLETED pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + + await self._log_event( + "INVOCATION_COMPLETED", + callback_ctx, + event_data=EventData( + trace_id_override=trace_id, + latency_ms=duration, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, + ), + ) + finally: + # Cleanup must run even if _log_event raises, otherwise + # stale invocation metadata leaks into the next invocation. + TraceManager.clear_stack() + _active_invocation_id_ctx.set(None) + _root_agent_name_ctx.set(None) + # Ensure all logs are flushed before the agent returns. + await self.flush() @_safe_callback async def before_agent_callback( @@ -2445,18 +2779,20 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): callback_context: The callback context. """ span_id, duration = TraceManager.pop_span() - # When popping, the current stack now points to parent. - # The event we are logging ("AGENT_COMPLETED") belongs to the span we just popped. - # So we must override span_id to be the popped span, and parent to be current top of stack. parent_span_id, _ = TraceManager.get_current_span_and_parent() + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping STARTING/COMPLETED pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "AGENT_COMPLETED", callback_context, event_data=EventData( latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, ), ) @@ -2606,6 +2942,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): # Otherwise log_event will fetch current stack (which is parent). span_id = popped_span_id or span_id + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping LLM_REQUEST/LLM_RESPONSE pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + use_override = is_popped and not has_ambient + await self._log_event( "LLM_RESPONSE", callback_context, @@ -2616,8 +2958,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): time_to_first_token_ms=tfft, model_version=llm_response.model_version, usage_metadata=llm_response.usage_metadata, - span_id_override=span_id if is_popped else None, - parent_span_id_override=(parent_span_id if is_popped else None), + span_id_override=span_id if use_override else None, + parent_span_id_override=parent_span_id if use_override else None, ), ) @@ -2638,14 +2980,18 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): """ span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "LLM_ERROR", callback_context, event_data=EventData( error_message=str(error), latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, ), ) @@ -2710,10 +3056,13 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + event_data = EventData( latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, ) await self._log_event( "TOOL_COMPLETED", @@ -2749,7 +3098,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): "args": args_truncated, "tool_origin": tool_origin, } - _, duration = TraceManager.pop_span() + span_id, duration = TraceManager.pop_span() + parent_span_id, _ = TraceManager.get_current_span_and_parent() + + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "TOOL_ERROR", tool_context, @@ -2758,5 +3112,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): event_data=EventData( error_message=str(error), latency_ms=duration, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, ), ) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 549263fb..5d87a17c 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -20,8 +20,8 @@ import json from unittest import mock from google.adk.agents import base_agent -from google.adk.agents import callback_context as callback_context_lib -from google.adk.agents import invocation_context as invocation_context_lib +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext from google.adk.events import event as event_lib from google.adk.events import event_actions as event_actions_lib from google.adk.models import llm_request as llm_request_lib @@ -83,7 +83,7 @@ def invocation_context(mock_agent, mock_session): mock_plugin_manager = mock.create_autospec( plugin_manager_lib.PluginManager, instance=True, spec_set=True ) - return invocation_context_lib.InvocationContext( + return InvocationContext( agent=mock_agent, session=mock_session, invocation_id="inv-789", @@ -94,9 +94,7 @@ def invocation_context(mock_agent, mock_session): @pytest.fixture def callback_context(invocation_context): - return callback_context_lib.CallbackContext( - invocation_context=invocation_context - ) + return CallbackContext(invocation_context=invocation_context) @pytest.fixture @@ -2152,7 +2150,7 @@ class TestBigQueryAgentAnalyticsPlugin: span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( callback_context, "test_span" ) - mock_tracer.start_span.assert_called_with("test_span") + mock_tracer.start_span.assert_called_with("test_span", context=None) assert span_id == format(span_id_int, "016x") # Test get_trace_id # We need to mock trace.get_current_span() to return our mock span @@ -3018,81 +3016,221 @@ class TestDuplicateLabels: assert "labels" not in attributes -class TestResolveSpanIds: - """Tests for the _resolve_span_ids static helper.""" +class TestResolveIds: + """Tests for the _resolve_ids static helper.""" - def test_uses_trace_manager_defaults(self): - """Should use TraceManager values when no overrides provided.""" + def _resolve(self, ed, callback_context): + return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_ids( + ed, callback_context + ) + + def test_uses_trace_manager_defaults(self, callback_context): + """Should use TraceManager values when no overrides and no ambient.""" ed = bigquery_agent_analytics_plugin.EventData( extra_attributes={"some_key": "value"} ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + assert trace_id == "trace-1" assert span_id == "span-1" assert parent_id == "parent-1" - def test_span_id_override(self): + def test_span_id_override(self, callback_context): """Should use span_id_override from EventData.""" ed = bigquery_agent_analytics_plugin.EventData( span_id_override="custom-span" ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "custom-span" assert parent_id == "parent-1" - def test_parent_span_id_override(self): + def test_parent_span_id_override(self, callback_context): """Should use parent_span_id_override from EventData.""" ed = bigquery_agent_analytics_plugin.EventData( parent_span_id_override="custom-parent" ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "span-1" assert parent_id == "custom-parent" - def test_none_override_keeps_default(self): + def test_none_override_keeps_default(self, callback_context): """None overrides should keep the TraceManager defaults.""" ed = bigquery_agent_analytics_plugin.EventData( span_id_override=None, parent_span_id_override=None ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "span-1" assert parent_id == "parent-1" + def test_ambient_otel_span_takes_priority(self, callback_context): + """When an ambient OTel span is valid, its IDs take priority.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + ed = bigquery_agent_analytics_plugin.EventData() + + with real_tracer.start_as_current_span("invocation") as parent_span: + with real_tracer.start_as_current_span("agent") as agent_span: + ambient_ctx = agent_span.get_span_context() + expected_trace = format(ambient_ctx.trace_id, "032x") + expected_span = format(ambient_ctx.span_id, "016x") + expected_parent = format(parent_span.get_span_context().span_id, "016x") + + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + + assert trace_id == expected_trace + assert span_id == expected_span + assert parent_id == expected_parent + provider.shutdown() + + def test_override_beats_ambient(self, callback_context): + """EventData overrides take priority over ambient OTel span.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + ed = bigquery_agent_analytics_plugin.EventData( + trace_id_override="forced-trace", + span_id_override="forced-span", + parent_span_id_override="forced-parent", + ) + + with real_tracer.start_as_current_span("invocation"): + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + + assert trace_id == "forced-trace" + assert span_id == "forced-span" + assert parent_id == "forced-parent" + provider.shutdown() + + def test_ambient_root_span_no_self_parent(self, callback_context): + """Ambient root span (no parent) must not produce self-parent.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + # Seed the plugin stack with a span so there's a stale parent. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "plugin-child" + ) + + ed = bigquery_agent_analytics_plugin.EventData() + + # Single root ambient span — no parent. + with real_tracer.start_as_current_span("root_invocation") as root: + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + root_span_id = format(root.get_span_context().span_id, "016x") + + # span_id should be the ambient root's span_id + assert span_id == root_span_id + # parent must be None — not the stale plugin parent, not self + assert parent_id is None + assert span_id != parent_id + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + provider.shutdown() + + def test_ambient_span_used_for_completed_event(self, callback_context): + """Completed event with overrides should use ambient when present. + + When an ambient OTel span is valid, passing None overrides lets + _resolve_ids Layer 2 pick the ambient span — matching the + STARTING event's span_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with real_tracer.start_as_current_span("invoke_agent") as agent_span: + expected_span = format(agent_span.get_span_context().span_id, "016x") + + # Simulate STARTING: no overrides → ambient Layer 2 wins. + ed_starting = bigquery_agent_analytics_plugin.EventData() + _, span_starting, _ = self._resolve(ed_starting, callback_context) + + # Simulate COMPLETED: None overrides (ambient check passed). + ed_completed = bigquery_agent_analytics_plugin.EventData( + span_id_override=None, + parent_span_id_override=None, + latency_ms=42, + ) + _, span_completed, _ = self._resolve(ed_completed, callback_context) + + assert span_starting == expected_span + assert span_completed == expected_span + assert span_starting == span_completed + + provider.shutdown() + class TestExtractLatency: """Tests for the _extract_latency static helper.""" @@ -3282,7 +3420,7 @@ class TestMultiSubagentToolLogging: instance=True, spec_set=True, ) - return invocation_context_lib.InvocationContext( + return InvocationContext( agent=mock_a, session=session, invocation_id=invocation_id, @@ -3488,7 +3626,7 @@ class TestMultiSubagentToolLogging: """ session = self._make_session() inv_ctx = self._make_invocation_context("schema_explorer", session) - cb_ctx = callback_context_lib.CallbackContext(invocation_context=inv_ctx) + cb_ctx = CallbackContext(invocation_context=inv_ctx) tool_ctx = tool_context_lib.ToolContext(invocation_context=inv_ctx) mock_agent = inv_ctx.agent tool = self._make_tool("get_table_info") @@ -3766,9 +3904,7 @@ class TestMultiSubagentToolLogging: inv_ctx_t1_orch = self._make_invocation_context( "orchestrator", session, invocation_id="inv-t1" ) - cb_ctx_t1_orch = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t1_orch - ) + cb_ctx_t1_orch = CallbackContext(invocation_context=inv_ctx_t1_orch) # Orchestrator agent_starting await plugin.before_agent_callback( @@ -3781,9 +3917,7 @@ class TestMultiSubagentToolLogging: inv_ctx_t1_sub = self._make_invocation_context( "schema_explorer", session, invocation_id="inv-t1" ) - cb_ctx_t1_sub = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t1_sub - ) + cb_ctx_t1_sub = CallbackContext(invocation_context=inv_ctx_t1_sub) tool_ctx_t1 = tool_context_lib.ToolContext( invocation_context=inv_ctx_t1_sub ) @@ -3831,9 +3965,7 @@ class TestMultiSubagentToolLogging: inv_ctx_t2_orch = self._make_invocation_context( "orchestrator", session, invocation_id="inv-t2" ) - cb_ctx_t2_orch = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t2_orch - ) + cb_ctx_t2_orch = CallbackContext(invocation_context=inv_ctx_t2_orch) await plugin.before_agent_callback( agent=inv_ctx_t2_orch.agent, @@ -3845,9 +3977,7 @@ class TestMultiSubagentToolLogging: inv_ctx_t2_sub = self._make_invocation_context( "image_describer", session, invocation_id="inv-t2" ) - cb_ctx_t2_sub = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t2_sub - ) + cb_ctx_t2_sub = CallbackContext(invocation_context=inv_ctx_t2_sub) tool_ctx_t2 = tool_context_lib.ToolContext( invocation_context=inv_ctx_t2_sub ) @@ -4665,3 +4795,1265 @@ class TestHITLTracingEndToEnd: ), f"Expected no HITL events for regular tool, got {hitl_events}" await bq_plugin.shutdown() + + +# ============================================================================== +# Fork-Safety Tests +# ============================================================================== +class TestForkSafety: + """Tests for fork-safety via PID tracking.""" + + def _make_plugin(self): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + return plugin + + @pytest.mark.asyncio + async def test_pid_change_triggers_reinit( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Simulating a fork by changing _init_pid forces re-init.""" + plugin = self._make_plugin() + await plugin._ensure_started() + assert plugin._started is True + + # Simulate a fork: set _init_pid to a stale value + plugin._init_pid = -1 + assert plugin._started is True # still True before check + + # _ensure_started should detect PID mismatch and reset + await plugin._ensure_started() + # After reset + re-init, _init_pid should match current + import os + + assert plugin._init_pid == os.getpid() + assert plugin._started is True + await plugin.shutdown() + + @pytest.mark.asyncio + async def test_pid_unchanged_skips_reset( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Same PID should not trigger a reset.""" + plugin = self._make_plugin() + await plugin._ensure_started() + + # Save references to verify they are not recreated + original_client = plugin.client + original_parser = plugin.parser + + await plugin._ensure_started() + assert plugin.client is original_client + assert plugin.parser is original_parser + await plugin.shutdown() + + def test_reset_runtime_state_clears_fields(self): + """_reset_runtime_state clears all runtime fields.""" + plugin = self._make_plugin() + # Fake some runtime state + plugin._started = True + plugin._is_shutting_down = True + plugin.client = mock.MagicMock() + plugin._loop_state_by_loop = {"fake": "state"} + plugin._write_stream_name = "some/stream" + plugin._executor = mock.MagicMock() + plugin.offloader = mock.MagicMock() + plugin.parser = mock.MagicMock() + plugin._setup_lock = mock.MagicMock() + # Keep pure-data fields + plugin._schema = ["kept"] + plugin.arrow_schema = "kept_arrow" + + plugin._reset_runtime_state() + + assert plugin._started is False + assert plugin._is_shutting_down is False + assert plugin.client is None + assert plugin._loop_state_by_loop == {} + assert plugin._write_stream_name is None + assert plugin._executor is None + assert plugin.offloader is None + assert plugin.parser is None + assert plugin._setup_lock is None + # Pure-data fields are preserved + assert plugin._schema == ["kept"] + assert plugin.arrow_schema == "kept_arrow" + + import os + + assert plugin._init_pid == os.getpid() + + def test_getstate_resets_pid(self): + """Pickle state should have _init_pid = 0 to force re-init.""" + plugin = self._make_plugin() + state = plugin.__getstate__() + assert state["_init_pid"] == 0 + assert state["_started"] is False + + @pytest.mark.asyncio + async def test_unpickle_legacy_state_missing_init_pid( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Unpickling state from older code without _init_pid should not crash.""" + plugin = self._make_plugin() + state = plugin.__getstate__() + # Simulate legacy pickle state that lacks _init_pid entirely + del state["_init_pid"] + + new_plugin = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin.__new__( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin + ) + ) + new_plugin.__setstate__(state) + + # _init_pid should be backfilled to 0, triggering re-init + assert new_plugin._init_pid == 0 + # _ensure_started should not raise AttributeError + await new_plugin._ensure_started() + assert new_plugin._started is True + await new_plugin.shutdown() + + +# ============================================================================== +# Analytics Views Tests +# ============================================================================== +class TestAnalyticsViews: + """Tests for auto-created per-event-type BigQuery views.""" + + def _make_plugin(self, create_views=True): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + create_views=create_views, + ) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + plugin.client = mock.MagicMock() + plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}" + plugin._schema = bigquery_agent_analytics_plugin._get_events_schema() + return plugin + + def test_views_created_on_new_table(self): + """NotFound path creates all views.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + assert plugin.client.query.call_count == expected_count + + def test_views_created_for_existing_table(self): + """Existing table path also creates views.""" + plugin = self._make_plugin(create_views=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = plugin._schema + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.get_table.return_value = existing + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + assert plugin.client.query.call_count == expected_count + + def test_views_not_created_when_disabled(self): + """create_views=False skips view creation.""" + plugin = self._make_plugin(create_views=False) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + + plugin._ensure_schema_exists() + + plugin.client.query.assert_not_called() + + def test_view_creation_error_logged_not_raised(self): + """Errors during view creation don't crash the plugin.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.query.side_effect = Exception("BQ error") + + # Should not raise + plugin._ensure_schema_exists() + + # Verify it tried to create views (and failed gracefully) + assert plugin.client.query.call_count > 0 + + def test_view_sql_contains_correct_event_filter(self): + """Each SQL has correct WHERE clause and view name.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + calls = plugin.client.query.call_args_list + for call in calls: + sql = call[0][0] + # Each SQL should have CREATE OR REPLACE VIEW + assert "CREATE OR REPLACE VIEW" in sql + # Each SQL should filter by event_type + assert "WHERE" in sql + assert "event_type = " in sql + # View name should start with v_ + assert ".v_" in sql + + # Verify specific views exist + all_sql = " ".join(c[0][0] for c in calls) + for event_type in bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS: + view_name = "v_" + event_type.lower() + assert view_name in all_sql, f"View {view_name} not found in SQL" + + def test_config_create_views_default_true(self): + """Config create_views defaults to True.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + assert config.create_views is True + + @pytest.mark.asyncio + async def test_create_analytics_views_ensures_started( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Public create_analytics_views() initializes plugin first.""" + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + assert plugin._started is False + + await plugin.create_analytics_views() + + # Plugin should be started after the call + assert plugin._started is True + # Views should have been created (query called) + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + # _ensure_schema_exists also creates views, so total calls + # = schema-creation views + explicit views + assert mock_bq_client.query.call_count >= expected_count + await plugin.shutdown() + + def test_views_not_created_after_table_creation_failure(self): + """View creation is skipped when create_table raises a non-Conflict error.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.create_table.side_effect = RuntimeError("BQ down") + + plugin._ensure_schema_exists() + + # Views should NOT be attempted since table creation failed + plugin.client.query.assert_not_called() + + @pytest.mark.asyncio + async def test_create_analytics_views_raises_on_startup_failure( + self, mock_auth_default, mock_write_client + ): + """create_analytics_views() raises if plugin init fails.""" + # Make the BQ Client constructor raise so _lazy_setup fails + # before _started is set to True. + with mock.patch.object( + bigquery, "Client", side_effect=Exception("client boom") + ): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + with pytest.raises( + RuntimeError, match="Plugin initialization failed" + ) as exc_info: + await plugin.create_analytics_views() + # Root cause should be chained for debuggability + assert exc_info.value.__cause__ is not None + assert "client boom" in str(exc_info.value.__cause__) + + +# ============================================================================== +# Trace-ID Continuity Tests (Issue #4645) +# ============================================================================== +class TestTraceIdContinuity: + """Tests for trace_id continuity across all events in an invocation. + + Regression tests for https://github.com/google/adk-python/issues/4645. + + When there is no ambient OTel span (e.g. Agent Engine, custom runners), + early events (USER_MESSAGE_RECEIVED, INVOCATION_STARTING) used to fall + back to ``invocation_id`` while AGENT_STARTING got a new OTel hex + trace_id from ``push_span()``. The ``ensure_invocation_span()`` fix + guarantees a root span is always on the stack before any events fire. + """ + + @pytest.mark.asyncio + async def test_trace_id_continuity_no_ambient_span(self, callback_context): + """All events share one trace_id when no ambient OTel span exists. + + Simulates the #4645 scenario: OTel IS configured (real TracerProvider) + but the Runner's ambient span is NOT present (e.g. Agent Engine, + custom runners). + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + # Create a real TracerProvider and patch the plugin's module-level + # tracer so push_span creates valid spans with proper trace_ids. + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset the span records contextvar for a clean invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel span — we do NOT start_as_current_span. + ambient = trace.get_current_span() + assert not ambient.get_span_context().is_valid + + # ensure_invocation_span should push a new span. + TM.ensure_invocation_span(callback_context) + trace_id_early = TM.get_trace_id(callback_context) + assert trace_id_early is not None + # Should NOT fall back to invocation_id — it should be + # a 32-char hex OTel trace_id. + assert trace_id_early != callback_context.invocation_id + assert len(trace_id_early) == 32 + + # Simulate agent callback: push_span("agent") + TM.push_span(callback_context, "agent") + trace_id_agent = TM.get_trace_id(callback_context) + + # Both trace_ids must be identical. + assert trace_id_early == trace_id_agent + + # Cleanup + TM.pop_span() # agent + TM.pop_span() # invocation + + provider.shutdown() + + @pytest.mark.asyncio + async def test_invocation_completed_trace_continuity_no_ambient( + self, callback_context + ): + """INVOCATION_COMPLETED must share trace_id with earlier events. + + Reproduces the completion-event fracture: after_run_callback pops + the invocation span, then _log_event would resolve trace_id via + the fallback to invocation_id. The trace_id_override ensures the + completion event keeps the same trace_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset for a clean invocation; no ambient span. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + assert not trace.get_current_span().get_span_context().is_valid + + # --- Simulate the full callback lifecycle --- + # 1. before_run / on_user_message: ensure invocation span + TM.ensure_invocation_span(callback_context) + trace_id_start = TM.get_trace_id(callback_context) + + # 2. before_agent: push agent span + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_start + + # 3. after_agent: pop agent span + TM.pop_span() + + # 4. after_run: capture trace_id THEN pop invocation span + trace_id_before_pop = TM.get_trace_id(callback_context) + assert trace_id_before_pop == trace_id_start + + TM.pop_span() + + # After popping, get_trace_id falls back to invocation_id + trace_id_after_pop = TM.get_trace_id(callback_context) + assert trace_id_after_pop == callback_context.invocation_id + + # The trace_id_override preserves continuity + assert trace_id_before_pop == trace_id_start + assert trace_id_before_pop != trace_id_after_pop + + provider.shutdown() + + @pytest.mark.asyncio + async def test_callbacks_emit_same_trace_id_no_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Full callback path: all emitted rows share one trace_id. + + Exercises the real before_run → before_agent → after_agent → + after_run callback chain via the plugin instance, then checks + every emitted BQ row has the same trace_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset span records for a clean invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient span — simulates Agent Engine / custom runner. + assert not trace.get_current_span().get_span_context().is_valid + + # Run the full callback lifecycle. + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + # Collect all emitted rows. + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + event_types = [r["event_type"] for r in rows] + assert "INVOCATION_STARTING" in event_types + assert "INVOCATION_COMPLETED" in event_types + + # Every row must share the same trace_id. + trace_ids = {r["trace_id"] for r in rows} + assert len(trace_ids) == 1, ( + "Expected 1 unique trace_id across all events, got" + f" {len(trace_ids)}: {trace_ids}" + ) + # Should be a 32-char hex OTel trace, not the invocation_id. + sole_trace_id = trace_ids.pop() + assert sole_trace_id != invocation_context.invocation_id + assert len(sole_trace_id) == 32 + + provider.shutdown() + + @pytest.mark.asyncio + async def test_trace_id_continuity_with_ambient_span(self, callback_context): + """All events share one trace_id when an ambient OTel span exists.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + # Set up a real OTel tracer. + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset the span records contextvar. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + with real_tracer.start_as_current_span("runner_invocation"): + ambient = trace.get_current_span() + assert ambient.get_span_context().is_valid + ambient_trace_id = format(ambient.get_span_context().trace_id, "032x") + + # ensure_invocation_span should attach the ambient span. + TM.ensure_invocation_span(callback_context) + trace_id_early = TM.get_trace_id(callback_context) + assert trace_id_early == ambient_trace_id + + # Simulate agent callback: push_span("agent") + TM.push_span(callback_context, "agent") + trace_id_agent = TM.get_trace_id(callback_context) + assert trace_id_agent == ambient_trace_id + + # Cleanup + TM.pop_span() # agent + TM.pop_span() # invocation (attached, not owned) + + provider.shutdown() + + @pytest.mark.asyncio + async def test_invocation_root_span_isolated_across_turns( + self, callback_context + ): + """Each invocation gets its own root span; turns don't leak.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # --- Turn 1 --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.ensure_invocation_span(callback_context) + trace_id_turn1 = TM.get_trace_id(callback_context) + + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_turn1 + TM.pop_span() # agent + TM.pop_span() # invocation + + # After popping, the stack should be empty. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert not records + + # --- Turn 2 --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.ensure_invocation_span(callback_context) + trace_id_turn2 = TM.get_trace_id(callback_context) + + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_turn2 + TM.pop_span() # agent + TM.pop_span() # invocation + + # The two turns must have DIFFERENT trace_ids (different + # root spans). + assert trace_id_turn1 != trace_id_turn2 + + provider.shutdown() + + +class TestSpanIdConsistency: + """Tests that STARTING/COMPLETED event pairs share span IDs. + + Span-ID resolution contract: + - When OTel is active: BQ rows use the same trace/span/parent IDs as + Cloud Trace (ambient framework spans). STARTING and COMPLETED events + in the same lifecycle share the same span_id. + - When OTel is not active: BQ rows use the plugin's internal span + stack. STARTING gets the current top-of-stack; COMPLETED gets the + popped span. + """ + + @pytest.mark.asyncio + async def test_starting_completed_same_span_with_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """With ambient OTel, STARTING and COMPLETED get the same span_id.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # Simulate the framework's ambient spans. + with real_tracer.start_as_current_span("invocation"): + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + with real_tracer.start_as_current_span("invoke_agent"): + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + agent_starting = [r for r in rows if r["event_type"] == "AGENT_STARTING"] + agent_completed = [ + r for r in rows if r["event_type"] == "AGENT_COMPLETED" + ] + + assert len(agent_starting) == 1 + assert len(agent_completed) == 1 + + # Both events must share the same span_id (the ambient + # invoke_agent span) — no plugin-synthetic override. + assert agent_starting[0]["span_id"] == agent_completed[0]["span_id"] + assert ( + agent_starting[0]["parent_span_id"] + == agent_completed[0]["parent_span_id"] + ) + + provider.shutdown() + + @pytest.mark.asyncio + async def test_starting_completed_use_plugin_span_without_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Without ambient OTel, COMPLETED gets the popped plugin span.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel span. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + agent_starting = [r for r in rows if r["event_type"] == "AGENT_STARTING"] + agent_completed = [ + r for r in rows if r["event_type"] == "AGENT_COMPLETED" + ] + + assert len(agent_starting) == 1 + assert len(agent_completed) == 1 + + # AGENT_STARTING gets the top-of-stack span; AGENT_COMPLETED + # gets the popped span via override — they should match. + assert agent_starting[0]["span_id"] == agent_completed[0]["span_id"] + + provider.shutdown() + + @pytest.mark.asyncio + async def test_tool_error_captures_span_id( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + dummy_arrow_schema, + ): + """on_tool_error_callback uses the popped span_id (bonus fix).""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + mock_tool = mock.create_autospec(base_tool_lib.BaseTool, instance=True) + type(mock_tool).name = mock.PropertyMock(return_value="my_tool") + tool_ctx = tool_context_lib.ToolContext( + invocation_context=invocation_context + ) + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel — plugin span stack provides IDs. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + # Push tool span via before_tool_callback + await bq_plugin_inst.before_tool_callback( + tool=mock_tool, + tool_args={"a": 1}, + tool_context=tool_ctx, + ) + # Error callback should pop the tool span and use its ID + await bq_plugin_inst.on_tool_error_callback( + tool=mock_tool, + tool_args={"a": 1}, + tool_context=tool_ctx, + error=RuntimeError("boom"), + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + tool_starting = [r for r in rows if r["event_type"] == "TOOL_STARTING"] + tool_error = [r for r in rows if r["event_type"] == "TOOL_ERROR"] + + assert len(tool_starting) == 1 + assert len(tool_error) == 1 + + # The TOOL_ERROR event must have the same span_id as + # TOOL_STARTING (both correspond to the same tool span). + assert tool_starting[0]["span_id"] == tool_error[0]["span_id"] + assert tool_error[0]["span_id"] is not None + + provider.shutdown() + + +class TestStackLeakSafety: + """Tests for stack leak safety (P2). + + Ensures the plugin's internal span stack doesn't leak records + across invocations when after_run_callback is skipped. + """ + + def test_ensure_invocation_span_clears_stale_records(self, callback_context): + """Pre-populated stack from a different invocation is cleared.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Simulate stale records from incomplete previous invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + # Mark the stale records as belonging to a different invocation. + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set( + "old-inv-stale" + ) + TM.push_span(callback_context, "stale-invocation") + TM.push_span(callback_context, "stale-agent") + + stale_records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert len(stale_records) == 2 + + # ensure_invocation_span with the *current* invocation_id should + # detect the mismatch, clear stale records, and re-init. + TM.ensure_invocation_span(callback_context) + + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + # Should have exactly 1 fresh entry (the new invocation span). + assert len(records) == 1 + # The fresh span should NOT be one of the stale ones. + assert records[0].span_id != stale_records[0].span_id + assert records[0].span_id != stale_records[1].span_id + + provider.shutdown() + + def test_clear_stack_ends_owned_spans(self, callback_context): + """clear_stack() ends all owned spans.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + exporter = InMemorySpanExporter() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.push_span(callback_context, "span-a") + TM.push_span(callback_context, "span-b") + + records = list(bigquery_agent_analytics_plugin._span_records_ctx.get()) + assert all(r.owns_span for r in records) + + TM.clear_stack() + + # Stack must be empty after clear. + result = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert result == [] + + # Both owned spans should have been ended (exported). + exported = exporter.get_finished_spans() + assert len(exported) == 2 + + provider.shutdown() + + @pytest.mark.asyncio + async def test_after_run_callback_clears_remaining_stack( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """after_run_callback clears any leftover stack entries.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient span. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + # Push an agent span but DON'T pop it (simulate missing + # after_agent_callback due to exception). + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + # Stack now has [invocation, agent]. + + # after_run_callback should pop invocation + clear remaining. + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + # Stack must be empty. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert records == [] + + provider.shutdown() + + @pytest.mark.asyncio + async def test_next_invocation_clean_after_incomplete_previous( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + mock_session, + ): + """Next invocation starts clean even if previous was incomplete.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # --- Incomplete invocation 1: no after_run_callback --- + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + # Skip after_agent and after_run — simulates exception. + + stale = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert len(stale) >= 2 # invocation + agent + + # --- Invocation 2 with a different invocation_id --- + mock_write_client.append_rows.reset_mock() + inv_ctx_2 = InvocationContext( + agent=mock_agent, + session=mock_session, + invocation_id="inv-NEW-002", + session_service=invocation_context.session_service, + plugin_manager=invocation_context.plugin_manager, + ) + await bq_plugin_inst.before_run_callback(invocation_context=inv_ctx_2) + + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + # Should have exactly 1 fresh invocation span. + assert len(records) == 1 + + # Cleanup + await bq_plugin_inst.after_run_callback(invocation_context=inv_ctx_2) + + provider.shutdown() + + def test_ensure_invocation_span_idempotent_same_invocation( + self, callback_context + ): + """Calling ensure_invocation_span twice in the same invocation is a no-op.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # First call: creates invocation span. + TM.ensure_invocation_span(callback_context) + records_after_first = list( + bigquery_agent_analytics_plugin._span_records_ctx.get() + ) + assert len(records_after_first) == 1 + first_span_id = records_after_first[0].span_id + + # Second call (same invocation): must be a no-op. + TM.ensure_invocation_span(callback_context) + records_after_second = ( + bigquery_agent_analytics_plugin._span_records_ctx.get() + ) + assert len(records_after_second) == 1 + assert records_after_second[0].span_id == first_span_id + + # Cleanup + TM.pop_span() + + provider.shutdown() + + @pytest.mark.asyncio + async def test_user_message_then_before_run_same_trace_no_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Regression: on_user_message → before_run must share one trace_id. + + Without the invocation-ID guard, the second ensure_invocation_span() + call would clear the stack and create a new root span with a + different trace_id, fracturing USER_MESSAGE_RECEIVED from + INVOCATION_STARTING. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # No ambient span. + assert not trace.get_current_span().get_span_context().is_valid + + user_msg = types.Content(parts=[types.Part(text="hello")], role="user") + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=user_msg, + ) + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + event_types = [r["event_type"] for r in rows] + assert "USER_MESSAGE_RECEIVED" in event_types + assert "INVOCATION_STARTING" in event_types + + # Every row must share the same trace_id. + trace_ids = {r["trace_id"] for r in rows} + assert len(trace_ids) == 1, ( + "Expected 1 unique trace_id across all events, got" + f" {len(trace_ids)}: {trace_ids}" + ) + + provider.shutdown() + + +class TestRootAgentNameAcrossInvocations: + """Regression: root_agent_name must refresh across invocations.""" + + @pytest.mark.asyncio + async def test_root_agent_name_updates_between_invocations( + self, + bq_plugin_inst, + mock_write_client, + mock_session, + dummy_arrow_schema, + ): + """Two invocations with different root agents must log correct names. + + Previously init_trace() only set _root_agent_name_ctx when it was + None, so the second invocation would inherit the first's root agent. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + mock_session_service = mock.create_autospec( + base_session_service_lib.BaseSessionService, + instance=True, + spec_set=True, + ) + mock_plugin_manager = mock.create_autospec( + plugin_manager_lib.PluginManager, + instance=True, + spec_set=True, + ) + + def _make_inv_ctx(agent_name, inv_id): + agent = mock.create_autospec( + base_agent.BaseAgent, instance=True, spec_set=True + ) + type(agent).name = mock.PropertyMock(return_value=agent_name) + type(agent).instruction = mock.PropertyMock(return_value="") + # root_agent returns itself (no parent). + agent.root_agent = agent + return InvocationContext( + agent=agent, + session=mock_session, + invocation_id=inv_id, + session_service=mock_session_service, + plugin_manager=mock_plugin_manager, + ) + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # --- Invocation 1: root agent = "RootA" --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None) + + inv1 = _make_inv_ctx("RootA", "inv-001") + cb1 = CallbackContext(inv1) + await bq_plugin_inst.before_run_callback(invocation_context=inv1) + await bq_plugin_inst.before_agent_callback( + agent=inv1.agent, callback_context=cb1 + ) + await bq_plugin_inst.after_agent_callback( + agent=inv1.agent, callback_context=cb1 + ) + await bq_plugin_inst.after_run_callback(invocation_context=inv1) + await asyncio.sleep(0.01) + + rows_inv1 = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + # --- Invocation 2: root agent = "RootB" --- + mock_write_client.append_rows.reset_mock() + + inv2 = _make_inv_ctx("RootB", "inv-002") + cb2 = CallbackContext(inv2) + await bq_plugin_inst.before_run_callback(invocation_context=inv2) + await bq_plugin_inst.before_agent_callback( + agent=inv2.agent, callback_context=cb2 + ) + await bq_plugin_inst.after_agent_callback( + agent=inv2.agent, callback_context=cb2 + ) + await bq_plugin_inst.after_run_callback(invocation_context=inv2) + await asyncio.sleep(0.01) + + rows_inv2 = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + # Parse root_agent_name from the attributes JSON column. + def _get_root_names(rows): + names = set() + for r in rows: + attrs = r.get("attributes") + if attrs: + parsed = json.loads(attrs) if isinstance(attrs, str) else attrs + if "root_agent_name" in parsed: + names.add(parsed["root_agent_name"]) + return names + + names_inv1 = _get_root_names(rows_inv1) + names_inv2 = _get_root_names(rows_inv2) + + # Invocation 1 should only have "RootA". + assert names_inv1 == {"RootA"}, f"Expected {{'RootA'}}, got {names_inv1}" + # Invocation 2 must have "RootB", NOT stale "RootA". + assert names_inv2 == {"RootB"}, f"Expected {{'RootB'}}, got {names_inv2}" + + provider.shutdown() + + +class TestAfterRunCleanupExceptionSafety: + """after_run_callback cleanup must execute even if _log_event fails.""" + + @pytest.mark.asyncio + async def test_cleanup_runs_when_log_event_raises( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + ): + """Stale state is cleared even when _log_event raises.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None) + + # Run a normal before_run to initialise state. + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + + # Verify state is populated. + assert bigquery_agent_analytics_plugin._span_records_ctx.get() + assert ( + bigquery_agent_analytics_plugin._active_invocation_id_ctx.get() + is not None + ) + + # Make _log_event raise inside after_run_callback. + with mock.patch.object( + bq_plugin_inst, + "_log_event", + side_effect=RuntimeError("boom"), + ): + # _safe_callback swallows the exception, but cleanup in + # the finally block must still execute. + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + # All invocation state must be cleaned up despite the error. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert records == [] or records is None + assert ( + bigquery_agent_analytics_plugin._active_invocation_id_ctx.get() + is None + ) + assert bigquery_agent_analytics_plugin._root_agent_name_ctx.get() is None + + provider.shutdown() From 5770cd3776c8805086ece34d747e589e36916a34 Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 2 Mar 2026 15:47:43 -0800 Subject: [PATCH 077/102] feat: Add streaming support for Anthropic models Refactor ToolResultBlockParam content handling to use json.dumps for dict/list results. Implement _generate_content_streaming to handle Anthropic's streaming API Close #3250 Co-authored-by: George Weale PiperOrigin-RevId: 877613612 --- src/google/adk/models/anthropic_llm.py | 128 ++++++- tests/unittests/models/test_anthropic_llm.py | 355 +++++++++++++++++++ 2 files changed, 474 insertions(+), 9 deletions(-) diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 97992096..1f7f37b0 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -17,7 +17,9 @@ from __future__ import annotations import base64 +import dataclasses from functools import cached_property +import json import logging import os from typing import Any @@ -31,6 +33,7 @@ from typing import Union from anthropic import AsyncAnthropic from anthropic import AsyncAnthropicVertex from anthropic import NOT_GIVEN +from anthropic import NotGiven from anthropic import types as anthropic_types from google.genai import types from pydantic import BaseModel @@ -48,6 +51,15 @@ __all__ = ["AnthropicLlm", "Claude"] logger = logging.getLogger("google_adk." + __name__) +@dataclasses.dataclass +class _ToolUseAccumulator: + """Accumulates streamed tool_use content block data.""" + + id: str + name: str + args_json: str + + class ClaudeRequest(BaseModel): system_instruction: str messages: Iterable[anthropic_types.MessageParam] @@ -115,12 +127,15 @@ def part_to_message_block( else: content_items.append(str(item)) content = "\n".join(content_items) if content_items else "" - # Handle traditional result format - elif "result" in response_data and response_data["result"]: - # Transformation is required because the content is a list of dict. - # ToolResultBlockParam content doesn't support list of dict. Converting - # to str to prevent anthropic.BadRequestError from being thrown. - content = str(response_data["result"]) + # We serialize to str here + # SDK ref: anthropic.types.tool_result_block_param + # https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_result_block_param.py + elif "result" in response_data and response_data["result"] is not None: + result = response_data["result"] + if isinstance(result, (dict, list)): + content = json.dumps(result) + else: + content = str(result) return anthropic_types.ToolResultBlockParam( tool_use_id=part.function_response.id or "", @@ -305,16 +320,111 @@ class AnthropicLlm(BaseLlm): if llm_request.tools_dict else NOT_GIVEN ) - # TODO(b/421255973): Enable streaming for anthropic models. - message = await self._anthropic_client.messages.create( + + if not stream: + message = await self._anthropic_client.messages.create( + model=llm_request.model, + system=llm_request.config.system_instruction, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_tokens=self.max_tokens, + ) + yield message_to_generate_content_response(message) + else: + async for response in self._generate_content_streaming( + llm_request, messages, tools, tool_choice + ): + yield response + + async def _generate_content_streaming( + self, + llm_request: LlmRequest, + messages: list[anthropic_types.MessageParam], + tools: Union[Iterable[anthropic_types.ToolUnionParam], NotGiven], + tool_choice: Union[anthropic_types.ToolChoiceParam, NotGiven], + ) -> AsyncGenerator[LlmResponse, None]: + """Handles streaming responses from Anthropic models. + + Yields partial LlmResponse objects as content arrives, followed by + a final aggregated LlmResponse with all content. + """ + raw_stream = await self._anthropic_client.messages.create( model=llm_request.model, system=llm_request.config.system_instruction, messages=messages, tools=tools, tool_choice=tool_choice, max_tokens=self.max_tokens, + stream=True, + ) + + # Track content blocks being built during streaming. + # Each entry maps a block index to its accumulated state. + text_blocks: dict[int, str] = {} + tool_use_blocks: dict[int, _ToolUseAccumulator] = {} + input_tokens = 0 + output_tokens = 0 + + async for event in raw_stream: + if event.type == "message_start": + input_tokens = event.message.usage.input_tokens + output_tokens = event.message.usage.output_tokens + + elif event.type == "content_block_start": + block = event.content_block + if isinstance(block, anthropic_types.TextBlock): + text_blocks[event.index] = block.text + elif isinstance(block, anthropic_types.ToolUseBlock): + tool_use_blocks[event.index] = _ToolUseAccumulator( + id=block.id, + name=block.name, + args_json="", + ) + + elif event.type == "content_block_delta": + delta = event.delta + if isinstance(delta, anthropic_types.TextDelta): + text_blocks.setdefault(event.index, "") + text_blocks[event.index] += delta.text + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text=delta.text)], + ), + partial=True, + ) + elif isinstance(delta, anthropic_types.InputJSONDelta): + if event.index in tool_use_blocks: + tool_use_blocks[event.index].args_json += delta.partial_json + + elif event.type == "message_delta": + output_tokens = event.usage.output_tokens + + # Build the final aggregated response with all content. + all_parts: list[types.Part] = [] + all_indices = sorted( + set(list(text_blocks.keys()) + list(tool_use_blocks.keys())) + ) + for idx in all_indices: + if idx in text_blocks: + all_parts.append(types.Part.from_text(text=text_blocks[idx])) + if idx in tool_use_blocks: + acc = tool_use_blocks[idx] + args = json.loads(acc.args_json) if acc.args_json else {} + part = types.Part.from_function_call(name=acc.name, args=args) + part.function_call.id = acc.id + all_parts.append(part) + + yield LlmResponse( + content=types.Content(role="model", parts=all_parts), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=input_tokens, + candidates_token_count=output_tokens, + total_token_count=input_tokens + output_tokens, + ), + partial=False, ) - yield message_to_generate_content_response(message) @cached_property def _anthropic_client(self) -> AsyncAnthropic: diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index fac5f462..50759659 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys from unittest import mock +from unittest.mock import AsyncMock +from unittest.mock import MagicMock from anthropic import types as anthropic_types from google.adk import version as adk_version @@ -23,6 +26,7 @@ from google.adk.models.anthropic_llm import AnthropicLlm from google.adk.models.anthropic_llm import Claude from google.adk.models.anthropic_llm import content_to_message_param from google.adk.models.anthropic_llm import function_declaration_to_tool_param +from google.adk.models.anthropic_llm import part_to_message_block from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.genai import types @@ -598,3 +602,354 @@ def test_content_to_message_param_with_images( ) else: mock_logger.warning.assert_not_called() + + +# --- Tests for Bug #2: json.dumps for dict/list function results --- + + +def test_part_to_message_block_dict_result_serialized_as_json(): + """Dict results should be serialized with json.dumps, not str().""" + response_part = types.Part.from_function_response( + name="get_topic", + response={"result": {"topic": "travel", "active": True, "count": None}}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + content = result["content"] + + # Must be valid JSON (json.dumps produces "true"/"null", not "True"/"None") + parsed = json.loads(content) + assert parsed["topic"] == "travel" + assert parsed["active"] is True + assert parsed["count"] is None + + +def test_part_to_message_block_list_result_serialized_as_json(): + """List results should be serialized with json.dumps.""" + response_part = types.Part.from_function_response( + name="get_items", + response={"result": ["item1", "item2", "item3"]}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + content = result["content"] + + parsed = json.loads(content) + assert parsed == ["item1", "item2", "item3"] + + +def test_part_to_message_block_empty_dict_result_not_dropped(): + """Empty dict results should produce '{}', not empty string.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"result": {}}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + assert result["content"] == "{}" + + +def test_part_to_message_block_empty_list_result_not_dropped(): + """Empty list results should produce '[]', not empty string.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"result": []}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + assert result["content"] == "[]" + + +def test_part_to_message_block_string_result_unchanged(): + """String results should still work as before (backward compat).""" + response_part = types.Part.from_function_response( + name="simple_tool", + response={"result": "plain text result"}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + assert result["content"] == "plain text result" + + +def test_part_to_message_block_nested_dict_result(): + """Nested dict with arrays should produce valid JSON.""" + response_part = types.Part.from_function_response( + name="search", + response={ + "result": { + "results": [ + {"id": 1, "tags": ["a", "b"]}, + {"id": 2, "meta": {"key": "val"}}, + ], + "has_more": False, + } + }, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + parsed = json.loads(result["content"]) + assert parsed["has_more"] is False + assert parsed["results"][0]["tags"] == ["a", "b"] + + +# --- Tests for Bug #1: Streaming support --- + + +def _make_mock_stream_events(events): + """Helper to create an async iterable from a list of events.""" + + async def _stream(): + for event in events: + yield event + + return _stream() + + +@pytest.mark.asyncio +async def test_streaming_text_yields_partial_and_final(): + """Streaming text should yield partial chunks then a final response.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=10, output_tokens=0)), + ), + MagicMock( + type="content_block_start", + index=0, + content_block=anthropic_types.TextBlock(text="", type="text"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="Hello ", type="text_delta"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="world!", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="end_turn"), + usage=MagicMock(output_tokens=5), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction="You are helpful", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=True) + ] + + # 2 partial text chunks + 1 final aggregated + assert len(responses) == 3 + assert responses[0].partial is True + assert responses[0].content.parts[0].text == "Hello " + assert responses[1].partial is True + assert responses[1].content.parts[0].text == "world!" + assert responses[2].partial is False + assert responses[2].content.parts[0].text == "Hello world!" + assert responses[2].usage_metadata.prompt_token_count == 10 + assert responses[2].usage_metadata.candidates_token_count == 5 + + +@pytest.mark.asyncio +async def test_streaming_tool_use_yields_function_call(): + """Streaming tool_use should accumulate args and yield in final.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=20, output_tokens=0)), + ), + MagicMock( + type="content_block_start", + index=0, + content_block=anthropic_types.TextBlock(text="", type="text"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="Checking.", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="content_block_start", + index=1, + content_block=anthropic_types.ToolUseBlock( + id="toolu_abc", + name="get_weather", + input={}, + type="tool_use", + ), + ), + MagicMock( + type="content_block_delta", + index=1, + delta=anthropic_types.InputJSONDelta( + partial_json='{"city": "Paris"}', + type="input_json_delta", + ), + ), + MagicMock(type="content_block_stop", index=1), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="tool_use"), + usage=MagicMock(output_tokens=12), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[ + Content( + role="user", + parts=[Part.from_text(text="Weather?")], + ) + ], + config=types.GenerateContentConfig( + system_instruction="You are helpful", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=True) + ] + + # 1 text partial + 1 final + assert len(responses) == 2 + + final = responses[-1] + assert final.partial is False + assert len(final.content.parts) == 2 + assert final.content.parts[0].text == "Checking." + assert final.content.parts[1].function_call.name == "get_weather" + assert final.content.parts[1].function_call.args == {"city": "Paris"} + assert final.content.parts[1].function_call.id == "toolu_abc" + + +@pytest.mark.asyncio +async def test_streaming_passes_stream_true_to_create(): + """When stream=True, messages.create should be called with stream=True.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)), + ), + MagicMock( + type="content_block_start", + index=0, + content_block=anthropic_types.TextBlock(text="", type="text"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="Hi", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="end_turn"), + usage=MagicMock(output_tokens=1), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction="Test", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + _ = [r async for r in llm.generate_content_async(llm_request, stream=True)] + + mock_client.messages.create.assert_called_once() + _, kwargs = mock_client.messages.create.call_args + assert kwargs["stream"] is True + + +@pytest.mark.asyncio +async def test_non_streaming_does_not_pass_stream_param(): + """When stream=False, messages.create should NOT get stream param.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + mock_message = anthropic_types.Message( + id="msg_test", + content=[ + anthropic_types.TextBlock(text="Hello!", type="text", citations=None) + ], + model="claude-sonnet-4-20250514", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=anthropic_types.Usage( + input_tokens=5, + output_tokens=2, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + server_tool_use=None, + service_tier=None, + ), + ) + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(return_value=mock_message) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction="Test", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + mock_client.messages.create.assert_called_once() + _, kwargs = mock_client.messages.create.call_args + assert "stream" not in kwargs From dd0851ac74d358bc030def5adf242d875ab18265 Mon Sep 17 00:00:00 2001 From: Drew Afromsky Date: Mon, 2 Mar 2026 17:10:12 -0800 Subject: [PATCH 078/102] fix: Update expected UsageMetadataChunk in LiteLLM tests Close #4680 Co-authored-by: George Weale PiperOrigin-RevId: 877646178 --- tests/unittests/models/test_litellm.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 8e353efb..e87fec43 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -2819,7 +2819,7 @@ def test_to_litellm_role(): "content": "this is a test", } } - ] + ], ), [TextChunk(text="this is a test")], UsageMetadataChunk( @@ -2877,7 +2877,9 @@ def test_to_litellm_role(): (None, "stop"), ), ( - ModelResponse(choices=[{"finish_reason": "tool_calls"}]), + ModelResponse( + choices=[{"finish_reason": "tool_calls"}], + ), [None], UsageMetadataChunk( prompt_tokens=0, completion_tokens=0, total_tokens=0 @@ -2962,7 +2964,8 @@ def test_to_litellm_role(): finish_reason=None, delta=Delta(role="assistant", content="Hello"), ) - ] + ], + usage=None, ), [TextChunk(text="Hello")], None, @@ -2977,7 +2980,8 @@ def test_to_litellm_role(): role="assistant", reasoning_content="thinking..." ), ) - ] + ], + usage=None, ), [ ReasoningChunk( From 245b2b9874246b678774572988f53c6b7da7d4e2 Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 3 Mar 2026 08:12:38 -0800 Subject: [PATCH 079/102] fix: Add usage field to ModelResponse in LiteLLM tests Co-authored-by: George Weale PiperOrigin-RevId: 877954087 --- tests/unittests/models/test_litellm.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index e87fec43..aa19bfa8 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -2820,6 +2820,11 @@ def test_to_litellm_role(): } } ], + usage={ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, ), [TextChunk(text="this is a test")], UsageMetadataChunk( @@ -2879,6 +2884,11 @@ def test_to_litellm_role(): ( ModelResponse( choices=[{"finish_reason": "tool_calls"}], + usage={ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, ), [None], UsageMetadataChunk( @@ -2887,7 +2897,14 @@ def test_to_litellm_role(): "tool_calls", ), ( - ModelResponse(choices=[{}]), + ModelResponse( + choices=[{}], + usage={ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + ), [None], UsageMetadataChunk( prompt_tokens=0, completion_tokens=0, total_tokens=0 From 4e3e2cb58858e08a79bc6119ad49b6c049dbc0d0 Mon Sep 17 00:00:00 2001 From: Keyur Joshi Date: Tue, 3 Mar 2026 10:16:06 -0800 Subject: [PATCH 080/102] feat: Add GEPA root agent prompt optimizer details: * Uses GEPA (https://gepa-ai.github.io/gepa/) to optimize the instructions for the root agent. Can be extended to sub-agents and other components in the future. * GEPA package is imported dynamically; you do not need to install it along with ADK unless you plan to use this optimizer. Co-authored-by: Keyur Joshi PiperOrigin-RevId: 878009649 --- pyproject.toml | 1 + .../gepa_root_agent_prompt_optimizer.py | 323 ++++++++++++++++++ .../gepa_root_agent_prompt_optimizer_test.py | 264 ++++++++++++++ 3 files changed, 588 insertions(+) create mode 100644 src/google/adk/optimization/gepa_root_agent_prompt_optimizer.py create mode 100644 tests/unittests/optimization/gepa_root_agent_prompt_optimizer_test.py diff --git a/pyproject.toml b/pyproject.toml index 0441c72d..d0f3cd94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ community = [ eval = [ # go/keep-sorted start "Jinja2>=3.1.4,<4.0.0", # For eval template rendering + "gepa>=0.1.0", "google-cloud-aiplatform[evaluation]>=1.100.0", "pandas>=2.2.3", "rouge-score>=0.1.2", diff --git a/src/google/adk/optimization/gepa_root_agent_prompt_optimizer.py b/src/google/adk/optimization/gepa_root_agent_prompt_optimizer.py new file mode 100644 index 00000000..0627aced --- /dev/null +++ b/src/google/adk/optimization/gepa_root_agent_prompt_optimizer.py @@ -0,0 +1,323 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from typing import Optional + +from google.genai import types as genai_types +from pydantic import BaseModel +from pydantic import Field + +from ..agents.llm_agent import Agent +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from ..models.registry import LLMRegistry +from ..utils.context_utils import Aclosing +from ..utils.feature_decorator import experimental +from .agent_optimizer import AgentOptimizer +from .data_types import BaseAgentWithScores +from .data_types import OptimizerResult +from .data_types import UnstructuredSamplingResult +from .sampler import Sampler + +_logger = logging.getLogger("google_adk." + __name__) + +_AGENT_PROMPT_NAME = "agent_prompt" + + +class GEPARootAgentPromptOptimizerConfig(BaseModel): + """Contains configuration options required by the GEPARootAgentPromptOptimizer.""" + + optimizer_model: str = Field( + default="gemini-2.5-flash", + description=( + "The model used to analyze the eval results and optimize the agent." + ), + ) + + model_configuration: genai_types.GenerateContentConfig = Field( + default_factory=lambda: genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig( + include_thoughts=True, + thinking_budget=10240, + ) + ), + description="The configuration for the optimizer model.", + ) + + max_metric_calls: int = Field( + default=100, + description="The maximum number of metric calls (evaluations) to make.", + ) + + reflection_minibatch_size: int = Field( + default=3, + description="The number of examples to use for reflection.", + ) + + run_dir: Optional[str] = Field( + default=None, + description=( + "The directory to save the intermediate/final optimization results." + ), + ) + + +class GEPARootAgentPromptOptimizerResult(OptimizerResult[BaseAgentWithScores]): + """The final result of the GEPARootAgentPromptOptimizer.""" + + gepa_result: Optional[dict[str, Any]] = Field( + default=None, + description="The raw result dictionary from the GEPA optimizer.", + ) + + +def _create_agent_gepa_adapter_class(): + """Creates the _AgentGEPAAdapter class dynamically to avoid top-level gepa imports.""" + from gepa.core.adapter import EvaluationBatch + from gepa.core.adapter import GEPAAdapter + + class _AgentGEPAAdapter(GEPAAdapter[str, dict[str, Any], dict[str, Any]]): + """A GEPA adapter for ADK agents.""" + + def __init__( + self, + initial_agent: Agent, + sampler: Sampler[UnstructuredSamplingResult], + main_loop: asyncio.AbstractEventLoop, + ): + self._initial_agent = initial_agent + self._sampler = sampler + self._main_loop = main_loop + + self._train_example_ids = set(sampler.get_train_example_ids()) + self._validation_example_ids = set(sampler.get_validation_example_ids()) + + def evaluate( + self, + batch: list[str], + candidate: dict[str, str], + capture_traces: bool = False, + ) -> EvaluationBatch[dict[str, Any], dict[str, Any]]: + prompt = candidate[_AGENT_PROMPT_NAME] + _logger.info( + "Evaluating agent on batch:\n%s\nwith prompt:\n%s", batch, prompt + ) + # Clone the agent and update the instruction + new_agent = self._initial_agent.clone(update={"instruction": prompt}) + + if set(batch) <= self._train_example_ids: + example_set = "train" + elif set(batch) <= self._validation_example_ids: + example_set = "validation" + else: + raise ValueError(f"Invalid batch composition: {batch}") + + # Run the evaluation in the main loop + future = asyncio.run_coroutine_threadsafe( + self._sampler.sample_and_score( + new_agent, + example_set=example_set, + batch=batch, + capture_full_eval_data=capture_traces, + ), + self._main_loop, + ) + result: UnstructuredSamplingResult = future.result() + + scores = [] + outputs = [] + trajectories = [] + + for example_id in batch: + score = result.scores[example_id] + scores.append(score) + + eval_data = result.data.get(example_id, {}) if result.data else {} + outputs.append(eval_data) + trajectories.append(eval_data) + + return EvaluationBatch( + outputs=outputs, scores=scores, trajectories=trajectories + ) + + def make_reflective_dataset( + self, + candidate: dict[str, str], + eval_batch: EvaluationBatch[dict[str, Any], dict[str, Any]], + components_to_update: list[str], + ) -> dict[str, list[dict[str, Any]]]: + dataset: list[dict[str, Any]] = [] + trace_instances: list[tuple[float, dict[str, Any]]] = list( + zip( + eval_batch.scores, + eval_batch.trajectories, + strict=True, + ) + ) + for trace_instance in trace_instances: + score, eval_data = trace_instance + + dataset.append({ + _AGENT_PROMPT_NAME: candidate[_AGENT_PROMPT_NAME], + "score": score, + "eval_data": eval_data, + }) + + # same data for all components (should be only one) + result = {comp: dataset for comp in components_to_update} + + return result + + return _AgentGEPAAdapter + + +@experimental +class GEPARootAgentPromptOptimizer( + AgentOptimizer[UnstructuredSamplingResult, BaseAgentWithScores] +): + """An optimizer that improves the root agent prompt using the GEPA framework.""" + + def __init__( + self, + config: GEPARootAgentPromptOptimizerConfig, + ): + self._config = config + llm_registry = LLMRegistry() + self._llm_class = llm_registry.resolve(self._config.optimizer_model) + + async def optimize( + self, + initial_agent: Agent, + sampler: Sampler[UnstructuredSamplingResult], + ) -> GEPARootAgentPromptOptimizerResult: + """Runs the GEPARootAgentPromptOptimizer. + + Args: + initial_agent: The initial agent whose prompt is to be optimized. Only the + root agent prompt will be optimized. + sampler: The interface used to get training and validation example UIDs, + request agent evaluations, and get useful data for optimizing the agent. + + Returns: + The final result of the optimization process, containing the optimized + agent instance, its scores on the validation examples, and other metrics. + """ + if initial_agent.sub_agents: + _logger.warning( + "The GEPARootAgentPromptOptimizer will not optimize prompts for" + " sub-agents." + ) + + _logger.info("Setting up the GEPA optimizer...") + + try: + import gepa # lazy import as gepa is not in core ADK package + + _AgentGEPAAdapter = _create_agent_gepa_adapter_class() + except ImportError as e: + raise ImportError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e + + loop = asyncio.get_running_loop() + + adapter = _AgentGEPAAdapter( + initial_agent=initial_agent, + sampler=sampler, + main_loop=loop, + ) + + llm = self._llm_class(model=self._config.optimizer_model) + + def reflection_lm(prompt: str) -> str: + llm_request = LlmRequest( + model=self._config.optimizer_model, + config=self._config.model_configuration, + contents=[ + genai_types.Content( + parts=[genai_types.Part(text=prompt)], + role="user", + ) + ], + ) + + async def _generate(): + response_text = "" + async with Aclosing(llm.generate_content_async(llm_request)) as agen: + async for llm_response in agen: + llm_response: LlmResponse + generated_content: genai_types.Content = llm_response.content + if not generated_content.parts: + continue + response_text = "".join( + part.text + for part in generated_content.parts + if part.text and not part.thought + ) + return response_text + + future = asyncio.run_coroutine_threadsafe(_generate(), loop) + return future.result() + + train_ids = sampler.get_train_example_ids() + val_ids = sampler.get_validation_example_ids() + + if set(train_ids).intersection(val_ids): + _logger.warning( + "The training and validation example UIDs overlap. This WILL cause" + " aliasing issues unless each common UID refers to the same example" + " in both sets." + ) + + def run_gepa(): + return gepa.optimize( + seed_candidate={_AGENT_PROMPT_NAME: initial_agent.instruction}, + trainset=train_ids, + valset=val_ids, + adapter=adapter, + max_metric_calls=self._config.max_metric_calls, + reflection_lm=reflection_lm, + reflection_minibatch_size=self._config.reflection_minibatch_size, + run_dir=self._config.run_dir, + ) + + _logger.info("Running the GEPA optimizer...") + + gepa_results = await loop.run_in_executor(None, run_gepa) + + _logger.info("GEPA optimization finished. Preparing final results...") + + optimized_prompts = [ + candidate[_AGENT_PROMPT_NAME] for candidate in gepa_results.candidates + ] + scores = gepa_results.val_aggregate_scores + + optimized_agents = [ + BaseAgentWithScores( + optimized_agent=initial_agent.clone( + update={"instruction": optimized_prompt}, + ), + overall_score=score, + ) + for optimized_prompt, score in zip(optimized_prompts, scores) + ] + + return GEPARootAgentPromptOptimizerResult( + optimized_agents=optimized_agents, + gepa_result=gepa_results.to_dict(), + ) diff --git a/tests/unittests/optimization/gepa_root_agent_prompt_optimizer_test.py b/tests/unittests/optimization/gepa_root_agent_prompt_optimizer_test.py new file mode 100644 index 00000000..bd5da524 --- /dev/null +++ b/tests/unittests/optimization/gepa_root_agent_prompt_optimizer_test.py @@ -0,0 +1,264 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import sys + +from google.adk.agents.llm_agent import Agent +from google.adk.optimization.data_types import UnstructuredSamplingResult +from google.adk.optimization.gepa_root_agent_prompt_optimizer import _create_agent_gepa_adapter_class +from google.adk.optimization.gepa_root_agent_prompt_optimizer import GEPARootAgentPromptOptimizer +from google.adk.optimization.gepa_root_agent_prompt_optimizer import GEPARootAgentPromptOptimizerConfig +from google.adk.optimization.sampler import Sampler +import pytest + + +class MockEvaluationBatch: + + def __init__(self, outputs, scores, trajectories): + self.outputs = outputs + self.scores = scores + self.trajectories = trajectories + + +class MockGEPAAdapter: + """Mock that supports generic type hints.""" + + def __class_getitem__(cls, item): + return cls + + +@pytest.fixture(name="mock_gepa") +def fixture_mock_gepa(mocker): + # mock gepa before it gets imported by the optimizer module + mock_gepa_module = mocker.MagicMock() + mock_gepa_adapter = mocker.MagicMock() + + mock_gepa_adapter.EvaluationBatch = MockEvaluationBatch + mock_gepa_adapter.GEPAAdapter = MockGEPAAdapter + + mock_gepa_module.core = mocker.MagicMock() + mock_gepa_module.core.adapter = mock_gepa_adapter + + mocker.patch.dict( + sys.modules, + { + "gepa": mock_gepa_module, + "gepa.core": mock_gepa_module.core, + "gepa.core.adapter": mock_gepa_adapter, + }, + ) + return mock_gepa_module + + +@pytest.fixture +def mock_sampler(mocker): + sampler = mocker.MagicMock(spec=Sampler) + sampler.get_train_example_ids.return_value = ["train1", "train2"] + sampler.get_validation_example_ids.return_value = ["val1", "val2"] + return sampler + + +@pytest.fixture +def mock_agent(mocker): + agent = mocker.MagicMock(spec=Agent) + agent.instruction = "Initial instruction" + agent.sub_agents = {} + agent.clone.return_value = agent + return agent + + +def test_adapter_init(mock_gepa, mock_sampler, mock_agent): + del mock_gepa # only needed to mock gepa in background + loop = asyncio.new_event_loop() + _AdapterClass = _create_agent_gepa_adapter_class() + adapter = _AdapterClass(mock_agent, mock_sampler, loop) + assert adapter._initial_agent == mock_agent + assert adapter._sampler == mock_sampler + assert adapter._main_loop == loop + assert adapter._train_example_ids == {"train1", "train2"} + assert adapter._validation_example_ids == {"val1", "val2"} + loop.close() + + +def test_adapter_evaluate_train(mocker, mock_gepa, mock_sampler, mock_agent): + del mock_gepa # only needed to mock gepa in background + loop = mocker.MagicMock(spec=asyncio.AbstractEventLoop) + _AdapterClass = _create_agent_gepa_adapter_class() + adapter = _AdapterClass(mock_agent, mock_sampler, loop) + + candidate = {"agent_prompt": "New prompt"} + batch = ["train1"] + + # mock the future returned by run_coroutine_threadsafe + mock_future = mocker.MagicMock() + expected_result = UnstructuredSamplingResult( + scores={"train1": 0.8}, + data={"train1": {"output": "result"}}, + ) + mock_future.result.return_value = expected_result + + mock_rct = mocker.patch( + "asyncio.run_coroutine_threadsafe", return_value=mock_future + ) + eval_batch = adapter.evaluate(batch, candidate, capture_traces=True) + + mock_rct.assert_called_once() + mock_sampler.sample_and_score.assert_called_once_with( + mocker.ANY, + example_set="train", + batch=batch, + capture_full_eval_data=True, + ) + + mock_agent.clone.assert_called_once_with(update={"instruction": "New prompt"}) + + assert isinstance(eval_batch, MockEvaluationBatch) + assert eval_batch.scores == [0.8] + assert eval_batch.outputs == [{"output": "result"}] + assert eval_batch.trajectories == [{"output": "result"}] + + +def test_adapter_evaluate_validation( + mocker, mock_gepa, mock_sampler, mock_agent +): + del mock_gepa # only needed to mock gepa in background + loop = mocker.MagicMock(spec=asyncio.AbstractEventLoop) + _AdapterClass = _create_agent_gepa_adapter_class() + adapter = _AdapterClass(mock_agent, mock_sampler, loop) + + candidate = {"agent_prompt": "New prompt"} + batch = ["val1"] + + mock_future = mocker.MagicMock() + expected_result = UnstructuredSamplingResult(scores={"val1": 0.5}, data={}) + mock_future.result.return_value = expected_result + + mocker.patch("asyncio.run_coroutine_threadsafe", return_value=mock_future) + adapter.evaluate(batch, candidate) + + mock_sampler.sample_and_score.assert_called_once_with( + mocker.ANY, + example_set="validation", + batch=batch, + capture_full_eval_data=False, + ) + + +def test_adapter_make_reflective_dataset( + mocker, mock_gepa, mock_sampler, mock_agent +): + del mock_gepa # only needed to mock gepa in background + loop = mocker.MagicMock(spec=asyncio.AbstractEventLoop) + _AdapterClass = _create_agent_gepa_adapter_class() + adapter = _AdapterClass(mock_agent, mock_sampler, loop) + + candidate = {"agent_prompt": "Prompt"} + eval_batch = MockEvaluationBatch( + outputs=[{"o": 1}, {"o": 2}], + scores=[0.9, 0.1], + trajectories=[{"t": 1}, {"t": 2}], + ) + components = ["component1"] + + dataset = adapter.make_reflective_dataset(candidate, eval_batch, components) + + assert "component1" in dataset + assert len(dataset["component1"]) == 2 + assert dataset["component1"][0] == { + "agent_prompt": "Prompt", + "score": 0.9, + "eval_data": {"t": 1}, + } + assert dataset["component1"][1] == { + "agent_prompt": "Prompt", + "score": 0.1, + "eval_data": {"t": 2}, + } + + +@pytest.mark.asyncio +async def test_optimize(mocker, mock_gepa, mock_sampler, mock_agent): + config = GEPARootAgentPromptOptimizerConfig() + optimizer = GEPARootAgentPromptOptimizer(config) + + # mock LLM + mock_llm_class = mocker.MagicMock() + mock_llm = mocker.MagicMock() + mock_llm_class.return_value = mock_llm + optimizer._llm_class = mock_llm_class + + # mock gepa.optimize return value + mock_gepa_result = mocker.MagicMock() + mock_gepa_result.candidates = [{"agent_prompt": "Optimized instruction"}] + mock_gepa_result.val_aggregate_scores = [0.95] + mock_gepa_result.to_dict.return_value = {"full": "result"} + mock_gepa.optimize.return_value = mock_gepa_result + + result = await optimizer.optimize(mock_agent, mock_sampler) + + mock_gepa.optimize.assert_called_once() + call_kwargs = mock_gepa.optimize.call_args[1] + + assert call_kwargs["seed_candidate"] == { + "agent_prompt": "Initial instruction" + } + assert call_kwargs["trainset"] == ["train1", "train2"] + assert call_kwargs["valset"] == ["val1", "val2"] + + assert len(result.optimized_agents) == 1 + assert result.optimized_agents[0].overall_score == 0.95 + mock_agent.clone.assert_called_with( + update={"instruction": "Optimized instruction"} + ) + assert result.gepa_result == {"full": "result"} + + +@pytest.mark.asyncio +async def test_optimize_logs_warning_on_overlapping_ids( + mocker, mock_gepa, mock_sampler, mock_agent +): + # Setup overlapping IDs + mock_sampler.get_train_example_ids.return_value = ["id1", "id2"] + mock_sampler.get_validation_example_ids.return_value = ["id2", "id3"] + + config = GEPARootAgentPromptOptimizerConfig() + optimizer = GEPARootAgentPromptOptimizer(config) + + # Mock LLM class + mock_llm_class = mocker.MagicMock() + optimizer._llm_class = mock_llm_class + + # Mock gepa.optimize return value + mock_gepa_result = mocker.MagicMock() + mock_gepa_result.candidates = [] + mock_gepa_result.val_aggregate_scores = [] + mock_gepa_result.to_dict.return_value = {} + mock_gepa.optimize.return_value = mock_gepa_result + + mock_logger = mocker.patch( + "google.adk.optimization.gepa_root_agent_prompt_optimizer._logger" + ) + + # Run optimization + await optimizer.optimize(mock_agent, mock_sampler) + + # Verify warning + mock_logger.warning.assert_called_with( + "The training and validation example UIDs overlap. This WILL cause" + " aliasing issues unless each common UID refers to the same example" + " in both sets." + ) From d61846f6c6dd5e357abb0e30eaf61fe27896ae6a Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 3 Mar 2026 13:45:22 -0800 Subject: [PATCH 081/102] fix: Optimize row-level locking in append_event Only acquire FOR UPDATE locks on app and user state rows when the event's state_delta contains changes for those specific scopes. This avoids unnecessary locking on state rows that are not being modified, improving concurrency. Close #4655 Co-authored-by: George Weale PiperOrigin-RevId: 878108562 --- .../adk/sessions/database_session_service.py | 42 +++++---- .../sessions/test_session_service.py | 89 +++++++++++++++++++ 2 files changed, 113 insertions(+), 18 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 24f525ba..6b19464e 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -531,6 +531,16 @@ class DatabaseSessionService(BaseSessionService): schema = self._get_schema_classes() is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT use_row_level_locking = self._supports_row_level_locking() + + state_delta = ( + event.actions.state_delta + if event.actions and event.actions.state_delta + else {} + ) + state_deltas = _session_util.extract_state_delta(state_delta) + has_app_delta = bool(state_deltas["app"]) + has_user_delta = bool(state_deltas["user"]) + async with self._with_session_lock( app_name=session.app_name, user_id=session.user_id, @@ -554,7 +564,7 @@ class DatabaseSessionService(BaseSessionService): sql_session=sql_session, state_model=schema.StorageAppState, predicates=(schema.StorageAppState.app_name == session.app_name,), - use_row_level_locking=use_row_level_locking, + use_row_level_locking=use_row_level_locking and has_app_delta, missing_message=( "App state missing for app_name=" f"{session.app_name!r}. Session state tables should be " @@ -568,7 +578,7 @@ class DatabaseSessionService(BaseSessionService): schema.StorageUserState.app_name == session.app_name, schema.StorageUserState.user_id == session.user_id, ), - use_row_level_locking=use_row_level_locking, + use_row_level_locking=use_row_level_locking and has_user_delta, missing_message=( "User state missing for app_name=" f"{session.app_name!r}, user_id={session.user_id!r}. " @@ -599,23 +609,19 @@ class DatabaseSessionService(BaseSessionService): storage_events = [e async for e in result] session.events = [e.to_event() for e in storage_events] - # Extract state delta - if event.actions and event.actions.state_delta: - state_deltas = _session_util.extract_state_delta( - event.actions.state_delta + # Merge pre-extracted state deltas into storage. + if has_app_delta: + storage_app_state.state = ( + storage_app_state.state | state_deltas["app"] + ) + if has_user_delta: + storage_user_state.state = ( + storage_user_state.state | state_deltas["user"] + ) + if state_deltas["session"]: + storage_session.state = ( + storage_session.state | state_deltas["session"] ) - app_state_delta = state_deltas["app"] - user_state_delta = state_deltas["user"] - session_state_delta = state_deltas["session"] - # Merge state and update storage - if app_state_delta: - storage_app_state.state = storage_app_state.state | app_state_delta - if user_state_delta: - storage_user_state.state = ( - storage_user_state.state | user_state_delta - ) - if session_state_delta: - storage_session.state = storage_session.state | session_state_delta if is_sqlite: update_time = datetime.fromtimestamp( diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 25530bed..4e277195 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -1153,3 +1153,92 @@ async def test_prepare_tables_idempotent_after_creation(): assert session.id == 's1' finally: await service.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'state_delta, expect_app_lock, expect_user_lock', + [ + pytest.param( + None, + False, + False, + id='no_state_delta', + ), + pytest.param( + {'session_key': 'v'}, + False, + False, + id='session_only_delta', + ), + pytest.param( + {'app:key': 'v'}, + True, + False, + id='app_delta_only', + ), + pytest.param( + {'user:key': 'v'}, + False, + True, + id='user_delta_only', + ), + pytest.param( + {'app:a': '1', 'user:b': '2', 'sk': '3'}, + True, + True, + id='all_scopes', + ), + ], +) +async def test_append_event_locks_only_scopes_with_deltas( + state_delta, expect_app_lock, expect_user_lock +): + """FOR UPDATE should only be requested for state scopes that have deltas.""" + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + + lock_requests = [] + original_fn = database_session_service._select_required_state + + async def tracking_fn(**kwargs): + lock_requests.append({ + 'model': kwargs['state_model'].__tablename__, + 'use_row_level_locking': kwargs['use_row_level_locking'], + }) + return await original_fn(**kwargs) + + try: + session = await service.create_session( + app_name='app', user_id='user', session_id='s1' + ) + + database_session_service._select_required_state = tracking_fn + lock_requests.clear() + + event_kwargs = {'invocation_id': 'inv', 'author': 'user'} + if state_delta is not None: + event_kwargs['actions'] = EventActions(state_delta=state_delta) + event = Event(**event_kwargs) + await service.append_event(session, event) + + app_req = next( + (r for r in lock_requests if r['model'] == 'app_states'), None + ) + user_req = next( + (r for r in lock_requests if r['model'] == 'user_states'), None + ) + + # SQLite doesn't support row-level locking so use_row_level_locking is + # always False. The important check is that locking is not requested + # when there is no delta (it must never be True without a delta). + if not expect_app_lock: + assert ( + app_req is None or not app_req['use_row_level_locking'] + ), 'app_states should not be locked without an app: delta' + if not expect_user_lock: + assert ( + user_req is None or not user_req['use_row_level_locking'] + ), 'user_states should not be locked without a user: delta' + finally: + database_session_service._select_required_state = original_fn + await service.close() From 2e434ca7be765d45426fde9d52b131921bd9fa30 Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 3 Mar 2026 14:33:02 -0800 Subject: [PATCH 082/102] fix: Store and retrieve EventCompaction via custom_metadata in Vertex AISessionService This change enables round-tripping of EventCompaction data by storing it within the event's custom_metadata under the key "_compaction" when appending events. When retrieving events, the "_compaction" data is extracted from custom_metadata and used to populate the EventActions.compaction field. This is a temporary measure until the Vertex AI SDK's SessionEvent model supports a dedicated compaction field. Close #3465 Co-authored-by: George Weale PiperOrigin-RevId: 878128265 --- .../adk/sessions/vertex_ai_session_service.py | 55 +++++++++--- .../test_vertex_ai_session_service.py | 85 +++++++++++++++++++ 2 files changed, 129 insertions(+), 11 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 1837a907..8cb7109e 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from . import _session_util from ..events.event import Event from ..events.event_actions import EventActions +from ..events.event_actions import EventCompaction from ..utils.vertex_ai_utils import get_express_mode_api_key from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig @@ -267,8 +268,9 @@ class VertexAiSessionService(BaseSessionService): k: json.loads(v.model_dump_json(exclude_none=True, by_alias=True)) for k, v in event.actions.requested_auth_configs.items() }, - # TODO: add requested_tool_confirmations, compaction, agent_state once + # TODO: add requested_tool_confirmations, agent_state once # they are available in the API. + # Note: compaction is stored via event_metadata.custom_metadata. } if event.error_code: config['error_code'] = event.error_code @@ -291,6 +293,19 @@ class VertexAiSessionService(BaseSessionService): metadata_dict['grounding_metadata'] = event.grounding_metadata.model_dump( exclude_none=True, mode='json' ) + # Store compaction data in custom_metadata since the Vertex AI service + # does not yet support the compaction field. + # TODO: Stop writing to custom_metadata once the Vertex AI service + # supports the compaction field natively in EventActions. + if event.actions and event.actions.compaction: + compaction_dict = event.actions.compaction.model_dump( + exclude_none=True, mode='json' + ) + existing_custom = metadata_dict.get('custom_metadata') or {} + metadata_dict['custom_metadata'] = { + **existing_custom, + '_compaction': compaction_dict, + } config['event_metadata'] = metadata_dict async with self._get_api_client() as api_client: @@ -347,16 +362,6 @@ class VertexAiSessionService(BaseSessionService): def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: """Converts an API event object to an Event object.""" actions = getattr(api_event_obj, 'actions', None) - if actions: - actions_dict = actions.model_dump(exclude_none=True, mode='python') - rename_map = {'transfer_agent': 'transfer_to_agent'} - renamed_actions_dict = { - rename_map.get(k, k): v for k, v in actions_dict.items() - } - event_actions = EventActions.model_validate(renamed_actions_dict) - else: - event_actions = EventActions() - event_metadata = getattr(api_event_obj, 'event_metadata', None) if event_metadata: long_running_tool_ids_list = getattr( @@ -370,6 +375,16 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: interrupted = getattr(event_metadata, 'interrupted', None) branch = getattr(event_metadata, 'branch', None) custom_metadata = getattr(event_metadata, 'custom_metadata', None) + # Extract compaction data stored in custom_metadata. + # NOTE: This read path must be kept permanently because sessions + # written before native compaction support store compaction data + # in custom_metadata under the '_compaction' key. + compaction_data = None + if custom_metadata and '_compaction' in custom_metadata: + custom_metadata = dict(custom_metadata) # avoid mutating the API response + compaction_data = custom_metadata.pop('_compaction') + if not custom_metadata: + custom_metadata = None grounding_metadata = _session_util.decode_model( getattr(event_metadata, 'grounding_metadata', None), types.GroundingMetadata, @@ -381,8 +396,26 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: interrupted = None branch = None custom_metadata = None + compaction_data = None grounding_metadata = None + if actions: + actions_dict = actions.model_dump(exclude_none=True, mode='python') + rename_map = {'transfer_agent': 'transfer_to_agent'} + renamed_actions_dict = { + rename_map.get(k, k): v for k, v in actions_dict.items() + } + if compaction_data: + renamed_actions_dict['compaction'] = compaction_data + event_actions = EventActions.model_validate(renamed_actions_dict) + else: + if compaction_data: + event_actions = EventActions( + compaction=EventCompaction.model_validate(compaction_data) + ) + else: + event_actions = EventActions() + return Event( id=api_event_obj.name.split('/')[-1], invocation_id=api_event_obj.invocation_id, diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 8c77f194..c095ddd9 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -27,6 +27,7 @@ from google.adk.auth import auth_schemes from google.adk.auth.auth_tool import AuthConfig from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.events.event_actions import EventCompaction from google.adk.sessions.base_session_service import GetSessionConfig from google.adk.sessions.session import Session from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService @@ -826,3 +827,87 @@ async def test_append_event(): assert len(retrieved_session.events) == 2 event_to_append.id = retrieved_session.events[1].id assert retrieved_session.events[1] == event_to_append + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event_with_compaction(): + """Compaction data round-trips through append_event and get_session.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert session is not None + + compaction = EventCompaction( + start_timestamp=1000.0, + end_timestamp=2000.0, + compacted_content=genai_types.Content( + parts=[genai_types.Part(text='compacted summary')] + ), + ) + event_to_append = Event( + invocation_id='compaction_invocation', + author='model', + timestamp=1734005534.0, + actions=EventActions(compaction=compaction), + ) + + await session_service.append_event(session, event_to_append) + + retrieved_session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert retrieved_session is not None + + appended_event = retrieved_session.events[-1] + assert appended_event.actions.compaction is not None + assert appended_event.actions.compaction.start_timestamp == 1000.0 + assert appended_event.actions.compaction.end_timestamp == 2000.0 + assert appended_event.actions.compaction.compacted_content.parts[0].text == ( + 'compacted summary' + ) + # custom_metadata should remain None when only compaction was stored + assert appended_event.custom_metadata is None + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event_with_compaction_and_custom_metadata(): + """Both compaction and user custom_metadata survive the round-trip.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert session is not None + + compaction = EventCompaction( + start_timestamp=100.0, + end_timestamp=200.0, + compacted_content=genai_types.Content( + parts=[genai_types.Part(text='summary')] + ), + ) + event_to_append = Event( + invocation_id='compaction_and_meta_invocation', + author='model', + timestamp=1734005535.0, + actions=EventActions(compaction=compaction), + custom_metadata={'user_key': 'user_value'}, + ) + + await session_service.append_event(session, event_to_append) + + retrieved_session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert retrieved_session is not None + + appended_event = retrieved_session.events[-1] + # Compaction is restored + assert appended_event.actions.compaction is not None + assert appended_event.actions.compaction.start_timestamp == 100.0 + assert appended_event.actions.compaction.end_timestamp == 200.0 + # User custom_metadata is preserved without the internal _compaction key + assert appended_event.custom_metadata == {'user_key': 'user_value'} + assert '_compaction' not in (appended_event.custom_metadata or {}) From b004da50270475adc9e1d7afe4064ca1d10c560a Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 3 Mar 2026 14:41:34 -0800 Subject: [PATCH 083/102] fix: Allow artifact services to accept dictionary representations of types.Part This change introduces an `ensure_part` helper function that normalizes input to `types.Part`. This allows `save_artifact` methods in `FileArtifactService`, `GcsArtifactService`, and `InMemoryArtifactService` to accept dictionaries, including those with camelCase keys as used by Agentspace, and convert them into proper `types.Part` instances before saving Close #2886 Co-authored-by: George Weale PiperOrigin-RevId: 878131948 --- .../adk/artifacts/base_artifact_service.py | 36 ++++- .../adk/artifacts/file_artifact_service.py | 7 +- .../adk/artifacts/gcs_artifact_service.py | 7 +- .../artifacts/in_memory_artifact_service.py | 5 +- .../artifacts/test_artifact_service.py | 130 ++++++++++++++++++ 5 files changed, 175 insertions(+), 10 deletions(-) diff --git a/src/google/adk/artifacts/base_artifact_service.py b/src/google/adk/artifacts/base_artifact_service.py index 1a265f8a..23f5e44f 100644 --- a/src/google/adk/artifacts/base_artifact_service.py +++ b/src/google/adk/artifacts/base_artifact_service.py @@ -16,8 +16,10 @@ from __future__ import annotations from abc import ABC from abc import abstractmethod from datetime import datetime +import logging from typing import Any from typing import Optional +from typing import Union from google.genai import types from pydantic import alias_generators @@ -25,6 +27,8 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +logger = logging.getLogger("google_adk." + __name__) + class ArtifactVersion(BaseModel): """Metadata describing a specific version of an artifact.""" @@ -60,6 +64,26 @@ class ArtifactVersion(BaseModel): ) +def ensure_part(artifact: Union[types.Part, dict[str, Any]]) -> types.Part: + """Normalizes an artifact to a ``types.Part`` instance. + + External callers may provide artifacts as + plain dictionaries with camelCase keys (``inlineData``) instead of properly + deserialized ``types.Part`` objects. ``model_validate`` handles both + camelCase and snake_case dictionaries transparently via Pydantic aliases. + + Args: + artifact: A ``types.Part`` instance or a dictionary representation. + + Returns: + A validated ``types.Part`` instance. + """ + if isinstance(artifact, dict): + logger.debug("Normalizing artifact dict to types.Part: %s", list(artifact)) + return types.Part.model_validate(artifact) + return artifact + + class BaseArtifactService(ABC): """Abstract base class for artifact services.""" @@ -70,7 +94,7 @@ class BaseArtifactService(ABC): app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: @@ -84,10 +108,12 @@ class BaseArtifactService(ABC): app_name: The app name. user_id: The user ID. filename: The filename of the artifact. - artifact: The artifact to save. If the artifact consists of `file_data`, - the artifact service assumes its content has been uploaded separately, - and this method will associate the `file_data` with the artifact if - necessary. + artifact: The artifact to save. Accepts a ``types.Part`` instance or a + plain dictionary (camelCase or snake_case keys) which will be + normalized via ``ensure_part``. If the artifact consists of + ``file_data``, the artifact service assumes its content has been + uploaded separately, and this method will associate the ``file_data`` + with the artifact if necessary. session_id: The session ID. If `None`, the artifact is user-scoped. custom_metadata: custom metadata to associate with the artifact. diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index be5adb48..b0078e27 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -22,6 +22,7 @@ from pathlib import PureWindowsPath import shutil from typing import Any from typing import Optional +from typing import Union from urllib.parse import unquote from urllib.parse import urlparse @@ -35,6 +36,7 @@ from typing_extensions import override from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService +from .base_artifact_service import ensure_part logger = logging.getLogger("google_adk." + __name__) @@ -314,7 +316,7 @@ class FileArtifactService(BaseArtifactService): app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: @@ -339,11 +341,12 @@ class FileArtifactService(BaseArtifactService): self, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str], custom_metadata: Optional[dict[str, Any]], ) -> int: """Saves an artifact to disk and returns its version.""" + artifact = ensure_part(artifact) artifact_dir = self._artifact_dir( user_id=user_id, session_id=session_id, diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index 4108cfb0..f8706ded 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -27,6 +27,7 @@ import asyncio import logging from typing import Any from typing import Optional +from typing import Union from google.genai import types from typing_extensions import override @@ -34,6 +35,7 @@ from typing_extensions import override from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService +from .base_artifact_service import ensure_part logger = logging.getLogger("google_adk." + __name__) @@ -61,7 +63,7 @@ class GcsArtifactService(BaseArtifactService): app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: @@ -198,9 +200,10 @@ class GcsArtifactService(BaseArtifactService): user_id: str, session_id: Optional[str], filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], custom_metadata: Optional[dict[str, Any]] = None, ) -> int: + artifact = ensure_part(artifact) versions = self._list_versions( app_name=app_name, user_id=user_id, diff --git a/src/google/adk/artifacts/in_memory_artifact_service.py b/src/google/adk/artifacts/in_memory_artifact_service.py index 45552b14..48e7afca 100644 --- a/src/google/adk/artifacts/in_memory_artifact_service.py +++ b/src/google/adk/artifacts/in_memory_artifact_service.py @@ -17,6 +17,7 @@ import dataclasses import logging from typing import Any from typing import Optional +from typing import Union from google.genai import types from pydantic import BaseModel @@ -27,6 +28,7 @@ from . import artifact_util from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService +from .base_artifact_service import ensure_part logger = logging.getLogger("google_adk." + __name__) @@ -99,10 +101,11 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel): app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: + artifact = ensure_part(artifact) path = self._artifact_path(app_name, user_id, filename, session_id) if path not in self.artifacts: self.artifacts[path] = [] diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index ec74f8ab..f3e7380b 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -29,6 +29,7 @@ from urllib.parse import unquote from urllib.parse import urlparse from google.adk.artifacts.base_artifact_service import ArtifactVersion +from google.adk.artifacts.base_artifact_service import ensure_part from google.adk.artifacts.file_artifact_service import FileArtifactService from google.adk.artifacts.gcs_artifact_service import GcsArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService @@ -766,3 +767,132 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path): filename=str(absolute_in_scope), artifact=part, ) + + +class TestEnsurePart: + """Tests for the ensure_part normalization helper.""" + + def test_returns_part_unchanged(self): + """A types.Part instance passes through without modification.""" + part = types.Part.from_bytes(data=b"hello", mime_type="text/plain") + result = ensure_part(part) + assert result is part + + def test_converts_camel_case_dict(self): + """A camelCase dict (Agentspace format) is converted to types.Part.""" + raw = {"inlineData": {"mimeType": "image/png", "data": "dGVzdA=="}} + result = ensure_part(raw) + assert isinstance(result, types.Part) + assert result.inline_data is not None + assert result.inline_data.mime_type == "image/png" + + def test_converts_snake_case_dict(self): + """A snake_case dict is converted to types.Part.""" + raw = {"inline_data": {"mime_type": "text/plain", "data": "aGVsbG8="}} + result = ensure_part(raw) + assert isinstance(result, types.Part) + assert result.inline_data is not None + assert result.inline_data.mime_type == "text/plain" + + def test_converts_text_dict(self): + """A dict with 'text' key is converted to types.Part.""" + raw = {"text": "hello world"} + result = ensure_part(raw) + assert isinstance(result, types.Part) + assert result.text == "hello world" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.FILE, + ], +) +async def test_save_artifact_with_camel_case_dict( + service_type, artifact_service_factory +): + """Artifact services accept camelCase dicts (Agentspace format). + + Regression test for https://github.com/google/adk-python/issues/2886 + """ + artifact_service = artifact_service_factory(service_type) + app_name = "app0" + user_id = "user0" + session_id = "sess0" + filename = "uploaded.png" + + # Simulate what Agentspace sends: a plain dict with camelCase keys. + raw_artifact = { + "inlineData": { + "mimeType": "image/png", + "data": "dGVzdF9pbWFnZV9kYXRh", + } + } + + version = await artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=raw_artifact, + ) + assert version == 0 + + loaded = await artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.mime_type == "image/png" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.FILE, + ], +) +async def test_save_artifact_with_snake_case_dict( + service_type, artifact_service_factory +): + """Artifact services accept snake_case dicts.""" + artifact_service = artifact_service_factory(service_type) + app_name = "app0" + user_id = "user0" + session_id = "sess0" + filename = "uploaded.txt" + + raw_artifact = { + "inline_data": { + "mime_type": "text/plain", + "data": "aGVsbG8=", + } + } + + version = await artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=raw_artifact, + ) + assert version == 0 + + loaded = await artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.mime_type == "text/plain" From 63f450e0231f237ee1af37f17420d37b15426d48 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Tue, 3 Mar 2026 17:18:21 -0800 Subject: [PATCH 084/102] feat: Support all `types.SchemaUnion` as output_schema in LLM Agent Co-authored-by: Xuan Yang PiperOrigin-RevId: 878194677 --- .../samples/fields_output_schema/agent.py | 16 +- src/google/adk/agents/llm_agent.py | 27 +- .../llm_flows/_output_schema_processor.py | 8 +- src/google/adk/models/llm_request.py | 22 +- src/google/adk/tools/agent_tool.py | 8 +- .../adk/tools/google_search_agent_tool.py | 5 +- .../adk/tools/set_model_response_tool.py | 93 +++++-- src/google/adk/utils/_schema_utils.py | 119 +++++++++ .../llm_flows/test_output_schema_processor.py | 51 +++- .../tools/test_set_model_response_tool.py | 238 ++++++++++++++++-- tests/unittests/utils/test_schema_utils.py | 146 +++++++++++ 11 files changed, 653 insertions(+), 80 deletions(-) create mode 100644 src/google/adk/utils/_schema_utils.py create mode 100644 tests/unittests/utils/test_schema_utils.py diff --git a/contributing/samples/fields_output_schema/agent.py b/contributing/samples/fields_output_schema/agent.py index de40774d..f948668a 100644 --- a/contributing/samples/fields_output_schema/agent.py +++ b/contributing/samples/fields_output_schema/agent.py @@ -22,9 +22,20 @@ class WeatherData(BaseModel): wind_speed: str +def get_current_year() -> str: + """Get the current year. + + Returns: + The current year as a string + """ + from datetime import datetime + + return str(datetime.now().year) + + root_agent = Agent( name='root_agent', - model='gemini-2.0-flash', + model='gemini-2.5-flash', instruction="""\ Answer user's questions based on the data you have. @@ -43,6 +54,7 @@ Here are the data you have for Cupertino * wind_speed: 13 mph """, - output_schema=WeatherData, + output_schema=list[WeatherData], output_key='weather_data', + tools=[get_current_year], ) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 8b555b74..0f7cc2b7 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -56,6 +56,8 @@ from ..tools.base_toolset import BaseToolset from ..tools.function_tool import FunctionTool from ..tools.tool_configs import ToolConfig from ..tools.tool_context import ToolContext +from ..utils._schema_utils import SchemaType +from ..utils._schema_utils import validate_schema from ..utils.context_utils import Aclosing from .base_agent import BaseAgent from .base_agent import BaseAgentState @@ -318,9 +320,16 @@ class LlmAgent(BaseAgent): # Controlled input/output configurations - Start input_schema: Optional[type[BaseModel]] = None """The input schema when agent is used as a tool.""" - output_schema: Optional[type[BaseModel]] = None + output_schema: Optional[SchemaType] = None """The output schema when agent replies. + Supports all schema types that the underlying Google GenAI API supports: + - type[BaseModel]: e.g., MySchema + - list[type[BaseModel]]: e.g., list[MySchema] + - list[primitive]: e.g., list[str], list[int] + - dict: Raw dict schemas + - Schema: Google's Schema type + NOTE: When this is set, agent can ONLY reply and CANNOT use any tools, such as function tools, RAGs, agent transfer, etc. @@ -820,12 +829,12 @@ class LlmAgent(BaseAgent): event.author, ) return - if ( - self.output_key - and event.is_final_response() - and event.content - and event.content.parts - ): + + if not self.output_key: + return + + # Handle text responses + if event.is_final_response() and event.content and event.content.parts: result = ''.join( part.text @@ -838,9 +847,7 @@ class LlmAgent(BaseAgent): # Do not attempt to parse it as JSON. if not result.strip(): return - result = self.output_schema.model_validate_json(result).model_dump( - exclude_none=True - ) + result = validate_schema(self.output_schema, result) event.actions.state_delta[self.output_key] = result @model_validator(mode='after') diff --git a/src/google/adk/flows/llm_flows/_output_schema_processor.py b/src/google/adk/flows/llm_flows/_output_schema_processor.py index 36fa8d56..284cc213 100644 --- a/src/google/adk/flows/llm_flows/_output_schema_processor.py +++ b/src/google/adk/flows/llm_flows/_output_schema_processor.py @@ -110,8 +110,12 @@ def get_structured_model_response(function_response_event: Event) -> str | None: for func_response in function_response_event.get_function_responses(): if func_response.name == 'set_model_response': - # Convert dict to JSON string - return json.dumps(func_response.response, ensure_ascii=False) + # Extract the actual result from the wrapped response. + # Tool results are wrapped as {'result': ...} when not already a dict. + response = func_response.response + if isinstance(response, dict) and 'result' in response: + response = response['result'] + return json.dumps(response, ensure_ascii=False) return None diff --git a/src/google/adk/models/llm_request.py b/src/google/adk/models/llm_request.py index 08d6b861..37f1852b 100644 --- a/src/google/adk/models/llm_request.py +++ b/src/google/adk/models/llm_request.py @@ -25,6 +25,7 @@ from pydantic import Field from ..agents.context_cache_config import ContextCacheConfig from ..tools.base_tool import BaseTool +from ..utils._schema_utils import SchemaType from .cache_metadata import CacheMetadata @@ -273,12 +274,27 @@ class LlmRequest(BaseModel): # No existing tool with function_declarations, create new one self.config.tools.append(types.Tool(function_declarations=declarations)) - def set_output_schema(self, base_model: type[BaseModel]) -> None: + def set_output_schema( + self, + output_schema: Optional[SchemaType] = None, + *, + base_model: Optional[SchemaType] = None, + ) -> None: """Sets the output schema for the request. Args: - base_model: The pydantic base model to set the output schema to. + output_schema: The output schema to set. Supports all types from + SchemaUnion: + - type[BaseModel]: A pydantic model class (e.g., MySchema) + - list[type[BaseModel]]: A generic list type (e.g., list[MySchema]) + - list[primitive]: e.g., list[str], list[int] + - dict: Raw dict schemas + - Schema: Google's Schema type + base_model: Deprecated alias for output_schema. Use output_schema instead. """ + schema = output_schema or base_model + if schema is None: + raise ValueError("Either output_schema or base_model must be provided.") - self.config.response_schema = base_model + self.config.response_schema = schema self.config.response_mime_type = "application/json" diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 91135dce..f53c18df 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -28,6 +28,8 @@ from ..agents.common_configs import AgentRefConfig from ..features import FeatureName from ..features import is_feature_enabled from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..utils._schema_utils import SchemaType +from ..utils._schema_utils import validate_schema from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService from .base_tool import BaseTool @@ -64,7 +66,7 @@ def _get_input_schema(agent: BaseAgent) -> Optional[type[BaseModel]]: return None -def _get_output_schema(agent: BaseAgent) -> Optional[type[BaseModel]]: +def _get_output_schema(agent: BaseAgent) -> Optional[SchemaType]: """Extracts the output_schema from an agent. For LlmAgent, returns its output_schema directly. @@ -268,9 +270,7 @@ class AgentTool(BaseTool): ) output_schema = _get_output_schema(self.agent) if output_schema: - tool_result = output_schema.model_validate_json(merged_text).model_dump( - exclude_none=True - ) + tool_result = validate_schema(output_schema, merged_text) else: tool_result = merged_text return tool_result diff --git a/src/google/adk/tools/google_search_agent_tool.py b/src/google/adk/tools/google_search_agent_tool.py index 56da204e..7ed09c79 100644 --- a/src/google/adk/tools/google_search_agent_tool.py +++ b/src/google/adk/tools/google_search_agent_tool.py @@ -23,6 +23,7 @@ from typing_extensions import override from ..agents.llm_agent import LlmAgent from ..memory.in_memory_memory_service import InMemoryMemoryService from ..models.base_llm import BaseLlm +from ..utils._schema_utils import validate_schema from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService from .agent_tool import AgentTool @@ -127,9 +128,7 @@ class GoogleSearchAgentTool(AgentTool): return '' merged_text = '\n'.join(p.text for p in last_content.parts if p.text) if isinstance(self.agent, LlmAgent) and self.agent.output_schema: - tool_result = self.agent.output_schema.model_validate_json( - merged_text - ).model_dump(exclude_none=True) + tool_result = validate_schema(self.agent.output_schema, merged_text) else: tool_result = merged_text diff --git a/src/google/adk/tools/set_model_response_tool.py b/src/google/adk/tools/set_model_response_tool.py index 7a69ca1f..d1dc6ed5 100644 --- a/src/google/adk/tools/set_model_response_tool.py +++ b/src/google/adk/tools/set_model_response_tool.py @@ -16,19 +16,22 @@ from __future__ import annotations +import inspect from typing import Any from typing import Optional from google.genai import types -from pydantic import BaseModel +from pydantic import TypeAdapter from typing_extensions import override +from ..utils._schema_utils import get_list_inner_type +from ..utils._schema_utils import is_basemodel_schema +from ..utils._schema_utils import is_list_of_basemodel +from ..utils._schema_utils import SchemaType from ._automatic_function_calling_util import build_function_declaration from .base_tool import BaseTool from .tool_context import ToolContext -MODEL_JSON_RESPONSE_KEY = 'temp:__adk_model_response__' - class SetModelResponseTool(BaseTool): """Internal tool used for output schema workaround. @@ -38,14 +41,20 @@ class SetModelResponseTool(BaseTool): provide its final structured response instead of outputting text directly. """ - def __init__(self, output_schema: type[BaseModel]): + def __init__(self, output_schema: SchemaType): """Initialize the tool with the expected output schema. Args: - output_schema: The pydantic model class defining the expected output - structure. + output_schema: The output schema. Supports all types from SchemaUnion: + - type[BaseModel]: A pydantic model class (e.g., MySchema) + - list[type[BaseModel]]: A generic list type (e.g., list[MySchema]) + - list[primitive]: e.g., list[str], list[int] + - dict: Raw dict schemas + - Schema: Google's Schema type """ self.output_schema = output_schema + self._is_basemodel = is_basemodel_schema(output_schema) + self._is_list_of_basemodel = is_list_of_basemodel(output_schema) # Create a function that matches the output schema def set_model_response() -> str: @@ -57,17 +66,37 @@ class SetModelResponseTool(BaseTool): return 'Response set successfully.' # Add the schema fields as parameters to the function dynamically - import inspect - - schema_fields = output_schema.model_fields - params = [] - for field_name, field_info in schema_fields.items(): - param = inspect.Parameter( - field_name, - inspect.Parameter.KEYWORD_ONLY, - annotation=field_info.annotation, - ) - params.append(param) + if self._is_basemodel: + # For regular BaseModel, use the model's fields + schema_fields = output_schema.model_fields + params = [] + for field_name, field_info in schema_fields.items(): + param = inspect.Parameter( + field_name, + inspect.Parameter.KEYWORD_ONLY, + annotation=field_info.annotation, + ) + params.append(param) + elif self._is_list_of_basemodel: + # For list[BaseModel], create a single 'items' parameter + inner_type = get_list_inner_type(output_schema) + params = [ + inspect.Parameter( + 'items', + inspect.Parameter.KEYWORD_ONLY, + annotation=list[inner_type], + ) + ] + else: + # For other schema types (list[str], dict, etc.), + # create a single parameter with the actual schema type + params = [ + inspect.Parameter( + 'response', + inspect.Parameter.KEYWORD_ONLY, + annotation=output_schema, + ) + ] # Create new signature with schema parameters new_sig = inspect.Signature(parameters=params) @@ -94,19 +123,31 @@ class SetModelResponseTool(BaseTool): @override async def run_async( - self, *, args: dict[str, Any], tool_context: ToolContext # pylint: disable=unused-argument - ) -> dict[str, Any]: - """Process the model's response and return the validated dict. + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Process the model's response and return the validated data. Args: args: The structured response data matching the output schema. tool_context: Tool execution context. Returns: - The validated response as dict. + The validated response. Type depends on the output_schema: + - dict for BaseModel + - list of dicts for list[BaseModel] + - raw value for other schema types (list[str], dict, etc.) """ - # Validate the input matches the expected schema - validated_response = self.output_schema.model_validate(args) - - # Return the validated dict directly - return validated_response.model_dump() + if self._is_basemodel: + # For regular BaseModel, validate directly + validated_response = self.output_schema.model_validate(args) + return validated_response.model_dump(exclude_none=True) + elif self._is_list_of_basemodel: + # For list[BaseModel], extract and validate the 'items' field + items = args.get('items', []) + type_adapter = TypeAdapter(self.output_schema) + validated_response = type_adapter.validate_python(items) + return [item.model_dump(exclude_none=True) for item in validated_response] + else: + # For other schema types (list[str], dict, etc.), + # return the value directly without pydantic validation + return args.get('response') diff --git a/src/google/adk/utils/_schema_utils.py b/src/google/adk/utils/_schema_utils.py new file mode 100644 index 00000000..3bb74df9 --- /dev/null +++ b/src/google/adk/utils/_schema_utils.py @@ -0,0 +1,119 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General schema utilities. + +This module is for ADK internal use only. +Please do not rely on the implementation details. +""" + +from __future__ import annotations + +import json +from typing import Any +from typing import get_args +from typing import get_origin +from typing import Optional + +from google.genai import types +from pydantic import BaseModel +from pydantic import TypeAdapter + +# Use SchemaUnion from google.genai.types to support all schema types +# that the underlying API supports. +SchemaType = types.SchemaUnion +"""Type for schema fields (e.g., output_schema, input_schema). + +Supports all schema types that the underlying Google GenAI API supports: + - type[BaseModel]: A pydantic model class (e.g., MySchema) + - GenericAlias: Generic types like list[str], list[MySchema], dict[str, int] + - dict: Raw dict schemas + - Schema: Google's Schema type +""" + + +def is_basemodel_schema(schema: SchemaType) -> bool: + """Check if the schema is a BaseModel type (not a generic alias). + + Args: + schema: The schema to check. + + Returns: + True if schema is a BaseModel class, False otherwise. + """ + return isinstance(schema, type) and issubclass(schema, BaseModel) + + +def is_list_of_basemodel(schema: SchemaType) -> bool: + """Check if the schema is a list of BaseModel type. + + Args: + schema: The schema to check. + + Returns: + True if schema is list[SomeBaseModel], False otherwise. + """ + origin = get_origin(schema) + if origin is not list: + return False + + args = get_args(schema) + if not args: + return False + + inner_type = args[0] + return isinstance(inner_type, type) and issubclass(inner_type, BaseModel) + + +def get_list_inner_type(schema: SchemaType) -> Optional[type[BaseModel]]: + """Get the inner BaseModel type from a list[BaseModel] schema. + + Args: + schema: The schema (expected to be list[SomeBaseModel]). + + Returns: + The inner BaseModel type, or None if not a list of BaseModel. + """ + if not is_list_of_basemodel(schema): + return None + + args = get_args(schema) + return args[0] + + +def validate_schema(schema: SchemaType, json_text: str) -> Any: + """Validate JSON text against a schema and return the result. + + Args: + schema: The schema to validate against. + json_text: The JSON text to validate. + + Returns: + The validated result. Type depends on the schema: + - dict for BaseModel + - list of dicts for list[BaseModel] + - raw value for other schema types (list[str], dict, etc.) + """ + if is_basemodel_schema(schema): + # For regular BaseModel, use model_validate_json + return schema.model_validate_json(json_text).model_dump(exclude_none=True) + elif is_list_of_basemodel(schema): + # For list[BaseModel], use TypeAdapter to validate + type_adapter = TypeAdapter(schema) + validated = type_adapter.validate_json(json_text) + return [item.model_dump(exclude_none=True) for item in validated] + else: + # For other schema types (list[str], dict, Schema, etc.), + # just parse JSON without pydantic validation + return json.loads(json_text) diff --git a/tests/unittests/flows/llm_flows/test_output_schema_processor.py b/tests/unittests/flows/llm_flows/test_output_schema_processor.py index 3e95dea1..23c741bc 100644 --- a/tests/unittests/flows/llm_flows/test_output_schema_processor.py +++ b/tests/unittests/flows/llm_flows/test_output_schema_processor.py @@ -199,7 +199,6 @@ async def test_output_schema_request_processor( @pytest.mark.asyncio async def test_set_model_response_tool(): """Test the set_model_response tool functionality.""" - from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY from google.adk.tools.set_model_response_tool import SetModelResponseTool from google.adk.tools.tool_context import ToolContext @@ -215,18 +214,12 @@ async def test_set_model_response_tool(): tool_context=tool_context, ) - # Verify the tool now returns dict directly + # Verify the tool returns dict directly assert result is not None assert result['name'] == 'John Doe' assert result['age'] == 30 assert result['city'] == 'New York' - # Check that the response is no longer stored in session state - stored_response = invocation_context.session.state.get( - MODEL_JSON_RESPONSE_KEY - ) - assert stored_response is None - @pytest.mark.asyncio async def test_output_schema_helper_functions(): @@ -328,6 +321,48 @@ async def test_get_structured_model_response_with_non_ascii(): assert extracted_json == expected_json +@pytest.mark.asyncio +async def test_get_structured_model_response_with_wrapped_result(): + """Test get_structured_model_response with wrapped list result. + + When a tool returns a non-dict (e.g., list), it gets wrapped as + {'result': [...]}. This test ensures we correctly unwrap the result. + """ + from google.adk.events.event import Event + from google.adk.flows.llm_flows._output_schema_processor import get_structured_model_response + from google.genai import types + + # Simulate a list result wrapped by ADK's functions.py + wrapped_response = { + 'result': [ + {'name': 'Alice', 'age': 30}, + {'name': 'Bob', 'age': 25}, + ] + } + expected_json = '[{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]' + + # Create a function response event with wrapped result + function_response_event = Event( + author='test_agent', + content=types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='set_model_response', response=wrapped_response + ) + ) + ], + ), + ) + + # Get the structured response + extracted_json = get_structured_model_response(function_response_event) + + # Should extract the unwrapped list, not the wrapped dict + assert extracted_json == expected_json + + @pytest.mark.asyncio async def test_end_to_end_integration(): """Test the complete output schema with tools integration.""" diff --git a/tests/unittests/tools/test_set_model_response_tool.py b/tests/unittests/tools/test_set_model_response_tool.py index 75fd40e9..89da394a 100644 --- a/tests/unittests/tools/test_set_model_response_tool.py +++ b/tests/unittests/tools/test_set_model_response_tool.py @@ -14,11 +14,12 @@ """Tests for SetModelResponseTool.""" +import inspect + from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig from google.adk.sessions.in_memory_session_service import InMemorySessionService -from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY from google.adk.tools.set_model_response_tool import SetModelResponseTool from google.adk.tools.tool_context import ToolContext from pydantic import BaseModel @@ -83,8 +84,6 @@ def test_function_signature_generation(): """Test that function signature is correctly generated from schema.""" tool = SetModelResponseTool(PersonSchema) - import inspect - sig = inspect.signature(tool.func) # Check that parameters match schema fields @@ -129,12 +128,6 @@ async def test_run_async_valid_data(): assert result['age'] == 25 assert result['city'] == 'Seattle' - # Verify data is no longer stored in session state (old behavior) - stored_response = invocation_context.session.state.get( - MODEL_JSON_RESPONSE_KEY - ) - assert stored_response is None - @pytest.mark.asyncio async def test_run_async_complex_schema(): @@ -165,12 +158,6 @@ async def test_run_async_complex_schema(): assert result['metadata'] == {'key': 'value'} assert result['is_active'] is False - # Verify data is no longer stored in session state (old behavior) - stored_response = invocation_context.session.state.get( - MODEL_JSON_RESPONSE_KEY - ) - assert stored_response is None - @pytest.mark.asyncio async def test_run_async_validation_error(): @@ -220,15 +207,12 @@ async def test_session_state_storage_key(): tool_context=tool_context, ) - # Verify response is returned directly, not stored in session state + # Verify response is returned directly assert result is not None assert result['name'] == 'Diana' assert result['age'] == 35 assert result['city'] == 'Miami' - # Verify session state is no longer used - assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state - @pytest.mark.asyncio async def test_multiple_executions_return_latest(): @@ -260,9 +244,6 @@ async def test_multiple_executions_return_latest(): assert result2['age'] == 30 assert result2['city'] == 'City2' - # Verify session state is not used - assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state - def test_function_return_value_consistency(): """Test that function return value matches run_async return value.""" @@ -273,3 +254,216 @@ def test_function_return_value_consistency(): # Both should return the same value assert direct_result == 'Response set successfully.' + + +# Tests for list[BaseModel] schema support + + +class ItemSchema(BaseModel): + """Simple item schema for list testing.""" + + id: int = Field(description='Item ID') + name: str = Field(description='Item name') + + +def test_tool_initialization_list_schema(): + """Test tool initialization with a list schema.""" + tool = SetModelResponseTool(list[ItemSchema]) + + assert tool.output_schema == list[ItemSchema] + assert tool._is_list_of_basemodel + assert tool.name == 'set_model_response' + assert 'Set your final response' in tool.description + assert tool.func is not None + + +def test_function_signature_generation_list_schema(): + """Test that function signature is correctly generated for list schema.""" + tool = SetModelResponseTool(list[ItemSchema]) + + sig = inspect.signature(tool.func) + + # Should have a single 'items' parameter + assert 'items' in sig.parameters + assert len(sig.parameters) == 1 + + # Parameter should be keyword-only with correct annotation + assert sig.parameters['items'].kind == inspect.Parameter.KEYWORD_ONLY + assert sig.parameters['items'].annotation == list[ItemSchema] + + +def test_get_declaration_list_schema(): + """Test that tool declaration is properly generated for list schema.""" + tool = SetModelResponseTool(list[ItemSchema]) + + declaration = tool._get_declaration() + + assert declaration is not None + assert declaration.name == 'set_model_response' + assert declaration.description is not None + + +@pytest.mark.asyncio +async def test_run_async_list_schema_valid_data(): + """Test tool execution with valid list data.""" + tool = SetModelResponseTool(list[ItemSchema]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with valid list data + result = await tool.run_async( + args={ + 'items': [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + }, + tool_context=tool_context, + ) + + # Verify the tool returns list of dicts + assert result is not None + assert isinstance(result, list) + assert len(result) == 3 + assert result[0]['id'] == 1 + assert result[0]['name'] == 'Item 1' + assert result[1]['id'] == 2 + assert result[2]['id'] == 3 + + +@pytest.mark.asyncio +async def test_run_async_list_schema_empty_list(): + """Test tool execution with empty list.""" + tool = SetModelResponseTool(list[ItemSchema]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with empty list + result = await tool.run_async( + args={'items': []}, + tool_context=tool_context, + ) + + # Verify the tool returns empty list + assert result is not None + assert isinstance(result, list) + assert len(result) == 0 + + +@pytest.mark.asyncio +async def test_run_async_list_schema_validation_error(): + """Test tool execution with invalid list data raises validation error.""" + tool = SetModelResponseTool(list[ItemSchema]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with invalid data (wrong type for id) + with pytest.raises(ValidationError): + await tool.run_async( + args={ + 'items': [ + {'id': 'not_a_number', 'name': 'Item 1'}, + ] + }, + tool_context=tool_context, + ) + + +# Tests for other schema types (list[str], dict, etc.) + + +def test_tool_initialization_list_str_schema(): + """Test tool initialization with list[str] schema.""" + tool = SetModelResponseTool(list[str]) + + assert tool.output_schema == list[str] + assert not tool._is_basemodel + assert not tool._is_list_of_basemodel + assert tool.name == 'set_model_response' + assert tool.func is not None + + +def test_function_signature_generation_list_str_schema(): + """Test that function signature is correctly generated for list[str] schema.""" + tool = SetModelResponseTool(list[str]) + + sig = inspect.signature(tool.func) + + # Should have a single 'response' parameter with list[str] annotation + assert 'response' in sig.parameters + assert len(sig.parameters) == 1 + assert sig.parameters['response'].kind == inspect.Parameter.KEYWORD_ONLY + assert sig.parameters['response'].annotation == list[str] + + +@pytest.mark.asyncio +async def test_run_async_list_str_schema(): + """Test tool execution with list[str] data.""" + tool = SetModelResponseTool(list[str]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with list of strings + result = await tool.run_async( + args={'response': ['apple', 'banana', 'cherry']}, + tool_context=tool_context, + ) + + # Verify the tool returns the list directly + assert result is not None + assert isinstance(result, list) + assert result == ['apple', 'banana', 'cherry'] + + +def test_tool_initialization_dict_schema(): + """Test tool initialization with dict schema.""" + tool = SetModelResponseTool(dict[str, int]) + + assert tool.output_schema == dict[str, int] + assert not tool._is_basemodel + assert not tool._is_list_of_basemodel + assert tool.name == 'set_model_response' + assert tool.func is not None + + +def test_function_signature_generation_dict_schema(): + """Test that function signature is correctly generated for dict schema.""" + tool = SetModelResponseTool(dict[str, int]) + + sig = inspect.signature(tool.func) + + # Should have a single 'response' parameter with dict[str, int] annotation + assert 'response' in sig.parameters + assert len(sig.parameters) == 1 + assert sig.parameters['response'].kind == inspect.Parameter.KEYWORD_ONLY + assert sig.parameters['response'].annotation == dict[str, int] + + +@pytest.mark.asyncio +async def test_run_async_dict_schema(): + """Test tool execution with dict data.""" + tool = SetModelResponseTool(dict[str, int]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with dict data + result = await tool.run_async( + args={'response': {'a': 1, 'b': 2, 'c': 3}}, + tool_context=tool_context, + ) + + # Verify the tool returns the dict directly + assert result is not None + assert isinstance(result, dict) + assert result == {'a': 1, 'b': 2, 'c': 3} diff --git a/tests/unittests/utils/test_schema_utils.py b/tests/unittests/utils/test_schema_utils.py new file mode 100644 index 00000000..8f68ecdb --- /dev/null +++ b/tests/unittests/utils/test_schema_utils.py @@ -0,0 +1,146 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for _schema_utils module.""" + +from google.adk.utils._schema_utils import get_list_inner_type +from google.adk.utils._schema_utils import is_basemodel_schema +from google.adk.utils._schema_utils import is_list_of_basemodel +from google.adk.utils._schema_utils import validate_schema +from pydantic import BaseModel + + +class SampleModel(BaseModel): + """Sample model for testing.""" + + name: str + value: int + + +class TestIsBasemodelSchema: + """Tests for is_basemodel_schema function.""" + + def test_basemodel_class_returns_true(self): + """Test that a BaseModel class returns True.""" + assert is_basemodel_schema(SampleModel) + + def test_list_of_basemodel_returns_false(self): + """Test that list[BaseModel] returns False.""" + assert not is_basemodel_schema(list[SampleModel]) + + def test_list_of_str_returns_false(self): + """Test that list[str] returns False.""" + assert not is_basemodel_schema(list[str]) + + def test_dict_returns_false(self): + """Test that dict types return False.""" + assert not is_basemodel_schema(dict[str, int]) + + def test_plain_str_returns_false(self): + """Test that plain str returns False.""" + assert not is_basemodel_schema(str) + + def test_plain_int_returns_false(self): + """Test that plain int returns False.""" + assert not is_basemodel_schema(int) + + +class TestIsListOfBasemodel: + """Tests for is_list_of_basemodel function.""" + + def test_list_of_basemodel_returns_true(self): + """Test that list[BaseModel] returns True.""" + assert is_list_of_basemodel(list[SampleModel]) + + def test_basemodel_class_returns_false(self): + """Test that a plain BaseModel class returns False.""" + assert not is_list_of_basemodel(SampleModel) + + def test_list_of_str_returns_false(self): + """Test that list[str] returns False.""" + assert not is_list_of_basemodel(list[str]) + + def test_list_of_int_returns_false(self): + """Test that list[int] returns False.""" + assert not is_list_of_basemodel(list[int]) + + def test_dict_returns_false(self): + """Test that dict types return False.""" + assert not is_list_of_basemodel(dict[str, int]) + + def test_plain_list_returns_false(self): + """Test that plain list (no type arg) returns False.""" + assert not is_list_of_basemodel(list) + + +class TestGetListInnerType: + """Tests for get_list_inner_type function.""" + + def test_list_of_basemodel_returns_inner_type(self): + """Test that list[BaseModel] returns the inner type.""" + assert get_list_inner_type(list[SampleModel]) is SampleModel + + def test_basemodel_class_returns_none(self): + """Test that a plain BaseModel class returns None.""" + assert get_list_inner_type(SampleModel) is None + + def test_list_of_str_returns_none(self): + """Test that list[str] returns None.""" + assert get_list_inner_type(list[str]) is None + + def test_dict_returns_none(self): + """Test that dict types return None.""" + assert get_list_inner_type(dict[str, int]) is None + + +class TestValidateSchema: + """Tests for validate_schema function.""" + + def test_basemodel_schema(self): + """Test validation with a BaseModel schema.""" + json_text = '{"name": "test", "value": 42}' + result = validate_schema(SampleModel, json_text) + assert result == {'name': 'test', 'value': 42} + + def test_basemodel_schema_excludes_none(self): + """Test that None values are excluded from the result.""" + + class ModelWithOptional(BaseModel): + name: str + optional_field: str | None = None + + json_text = '{"name": "test", "optional_field": null}' + result = validate_schema(ModelWithOptional, json_text) + assert result == {'name': 'test'} + + def test_list_of_basemodel_schema(self): + """Test validation with a list[BaseModel] schema.""" + json_text = '[{"name": "item1", "value": 1}, {"name": "item2", "value": 2}]' + result = validate_schema(list[SampleModel], json_text) + assert result == [ + {'name': 'item1', 'value': 1}, + {'name': 'item2', 'value': 2}, + ] + + def test_list_of_str_schema(self): + """Test validation with a list[str] schema.""" + json_text = '["a", "b", "c"]' + result = validate_schema(list[str], json_text) + assert result == ['a', 'b', 'c'] + + def test_dict_schema(self): + """Test validation with a dict schema.""" + json_text = '{"key1": 1, "key2": 2}' + result = validate_schema(dict[str, int], json_text) + assert result == {'key1': 1, 'key2': 2} From 2b8ccd4a0015fe83dd35d3040d286680d3d8ce57 Mon Sep 17 00:00:00 2001 From: Bastien Jacot-Guillarmod Date: Wed, 4 Mar 2026 00:43:18 -0800 Subject: [PATCH 085/102] chore: Exclude `BaseAgent.parent_agent` from serialization Otherwise, serialization fails due to circular references. PiperOrigin-RevId: 878339838 --- src/google/adk/agents/base_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 3d0a14d4..dec85690 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -122,7 +122,9 @@ class BaseAgent(BaseModel): One-line description is enough and preferred. """ - parent_agent: Optional[BaseAgent] = Field(default=None, init=False) + parent_agent: Optional[BaseAgent] = Field( + default=None, init=False, exclude=True + ) """The parent agent of this agent. Note that an agent can ONLY be added as sub-agent once. From 87ffc55640dea1185cf67e6f9b78f70b30867bcc Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 4 Mar 2026 03:25:26 -0800 Subject: [PATCH 086/102] feat: New implementation of A2aAgentExecutor and A2A-ADK conversion This change introduces new implementation files for the A2aAgentExecutor and event converters. The existing A2aAgentExecutor now acts as a wrapper, allowing a switch between the legacy and new implementations. The new implementation includes support for execution interceptors and a dedicated executor context. Main Changes= `a2a_agent_executor_impl.py` = the new implementation of the AgentExecutor differs from the legacy one (`a2a_agent_executor.py`) for the removal of the TaskResultAggregator and the explicit `InvocationContext` creation. Instead, it uses `ExecutorContext` and delegates event conversion to the new logic that supports streaming. It maintains an `agents_artifact` state map to handle partial updates and emits TaskArtifactUpdateEvents for content. The `long_running_functions.py` is used to keep track of the LongRunning FunctionCalls and respective FunctionResponse, to emit them at the end of the generation loop in a `TaskStateUpdateEvent(input-required/auth-required)`. `from_adk_event.py` = this file replaces the conversion functions in the `event_converter.py` used to convert the adk events into a2a events, estrapolating them in a dedicated file. The main changes in the methods are the introduction of TaskArtifactUpdateEvent to handle content parts, allowing for true artifact streaming and chunking. It utilizes an `agents_artifacts` dictionary to track artifact IDs across partial events to correctly handle append operations. PiperOrigin-RevId: 878399140 --- .../adk/a2a/converters/from_adk_event.py | 288 +++++++ .../a2a/converters/long_running_functions.py | 215 +++++ .../adk/a2a/executor/a2a_agent_executor.py | 59 +- .../a2a/executor/a2a_agent_executor_impl.py | 310 +++++++ src/google/adk/a2a/executor/config.py | 38 + .../unittests/a2a/converters/test_from_adk.py | 108 +++ .../executor/test_a2a_agent_executor_impl.py | 808 ++++++++++++++++++ 7 files changed, 1793 insertions(+), 33 deletions(-) create mode 100644 src/google/adk/a2a/converters/from_adk_event.py create mode 100644 src/google/adk/a2a/converters/long_running_functions.py create mode 100644 src/google/adk/a2a/executor/a2a_agent_executor_impl.py create mode 100644 tests/unittests/a2a/converters/test_from_adk.py create mode 100644 tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py diff --git a/src/google/adk/a2a/converters/from_adk_event.py b/src/google/adk/a2a/converters/from_adk_event.py new file mode 100644 index 00000000..05bf16d1 --- /dev/null +++ b/src/google/adk/a2a/converters/from_adk_event.py @@ -0,0 +1,288 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +from datetime import datetime +from datetime import timezone +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +import uuid + +from a2a.server.events import Event as A2AEvent +from a2a.types import Artifact +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart + +from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME +from ..experimental import a2a_experimental +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import convert_genai_part_to_a2a_part +from .part_converter import GenAIPartToA2APartConverter +from .utils import _get_adk_metadata_key + +# Constants +DEFAULT_ERROR_MESSAGE = "An error occurred during processing" + +# Logger +logger = logging.getLogger("google_adk." + __name__) + +A2AUpdateEvent = Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent] + +AdkEventToA2AEventsConverter = Callable[ + [ + Event, + Optional[Dict[str, str]], + Optional[str], + Optional[str], + GenAIPartToA2APartConverter, + ], + List[A2AUpdateEvent], +] +"""A callable that converts an ADK Event into a list of A2A events. + +This interface allows for custom logic to map ADK's event structure to the +event structure expected by the A2A server. + +Args: + event: The source ADK Event to convert. + agents_artifacts: State map for tracking active artifact IDs across chunks. + task_id: The ID of the A2A task being processed. + context_id: The context ID from the A2A request. + part_converter: A function to convert GenAI content parts to A2A + parts. + +Returns: + A list of A2A events. +""" + + +def _convert_adk_parts_to_a2a_parts( + event: Event, + part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part, +) -> Optional[List[A2APart]]: + """Converts an ADK event to an A2A parts list. + + Args: + event: The ADK event to convert. + part_converter: The function to convert GenAI part to A2A part. + + Returns: + A list of A2A parts representing the converted ADK event. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + + if not event.content or not event.content.parts: + return [] + + try: + output_parts = [] + for part in event.content.parts: + a2a_parts = part_converter(part) + if not isinstance(a2a_parts, list): + a2a_parts = [a2a_parts] if a2a_parts else [] + for a2a_part in a2a_parts: + output_parts.append(a2a_part) + + return output_parts + + except Exception as e: + logger.error("Failed to convert event to status message: %s", e) + raise + + +def create_error_status_event( + event: Event, + task_id: Optional[str] = None, + context_id: Optional[str] = None, +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent for error scenarios. + + Args: + event: The ADK event containing error information. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + + Returns: + A TaskStatusUpdateEvent with FAILED state. + """ + error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE + + error_event = TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=TaskState.failed, + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[A2APart(root=TextPart(text=error_message))], + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + final=True, + ) + return _add_event_metadata(event, [error_event])[0] + + +@a2a_experimental +def convert_event_to_a2a_events( + event: Event, + agents_artifacts: Dict[str, str], + task_id: Optional[str] = None, + context_id: Optional[str] = None, + part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part, +) -> List[A2AUpdateEvent]: + """Converts a GenAI event to a list of A2A StatusUpdate and ArtifactUpdate events. + + Args: + event: The ADK event to convert. + agents_artifacts: State map for tracking active artifact IDs across chunks. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + part_converter: The function to convert GenAI part to A2A part. + + Returns: + A list of A2A update events representing the converted ADK event. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + if agents_artifacts is None: + raise ValueError("Agents artifacts cannot be None") + + a2a_events = [] + try: + a2a_parts = _convert_adk_parts_to_a2a_parts( + event, part_converter=part_converter + ) + # Handle artifact updates for normal parts + if a2a_parts: + agent_name = event.author + partial = event.partial or False + + artifact_id = agents_artifacts.get(agent_name) + if artifact_id: + append = partial + if not partial: + del agents_artifacts[agent_name] + else: + artifact_id = str(uuid.uuid4()) + # TODO: Clarify if new artifact id must have append=False + append = False + if partial: + agents_artifacts[agent_name] = artifact_id + + a2a_events.append( + TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + last_chunk=not partial, + append=append, + artifact=Artifact( + artifact_id=artifact_id, + parts=a2a_parts, + ), + ) + ) + + a2a_events = _add_event_metadata(event, a2a_events) + return a2a_events + + except Exception as e: + logger.error("Failed to convert event to A2A events: %s", e) + raise + + +def _serialize_value(value: Any) -> Optional[Any]: + """Serializes a value and returns it if it contains meaningful content. + + Returns None if the value is empty or missing. + """ + if value is None: + return None + + # Handle Pydantic models + if hasattr(value, "model_dump"): + try: + dumped = value.model_dump( + exclude_none=True, + exclude_unset=True, + exclude_defaults=True, + by_alias=True, + ) + return dumped if dumped else None + except Exception as e: + logger.warning("Failed to serialize Pydantic model, falling back: %s", e) + return str(value) + + return str(value) + + +# TODO: Clarify if this metadata needs to be translated back into the ADK event +def _add_event_metadata( + event: Event, a2a_events: List[A2AEvent] +) -> List[A2AEvent]: + """Gets the context metadata for the event and applies it to A2A events.""" + if not event: + raise ValueError("Event cannot be None") + + metadata_values = { + "invocation_id": event.invocation_id, + "author": event.author, + "event_id": event.id, + "branch": event.branch, + "citation_metadata": event.citation_metadata, + "grounding_metadata": event.grounding_metadata, + "custom_metadata": event.custom_metadata, + "usage_metadata": event.usage_metadata, + "error_code": event.error_code, + "actions": event.actions, + } + + metadata = {} + for field_name, field_value in metadata_values.items(): + value = _serialize_value(field_value) + if value is not None: + metadata[_get_adk_metadata_key(field_name)] = value + + for a2a_event in a2a_events: + if isinstance(a2a_event, TaskStatusUpdateEvent): + a2a_event.status.message.metadata = metadata.copy() + elif isinstance(a2a_event, TaskArtifactUpdateEvent): + a2a_event.artifact.metadata = metadata.copy() + + return a2a_events diff --git a/src/google/adk/a2a/converters/long_running_functions.py b/src/google/adk/a2a/converters/long_running_functions.py new file mode 100644 index 00000000..0bbb46da --- /dev/null +++ b/src/google/adk/a2a/converters/long_running_functions.py @@ -0,0 +1,215 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +from typing import List +from typing import Set +import uuid + +from a2a.server.agent_execution.context import RequestContext +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.genai import types as genai_types + +from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE +from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import A2APartToGenAIPartConverter +from .part_converter import convert_a2a_part_to_genai_part +from .utils import _get_adk_metadata_key + + +class LongRunningFunctions: + """Keeps track of long running function calls and related responses.""" + + def __init__( + self, part_converter: A2APartToGenAIPartConverter | None = None + ) -> None: + self._parts: List[genai_types.Part] = [] + self._long_running_tool_ids: Set[str] = set() + self._part_converter = part_converter or convert_a2a_part_to_genai_part + self._task_state: TaskState = TaskState.input_required + + def has_long_running_function_calls(self) -> bool: + """Returns True if there are long running function calls.""" + return bool(self._long_running_tool_ids) + + def process_event(self, event: Event) -> Event: + """Processes parts to extract long running calls and responses. + + Returns a copy of the input event with processed parts removed from + event.content.parts. + + Args: + event: The ADK event containing long running tool IDs and content parts. + """ + event = event.model_copy(deep=True) + if not event.content or not event.content.parts: + return event + + kept_parts = [] + for part in event.content.parts: + should_remove = False + if part.function_call: + if part.function_call.id in event.long_running_tool_ids: + if not event.partial: + self._parts.append(part) + self._long_running_tool_ids.add(part.function_call.id) + should_remove = True + + elif part.function_response: + if part.function_response.id in self._long_running_tool_ids: + if not event.partial: + self._parts.append(part) + should_remove = True + + if not should_remove: + kept_parts.append(part) + + event.content.parts = kept_parts + return event + + def create_long_running_function_call_event( + self, + task_id: str, + context_id: str, + ) -> TaskStatusUpdateEvent: + """Creates a task status update event for the long running function calls.""" + if not self._long_running_tool_ids: + return None + + a2a_parts = self._return_long_running_parts() + if not a2a_parts: + return None + + return TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=self._task_state, + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=a2a_parts, + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + final=True, + ) + + def _return_long_running_parts(self) -> List[A2APart]: + """Converts long-running parts to A2A parts.""" + if not self._long_running_tool_ids: + return [] + + output_parts = [] + for part in self._parts: + a2a_parts = self._part_converter(part) + if not isinstance(a2a_parts, list): + a2a_parts = [a2a_parts] if a2a_parts else [] + for a2a_part in a2a_parts: + self._mark_long_running_function_call(a2a_part) + output_parts.append(a2a_part) + + return output_parts + + def _mark_long_running_function_call(self, a2a_part: A2APart) -> None: + """Processes long-running tool metadata for an A2A part. + + Args: + a2a_part: The A2A part to potentially mark as long-running. + """ + + if ( + isinstance(a2a_part.root, DataPart) + and a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ): + a2a_part.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ] = True + # If the function is a request for EUC, set the task state to + # auth_required. Otherwise, set it to input_required. Save the state of + # the last function call, as it will be the state of the task. + if a2a_part.root.metadata.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME: + self._task_state = TaskState.auth_required + else: + self._task_state = TaskState.input_required + + +def handle_user_input(context: RequestContext) -> TaskStatusUpdateEvent | None: + """Processes user input events, validating function responses.""" + + if ( + not context.current_task + or not context.current_task.status + or ( + context.current_task.status.state != TaskState.input_required + and context.current_task.status.state != TaskState.auth_required + ) + ): + return None + + # If the task is in input_required or auth_required state, we expect the user + # to provide a response for the function call. Check if the user input + # contains a function response. + for a2a_part in context.message.parts: + if ( + isinstance(a2a_part.root, DataPart) + and a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ): + return None + + return TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=context.current_task.status.state, + timestamp=datetime.now(timezone.utc).isoformat(), + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[ + A2APart( + root=TextPart( + text=( + "It was not provided a function response for the" + " function call." + ) + ) + ) + ], + ), + ), + final=True, + ) diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index 956b1233..da28955a 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -35,22 +35,14 @@ from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart from google.adk.runners import Runner -from pydantic import BaseModel from typing_extensions import override from ...utils.context_utils import Aclosing -from ..converters.event_converter import AdkEventToA2AEventsConverter -from ..converters.event_converter import convert_event_to_a2a_events -from ..converters.part_converter import A2APartToGenAIPartConverter -from ..converters.part_converter import convert_a2a_part_to_genai_part -from ..converters.part_converter import convert_genai_part_to_a2a_part -from ..converters.part_converter import GenAIPartToA2APartConverter -from ..converters.request_converter import A2ARequestToAgentRunRequestConverter from ..converters.request_converter import AgentRunRequest -from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key from ..experimental import a2a_experimental -from .config import ExecuteInterceptor +from .a2a_agent_executor_impl import _A2aAgentExecutor as ExecutorImpl +from .config import A2aAgentExecutorConfig from .executor_context import ExecutorContext from .task_result_aggregator import TaskResultAggregator from .utils import execute_after_agent_interceptors @@ -60,29 +52,16 @@ from .utils import execute_before_agent_interceptors logger = logging.getLogger('google_adk.' + __name__) -@a2a_experimental -class A2aAgentExecutorConfig(BaseModel): - """Configuration for the A2aAgentExecutor.""" - - a2a_part_converter: A2APartToGenAIPartConverter = ( - convert_a2a_part_to_genai_part - ) - gen_ai_part_converter: GenAIPartToA2APartConverter = ( - convert_genai_part_to_a2a_part - ) - request_converter: A2ARequestToAgentRunRequestConverter = ( - convert_a2a_request_to_agent_run_request - ) - event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events - - execute_interceptors: Optional[list[ExecuteInterceptor]] = None - - @a2a_experimental class A2aAgentExecutor(AgentExecutor): """An AgentExecutor that runs an ADK Agent against an A2A request and publishes updates to an event queue. + + Args: + runner: The runner to use for the agent. + config: The config to use for the executor. + use_legacy: Whether to use the legacy executor implementation. """ def __init__( @@ -90,10 +69,15 @@ class A2aAgentExecutor(AgentExecutor): *, runner: Runner | Callable[..., Runner | Awaitable[Runner]], config: Optional[A2aAgentExecutorConfig] = None, + use_legacy: bool = True, ): super().__init__() - self._runner = runner - self._config = config or A2aAgentExecutorConfig() + if not use_legacy: + self._executor_impl = ExecutorImpl(runner=runner, config=config) + else: + self._executor_impl = None + self._runner = runner + self._config = config or A2aAgentExecutorConfig() async def _resolve_runner(self) -> Runner: """Resolve the runner, handling cases where it's a callable that returns a Runner.""" @@ -122,6 +106,10 @@ class A2aAgentExecutor(AgentExecutor): @override async def cancel(self, context: RequestContext, event_queue: EventQueue): """Cancel the execution.""" + if self._executor_impl: + await self._executor_impl.cancel(context, event_queue) + return + # TODO: Implement proper cancellation logic if needed raise NotImplementedError('Cancellation is not supported') @@ -132,6 +120,7 @@ class A2aAgentExecutor(AgentExecutor): event_queue: EventQueue, ): """Executes an A2A request and publishes updates to the event queue + specified. It runs as following: * Takes the input from the A2A request * Convert the input to ADK input content, and runs the ADK agent @@ -139,6 +128,10 @@ class A2aAgentExecutor(AgentExecutor): * Converts the ADK output events into A2A task updates * Publishes the updates back to A2A server via event queue """ + if self._executor_impl: + await self._executor_impl.execute(context, event_queue) + return + if not context.message: raise ValueError('A2A request must have a message') @@ -213,7 +206,7 @@ class A2aAgentExecutor(AgentExecutor): run_config=run_request.run_config, ) - self._executor_context = ExecutorContext( + executor_context = ExecutorContext( app_name=runner.app_name, user_id=run_request.user_id, session_id=run_request.session_id, @@ -250,7 +243,7 @@ class A2aAgentExecutor(AgentExecutor): ): a2a_event = await execute_after_event_interceptors( a2a_event, - self._executor_context, + executor_context, adk_event, self._config.execute_interceptors, ) @@ -302,7 +295,7 @@ class A2aAgentExecutor(AgentExecutor): ) final_event = await execute_after_agent_interceptors( - self._executor_context, + executor_context, final_event, self._config.execute_interceptors, ) diff --git a/src/google/adk/a2a/executor/a2a_agent_executor_impl.py b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py new file mode 100644 index 00000000..cec68f36 --- /dev/null +++ b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py @@ -0,0 +1,310 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +import inspect +import logging +from typing import Awaitable +from typing import Callable +from typing import Optional +import uuid + +from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Artifact +from a2a.types import Message +from a2a.types import Part +from a2a.types import Role +from a2a.types import Task +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from typing_extensions import override + +from ...runners import Runner +from ...utils.context_utils import Aclosing +from ..converters.from_adk_event import create_error_status_event +from ..converters.long_running_functions import handle_user_input +from ..converters.long_running_functions import LongRunningFunctions +from ..converters.request_converter import AgentRunRequest +from ..converters.utils import _get_adk_metadata_key +from ..experimental import a2a_experimental +from .config import A2aAgentExecutorConfig +from .executor_context import ExecutorContext +from .utils import execute_after_agent_interceptors +from .utils import execute_after_event_interceptors +from .utils import execute_before_agent_interceptors + +logger = logging.getLogger('google_adk.' + __name__) + + +@a2a_experimental +class _A2aAgentExecutor(AgentExecutor): + """An AgentExecutor that runs an ADK Agent against an A2A request and + + publishes updates to an event queue. + """ + + def __init__( + self, + *, + runner: Runner | Callable[..., Runner | Awaitable[Runner]], + config: Optional[A2aAgentExecutorConfig] = None, + ): + super().__init__() + self._runner = runner + self._config = config or A2aAgentExecutorConfig() + + @override + async def cancel(self, context: RequestContext, event_queue: EventQueue): + """Cancel the execution.""" + # TODO: Implement proper cancellation logic if needed + raise NotImplementedError('Cancellation is not supported') + + @override + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ): + """Executes an A2A request and publishes updates to the event queue + + specified. It runs as following: + * Takes the input from the A2A request + * Convert the input to ADK input content, and runs the ADK agent + * Collects output events of the underlying ADK Agent + * Converts the ADK output events into A2A task updates + * Publishes the updates back to A2A server via event queue + """ + if not context.message: + raise ValueError('A2A request must have a message') + + context = await execute_before_agent_interceptors( + context, self._config.execute_interceptors + ) + + runner = await self._resolve_runner() + try: + run_request = self._config.request_converter( + context, + self._config.a2a_part_converter, + ) + await self._resolve_session(run_request, runner) + + executor_context = ExecutorContext( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + runner=runner, + ) + + # for new task, create a task submitted event + if not context.current_task: + await event_queue.enqueue_event( + Task( + id=context.task_id, + status=TaskStatus( + state=TaskState.submitted, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + history=[context.message], + metadata=self._get_invocation_metadata(executor_context), + ) + ) + else: + # Check if the user input is responding to the agent's + # request for input. + missing_user_input_event = handle_user_input(context) + if missing_user_input_event: + missing_user_input_event.metadata = self._get_invocation_metadata( + executor_context + ) + await event_queue.enqueue_event(missing_user_input_event) + return + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.working, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=False, + metadata=self._get_invocation_metadata(executor_context), + ) + ) + + # Handle the request and publish updates to the event queue + await self._handle_request( + context, + executor_context, + event_queue, + runner, + run_request, + ) + except Exception as e: + logger.error('Error handling A2A request: %s', e, exc_info=True) + # Publish failure event + try: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.failed, + timestamp=datetime.now(timezone.utc).isoformat(), + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[TextPart(text=str(e))], + ), + ), + context_id=context.context_id, + final=True, + ) + ) + except Exception as enqueue_error: + logger.error( + 'Failed to publish failure event: %s', enqueue_error, exc_info=True + ) + + async def _handle_request( + self, + context: RequestContext, + executor_context: ExecutorContext, + event_queue: EventQueue, + runner: Runner, + run_request: AgentRunRequest, + ): + agents_artifact: dict[str, str] = {} + error_event = None + long_running_functions = LongRunningFunctions( + self._config.gen_ai_part_converter + ) + async with Aclosing(runner.run_async(**vars(run_request))) as agen: + async for adk_event in agen: + # Handle error scenarios + if adk_event and (adk_event.error_code or adk_event.error_message): + error_event = create_error_status_event( + adk_event, + context.task_id, + context.context_id, + ) + + # Handle long running function calls + adk_event = long_running_functions.process_event(adk_event) + + for a2a_event in self._config.adk_event_converter( + adk_event, + agents_artifact, + context.task_id, + context.context_id, + self._config.gen_ai_part_converter, + ): + a2a_event.metadata = self._get_invocation_metadata(executor_context) + a2a_event = await execute_after_event_interceptors( + a2a_event, + executor_context, + adk_event, + self._config.execute_interceptors, + ) + if not a2a_event: + continue + await event_queue.enqueue_event(a2a_event) + + if error_event: + final_event = error_event + elif long_running_functions.has_long_running_function_calls(): + final_event = ( + long_running_functions.create_long_running_function_call_event( + context.task_id, context.context_id + ) + ) + else: + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.completed, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=True, + ) + + final_event.metadata = self._get_invocation_metadata(executor_context) + final_event = await execute_after_agent_interceptors( + executor_context, final_event, self._config.execute_interceptors + ) + await event_queue.enqueue_event(final_event) + + async def _resolve_runner(self) -> Runner: + """Resolve the runner, handling cases where it's a callable that returns a Runner.""" + if isinstance(self._runner, Runner): + return self._runner + if callable(self._runner): + result = self._runner() + + if inspect.iscoroutine(result): + resolved_runner = await result + else: + resolved_runner = result + + self._runner = resolved_runner + return resolved_runner + + raise TypeError( + 'Runner must be a Runner instance or a callable that returns a' + f' Runner, got {type(self._runner)}' + ) + + async def _resolve_session( + self, + run_request: AgentRunRequest, + runner: Runner, + ): + session_id = run_request.session_id + # create a new session if not exists + user_id = run_request.user_id + session = await runner.session_service.get_session( + app_name=runner.app_name, + user_id=user_id, + session_id=session_id, + ) + if session is None: + session = await runner.session_service.create_session( + app_name=runner.app_name, + user_id=user_id, + state={}, + session_id=session_id, + ) + # Update run_request with the new session_id + run_request.session_id = session.id + + def _get_invocation_metadata( + self, executor_context: ExecutorContext + ) -> dict[str, str]: + return { + _get_adk_metadata_key('app_name'): executor_context.app_name, + _get_adk_metadata_key('user_id'): executor_context.user_id, + _get_adk_metadata_key('session_id'): executor_context.session_id, + # TODO: Remove this metadata once the new agent executor + # is fully adopted. + _get_adk_metadata_key('agent_executor_v2'): True, + } diff --git a/src/google/adk/a2a/executor/config.py b/src/google/adk/a2a/executor/config.py index 79e88546..c083affd 100644 --- a/src/google/adk/a2a/executor/config.py +++ b/src/google/adk/a2a/executor/config.py @@ -23,9 +23,21 @@ from typing import Union from a2a.server.agent_execution.context import RequestContext from a2a.server.events import Event as A2AEvent from a2a.types import TaskStatusUpdateEvent +from pydantic import BaseModel from ...events.event import Event +from ..converters.event_converter import AdkEventToA2AEventsConverter +from ..converters.event_converter import convert_event_to_a2a_events as legacy_convert_event_to_a2a_events +from ..converters.from_adk_event import AdkEventToA2AEventsConverter as AdkEventToA2AEventsConverterImpl +from ..converters.from_adk_event import convert_event_to_a2a_events as convert_event_to_a2a_events_impl +from ..converters.part_converter import A2APartToGenAIPartConverter +from ..converters.part_converter import convert_a2a_part_to_genai_part +from ..converters.part_converter import convert_genai_part_to_a2a_part +from ..converters.part_converter import GenAIPartToA2APartConverter +from ..converters.request_converter import A2ARequestToAgentRunRequestConverter +from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key +from ..experimental import a2a_experimental from .executor_context import ExecutorContext @@ -67,3 +79,29 @@ class ExecuteInterceptor: completed or failed) before it is enqueued. Must return a valid `TaskStatusUpdateEvent`. """ + + +@a2a_experimental +class A2aAgentExecutorConfig(BaseModel): + """Configuration for the A2aAgentExecutor.""" + + a2a_part_converter: A2APartToGenAIPartConverter = ( + convert_a2a_part_to_genai_part + ) + gen_ai_part_converter: GenAIPartToA2APartConverter = ( + convert_genai_part_to_a2a_part + ) + request_converter: A2ARequestToAgentRunRequestConverter = ( + convert_a2a_request_to_agent_run_request + ) + event_converter: AdkEventToA2AEventsConverter = ( + legacy_convert_event_to_a2a_events + ) + """Set up the default event converter implementation to be used by the legacy agent executor implementation.""" + + adk_event_converter: AdkEventToA2AEventsConverterImpl = ( + convert_event_to_a2a_events_impl + ) + """Set up the imlp event converter implementation to be used by the new agent executor implementation.""" + + execute_interceptors: Optional[list[ExecuteInterceptor]] = None diff --git a/tests/unittests/a2a/converters/test_from_adk.py b/tests/unittests/a2a/converters/test_from_adk.py new file mode 100644 index 00000000..23546c58 --- /dev/null +++ b/tests/unittests/a2a/converters/test_from_adk.py @@ -0,0 +1,108 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import Mock +from unittest.mock import patch +import uuid + +from a2a.types import Part as A2APart +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.converters.from_adk_event import convert_event_to_a2a_events +from google.adk.events.event import Event +from google.genai import types as genai_types +import pytest + + +class TestFromAdk: + """Test suite for from_adk functions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_event = Mock(spec=Event) + self.mock_event.id = "test-event-id" + self.mock_event.invocation_id = "test-invocation-id" + self.mock_event.author = "test-author" + self.mock_event.branch = None + self.mock_event.content = None + self.mock_event.error_code = None + self.mock_event.error_message = None + self.mock_event.grounding_metadata = None + self.mock_event.citation_metadata = None + self.mock_event.custom_metadata = None + self.mock_event.usage_metadata = None + self.mock_event.actions = None + self.mock_event.partial = True + self.mock_event.long_running_tool_ids = None + + def test_convert_event_to_a2a_events_artifact_update(self): + """Test conversion of event to TaskArtifactUpdateEvent.""" + # Setup event with content + self.mock_event.content = genai_types.Content( + parts=[genai_types.Part(text="hello")], role="model" + ) + self.mock_event.author = "agent-1" + + agents_artifacts = {} + + # Mock part converter to return a standard text part + mock_a2a_part = A2APart(root=TextPart(text="hello")) + mock_a2a_part.root.metadata = {} + mock_convert_part = Mock(return_value=[mock_a2a_part]) + + result = convert_event_to_a2a_events( + self.mock_event, + agents_artifacts, + task_id="task-123", + context_id="context-456", + part_converter=mock_convert_part, + ) + + assert len(result) == 1 + assert isinstance(result[0], TaskArtifactUpdateEvent) + assert result[0].task_id == "task-123" + assert result[0].context_id == "context-456" + assert result[0].artifact.parts == [mock_a2a_part] + assert "agent-1" in agents_artifacts # Artifact ID should be stored + + def test_convert_event_to_a2a_events_error(self): + """Test conversion of event with error to TaskStatusUpdateEvent.""" + self.mock_event.error_code = "ERR001" + self.mock_event.error_message = "Something went wrong" + + agents_artifacts = {} + + result = convert_event_to_a2a_events( + self.mock_event, + agents_artifacts, + task_id="task-123", + context_id="context-456", + ) + + # Should not return any artifact events + assert len(result) == 0 + + def test_convert_event_to_a2a_events_none_event(self): + """Test convert_event_to_a2a_events with None event.""" + with pytest.raises(ValueError, match="Event cannot be None"): + convert_event_to_a2a_events(None, {}) + + def test_convert_event_to_a2a_events_none_artifacts(self): + """Test convert_event_to_a2a_events with None agents_artifacts.""" + with pytest.raises(ValueError, match="Agents artifacts cannot be None"): + convert_event_to_a2a_events(self.mock_event, None) diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py new file mode 100644 index 00000000..9acae2dc --- /dev/null +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py @@ -0,0 +1,808 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Message +from a2a.types import Task +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.converters.request_converter import AgentRunRequest +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.executor.a2a_agent_executor_impl import _A2aAgentExecutor as A2aAgentExecutor +from google.adk.a2a.executor.a2a_agent_executor_impl import A2aAgentExecutorConfig +from google.adk.a2a.executor.config import ExecuteInterceptor +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.runners import RunConfig +from google.adk.runners import Runner +from google.genai.types import Content +import pytest + + +class TestA2aAgentExecutor: + """Test suite for A2aAgentExecutor class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_runner = Mock(spec=Runner) + self.mock_runner.app_name = "test-app" + self.mock_runner.session_service = Mock() + self.mock_runner._new_invocation_context = Mock() + self.mock_runner.run_async = AsyncMock() + + self.mock_a2a_part_converter = Mock() + self.mock_gen_ai_part_converter = Mock() + self.mock_request_converter = Mock() + self.mock_event_converter = Mock() + self.mock_config = A2aAgentExecutorConfig( + a2a_part_converter=self.mock_a2a_part_converter, + gen_ai_part_converter=self.mock_gen_ai_part_converter, + request_converter=self.mock_request_converter, + adk_event_converter=self.mock_event_converter, + ) + self.executor = A2aAgentExecutor( + runner=self.mock_runner, config=self.mock_config + ) + + self.mock_context = Mock(spec=RequestContext) + self.mock_context.message = Mock(spec=Message) + self.mock_context.message.parts = [Mock(spec=TextPart)] + self.mock_context.current_task = None + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + self.mock_event_queue = Mock(spec=EventQueue) + + self.expected_metadata = { + _get_adk_metadata_key("app_name"): "test-app", + _get_adk_metadata_key("user_id"): "test-user", + _get_adk_metadata_key("session_id"): "test-session", + _get_adk_metadata_key("agent_executor_v2"): True, + } + + async def _create_async_generator(self, items): + """Helper to create async generator from items.""" + for item in items: + yield item + + @pytest.mark.asyncio + async def test_execute_success_new_task(self): + """Test successful execution of a new task.""" + # Setup + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock agent run with proper async generator + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter to return a working status update + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify request converter was called with proper arguments + self.mock_request_converter.assert_called_once_with( + self.mock_context, self.mock_a2a_part_converter + ) + + # Verify event converter was called with proper arguments + self.mock_event_converter.assert_called_once_with( + mock_event, + {}, # agents_artifact (initially empty) + self.mock_context.task_id, + self.mock_context.context_id, + self.mock_gen_ai_part_converter, + ) + + # Verify task submitted event was enqueued + # call 0: submitted + # call 1: working (from converter) + # call 2: completed (final) + assert self.mock_event_queue.enqueue_event.call_count >= 3 + + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert isinstance(submitted_event, Task) + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.metadata == self.expected_metadata + + # Verify working event was enqueued + enqueued_working_event = self.mock_event_queue.enqueue_event.call_args_list[ + 1 + ][0][0] + assert isinstance(enqueued_working_event, TaskStatusUpdateEvent) + assert enqueued_working_event.status.state == TaskState.working + assert enqueued_working_event.metadata == self.expected_metadata + + # Verify converted event was enqueued + converted_event = self.mock_event_queue.enqueue_event.call_args_list[2][0][ + 0 + ] + assert converted_event == working_event + assert converted_event.metadata == self.expected_metadata + + # Verify final event was enqueued + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + assert final_event.status.state == TaskState.completed + assert final_event.metadata == self.expected_metadata + + @pytest.mark.asyncio + async def test_execute_no_message_error(self): + """Test execution fails when no message is provided.""" + self.mock_context.message = None + + with pytest.raises(ValueError, match="A2A request must have a message"): + await self.executor.execute(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_existing_task(self): + """Test execution with existing task (no submitted event).""" + self.mock_context.current_task = Mock() + self.mock_context.task_id = "existing-task-id" + + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock agent run with proper async generator + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter + working_event = TaskStatusUpdateEvent( + task_id="existing-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify submitted event was NOT enqueued for existing task + # So we check first event is working state + first_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] + assert isinstance(first_event, TaskStatusUpdateEvent) + assert first_event.status.state == TaskState.working + assert first_event.metadata == self.expected_metadata + + # Verify manual working event is FIRST + assert isinstance(first_event, TaskStatusUpdateEvent) + assert first_event.status.state == TaskState.working + + # Verify converted event was enqueued + converted_event = self.mock_event_queue.enqueue_event.call_args_list[1][0][ + 0 + ] + assert converted_event == working_event + assert converted_event.metadata == self.expected_metadata + + # Verify final event + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + assert final_event.status.state == TaskState.completed + assert final_event.metadata == self.expected_metadata + + def test_constructor_with_callable_runner(self): + """Test constructor with callable runner.""" + callable_runner = Mock() + executor = A2aAgentExecutor(runner=callable_runner, config=self.mock_config) + + assert executor._runner == callable_runner + assert executor._config == self.mock_config + + @pytest.mark.asyncio + async def test_resolve_runner_direct_instance(self): + """Test _resolve_runner with direct Runner instance.""" + # Setup - already using direct runner instance in setup_method + runner = await self.executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_sync_callable(self): + """Test _resolve_runner with sync callable that returns Runner.""" + + def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_async_callable(self): + """Test _resolve_runner with async callable that returns Runner.""" + + async def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_invalid_type(self): + """Test _resolve_runner with invalid runner type.""" + executor = A2aAgentExecutor(runner="invalid", config=self.mock_config) + + with pytest.raises( + TypeError, match="Runner must be a Runner instance or a callable" + ): + await executor._resolve_runner() + + @pytest.mark.asyncio + async def test_handle_request_integration(self): + """Test the complete request handling flow.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + + # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [ + Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ), + Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ), + ] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter to return events + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Initialize executor context attributes as they would be in execute() + self.executor._invocation_metadata = {} + self.executor._executor_context = Mock() + + # Execute + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + self.mock_request_converter.return_value, + ) + + # Verify events enqueued + # Should check for working events + working_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "status") + and call[0][0].status.state == TaskState.working + ] + # Each ADK event generates 1 working event in this mock setup + assert len(working_events) >= len(mock_events) + + # Verify final event is completed + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] + assert final_event.status.state == TaskState.completed + + @pytest.mark.asyncio + async def test_cancel_with_task_id(self): + """Test cancellation with a task ID.""" + self.mock_context.task_id = "test-task-id" + + with pytest.raises( + NotImplementedError, match="Cancellation is not supported" + ): + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_with_exception_handling(self): + """Test execution with exception handling.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.current_task = None + + self.mock_request_converter.side_effect = Exception("Test error") + + # Execute (should not raise since we catch the exception) + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Check failure event (last) + failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert failure_event.status.state == TaskState.failed + assert failure_event.final == True + assert "Test error" in failure_event.status.message.parts[0].root.text + + @pytest.mark.asyncio + async def test_handle_request_with_non_working_state(self): + """Test handle request when a non-working state is encountered.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Mock agent run event + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + mock_event.error_code = "ERROR" + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter to return a FAILED event + failed_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.failed, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [failed_event] + + run_request = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Initialize executor context attributes + self.executor._invocation_metadata = {} + self.executor._executor_context = Mock() + + # Execute + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + run_request, + ) + + # Verify final event is FAILED, not COMPLETED + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + # The last event should be the synthesized final event + final_event = final_events[-1] + assert final_event.status.state == TaskState.failed + + @pytest.mark.asyncio + async def test_handle_request_with_error_message(self): + """Test handle request when an error message is present without an error code.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Mock agent run event with only error_message + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + mock_event.error_code = None + mock_event.error_message = "Test Error Message" + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] + + run_request = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + executor_context = Mock() + executor_context.app_name = "test-app" + executor_context.user_id = "test-user" + executor_context.session_id = "test-session" + + await self.executor._handle_request( + self.mock_context, + executor_context, + self.mock_event_queue, + self.mock_runner, + run_request, + ) + + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] + assert final_event.status.state == TaskState.failed + assert final_event.metadata == self.expected_metadata + + @pytest.mark.asyncio + async def test_interceptors(self): + """Test interceptors execution.""" + # Setup interceptors + before_interceptor = AsyncMock(return_value=self.mock_context) + after_event_interceptor = AsyncMock() + after_event_interceptor.side_effect = lambda ctx, a2a, adk: a2a + after_agent_interceptor = AsyncMock() + after_agent_interceptor.side_effect = lambda ctx, event: event + + interceptor = ExecuteInterceptor( + before_agent=before_interceptor, + after_event=after_event_interceptor, + after_agent=after_agent_interceptor, + ) + + self.mock_config.execute_interceptors = [interceptor] + + # Mock run + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Pre-setup request converter + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify interceptors called + before_interceptor.assert_called_once_with(self.mock_context) + # after_event called for each event + assert after_event_interceptor.call_count >= 1 + after_agent_interceptor.assert_called_once() + + @pytest.mark.asyncio + @patch("google.adk.a2a.executor.a2a_agent_executor_impl.handle_user_input") + async def test_execute_missing_user_input(self, mock_handle_user_input): + """Test when handle_user_input returns a missing user input event.""" + self.mock_context.current_task = Mock() + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Set up handle_user_input to return an event + missing_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.input_required, timestamp="now"), + context_id="test-context-id", + final=False, + ) + mock_handle_user_input.return_value = missing_event + + self.mock_runner.session_service.get_session = AsyncMock( + return_value=Mock(id="test-session") + ) + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify that the missing_event was enqueued + self.mock_event_queue.enqueue_event.assert_called_once_with(missing_event) + + # Verify that metadata was injected + enqueued_event = self.mock_event_queue.enqueue_event.call_args[0][0] + assert enqueued_event.metadata == self.expected_metadata + + @pytest.mark.asyncio + async def test_resolve_session_creates_new_session(self): + """Test that _resolve_session creates a new session if it doesn't exist.""" + self.mock_runner.session_service.get_session = AsyncMock(return_value=None) + + new_session = Mock() + new_session.id = "new-session-id" + self.mock_runner.session_service.create_session = AsyncMock( + return_value=new_session + ) + + run_request = AgentRunRequest( + user_id="test-user", + session_id="old-session-id", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + await self.executor._resolve_session(run_request, self.mock_runner) + + self.mock_runner.session_service.get_session.assert_called_once_with( + app_name=self.mock_runner.app_name, + user_id="test-user", + session_id="old-session-id", + ) + self.mock_runner.session_service.create_session.assert_called_once_with( + app_name=self.mock_runner.app_name, + user_id="test-user", + state={}, + session_id="old-session-id", + ) + assert run_request.session_id == "new-session-id" + + @pytest.mark.asyncio + async def test_execute_enqueue_error_in_exception_handler(self): + """Test failure event publishing handles exception during enqueue.""" + self.mock_context.task_id = "test-task-id" + self.mock_request_converter.side_effect = Exception("Test error") + + # Make enqueue_event raise an exception + self.mock_event_queue.enqueue_event.side_effect = Exception("Enqueue error") + + # This should not raise an exception itself + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify enqueue_event was called to publish the error event + assert self.mock_event_queue.enqueue_event.call_count == 1 + + @pytest.mark.asyncio + @patch("google.adk.a2a.executor.a2a_agent_executor_impl.LongRunningFunctions") + async def test_long_running_functions_final_event(self, mock_lrf_class): + """Test _handle_request when there are long running function calls.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Set up mock LongRunningFunctions + mock_lrf = mock_lrf_class.return_value + mock_lrf.process_event.side_effect = lambda e: e + mock_lrf.has_long_running_function_calls.return_value = True + + lrf_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.input_required, timestamp="now"), + context_id="test-context-id", + final=False, + ) + mock_lrf.create_long_running_function_call_event.return_value = lrf_event + + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] + + self.executor._invocation_metadata = {} + self.executor._executor_context = Mock() + + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + self.mock_request_converter.return_value, + ) + + # Verify final event is the long running function call event + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if call[0][0] == lrf_event + ] + assert len(final_events) >= 1 + + @pytest.mark.asyncio + async def test_after_event_interceptor_returns_none(self): + """Test after_event_interceptor returning None drops the event.""" + # Setup interceptor returning None + after_event_interceptor = AsyncMock() + after_event_interceptor.side_effect = lambda ctx, a2a, adk: None + + interceptor = ExecuteInterceptor( + after_event=after_event_interceptor, + ) + self.mock_config.execute_interceptors = [interceptor] + + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Event converter returns one event + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + self.executor._executor_context = Mock() + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + self.mock_request_converter.return_value, + ) + + # Since the interceptor returns None, working_event should NOT be enqueued + # The only event enqueued by _handle_request should be the final event + assert self.mock_event_queue.enqueue_event.call_count == 1 + final_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] + assert final_event.status.state == TaskState.completed From 82c2eefb27313c5b11b9e9382f626f543c53a29e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 4 Mar 2026 05:19:58 -0800 Subject: [PATCH 087/102] feat: add Dataplex Catalog search tool to BigQuery ADK Previous rollback CL - cl/872951141 This change introduces a new search_catalog tool within the BigQuery toolset, enabling users to search for BigQuery assets across projects using the Dataplex Catalog API. Key changes include: - Adding google-cloud-dataplex as a dependency in pyproject.toml. - Updating BigQuery credentials to include the Dataplex scope. - Implementing get_dataplex_catalog_client in client.py to create Dataplex API clients. - Creating search_tool.py with the search_catalog function, which constructs and executes Dataplex search queries. - Adding extensive unit tests for the new Dataplex client and the search_catalog tool, covering various scenarios including query filtering and error handling. - Updating the BigQuery toolset to include the new search_catalog tool. - Updating the BigQuery samples README to mention the new tool. PiperOrigin-RevId: 878435463 --- contributing/samples/bigquery/README.md | 4 + pyproject.toml | 1 + .../tools/bigquery/bigquery_credentials.py | 8 +- .../adk/tools/bigquery/bigquery_toolset.py | 2 + src/google/adk/tools/bigquery/client.py | 45 +- src/google/adk/tools/bigquery/search_tool.py | 179 +++++++ .../tools/bigquery/test_bigquery_client.py | 75 +++ .../bigquery/test_bigquery_credentials.py | 16 +- .../bigquery/test_bigquery_search_tool.py | 448 ++++++++++++++++++ .../tools/bigquery/test_bigquery_toolset.py | 3 +- 10 files changed, 768 insertions(+), 13 deletions(-) create mode 100644 src/google/adk/tools/bigquery/search_tool.py create mode 100644 tests/unittests/tools/bigquery/test_bigquery_search_tool.py diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index 3ed97432..99481390 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -55,6 +55,9 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: `ARIMA_PLUS` model and then querying it with `ML.DETECT_ANOMALIES` to detect time series data anomalies. +11. `search_catalog` + Searches for data entries across projects using the Dataplex Catalog. This allows discovery of datasets, tables, and other assets. + ## How to use Set up environment variables in your `.env` file for using @@ -159,3 +162,4 @@ the necessary access tokens to call BigQuery APIs on their behalf. * which tables exist in the ml_datasets dataset? * show more details about the penguins table * compute penguins population per island. +* are there any tables related to animals in project ? \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d0f3cd94..83b1a3f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "google-cloud-bigquery-storage>=2.0.0", "google-cloud-bigquery>=2.2.0", "google-cloud-bigtable>=2.32.0", # For Bigtable database + "google-cloud-dataplex>=1.7.0,<3.0.0", # For Dataplex Catalog Search tool "google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool "google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index fa23c74c..958ce9d7 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -19,6 +19,10 @@ from ...features import FeatureName from .._google_credentials import BaseGoogleCredentialsConfig BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" +BIGQUERY_SCOPES = [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", +] BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] @@ -34,8 +38,8 @@ class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig): super().__post_init__() if not self.scopes: - self.scopes = BIGQUERY_DEFAULT_SCOPE - + self.scopes = BIGQUERY_SCOPES + # Set the token cache key self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY return self diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 1a748b71..dba5f8ee 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -24,6 +24,7 @@ from typing_extensions import override from . import data_insights_tool from . import metadata_tool from . import query_tool +from . import search_tool from ...features import experimental from ...features import FeatureName from ...tools.base_tool import BaseTool @@ -87,6 +88,7 @@ class BigQueryToolset(BaseToolset): query_tool.analyze_contribution, query_tool.detect_anomalies, data_insights_tool.ask_data_insights, + search_tool.search_catalog, ] ] diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index d57c0c80..2cb4e67c 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -14,19 +14,22 @@ from __future__ import annotations +from typing import List from typing import Optional +from typing import Union import google.api_core.client_info +from google.api_core.gapic_v1 import client_info as gapic_client_info from google.auth.credentials import Credentials from google.cloud import bigquery +from google.cloud import dataplex_v1 from ... import version -USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" - - -from typing import List -from typing import Union +USER_AGENT_BASE = f"google-adk/{version.__version__}" +BQ_USER_AGENT = f"adk-bigquery-tool {USER_AGENT_BASE}" +DP_USER_AGENT = f"adk-dataplex-tool {USER_AGENT_BASE}" +USER_AGENT = BQ_USER_AGENT def get_bigquery_client( @@ -48,7 +51,7 @@ def get_bigquery_client( A BigQuery client. """ - user_agents = [USER_AGENT] + user_agents = [BQ_USER_AGENT] if user_agent: if isinstance(user_agent, str): user_agents.append(user_agent) @@ -67,3 +70,33 @@ def get_bigquery_client( ) return bigquery_client + + +def get_dataplex_catalog_client( + *, + credentials: Credentials, + user_agent: Optional[Union[str, List[str]]] = None, +) -> dataplex_v1.CatalogServiceClient: + """Get a Dataplex CatalogServiceClient with minimal necessary arguments. + + Args: + credentials: The credentials to use for the request. + user_agent: Additional user agent string(s) to append. + + Returns: + A Dataplex Client. + """ + + user_agents = [DP_USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) + + client_info = gapic_client_info.ClientInfo(user_agent=" ".join(user_agents)) + + return dataplex_v1.CatalogServiceClient( + credentials=credentials, + client_info=client_info, + ) diff --git a/src/google/adk/tools/bigquery/search_tool.py b/src/google/adk/tools/bigquery/search_tool.py new file mode 100644 index 00000000..0bf01d5a --- /dev/null +++ b/src/google/adk/tools/bigquery/search_tool.py @@ -0,0 +1,179 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any + +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + +from . import client +from .config import BigQueryToolConfig + + +def _construct_search_query_helper( + predicate: str, operator: str, items: list[str] +) -> str: + """Constructs a search query part for a specific predicate and items.""" + if not items: + return "" + + clauses = [f'{predicate}{operator}"{item}"' for item in items] + return "(" + " OR ".join(clauses) + ")" if len(items) > 1 else clauses[0] + + +def search_catalog( + prompt: str, + project_id: str, + *, + credentials: Credentials, + settings: BigQueryToolConfig, + location: str | None = None, + page_size: int = 10, + project_ids_filter: list[str] | None = None, + dataset_ids_filter: list[str] | None = None, + types_filter: list[str] | None = None, +) -> dict[str, Any]: + """Searches for BigQuery assets within Dataplex. + + Args: + prompt: The base search query (natural language or keywords). + project_id: The Google Cloud project ID to scope the search. + credentials: Credentials for the request. + settings: BigQuery tool settings. + location: The Dataplex location to use. + page_size: Maximum number of results. + project_ids_filter: Specific project IDs to include in the search results. + If None, defaults to the scoping project_id. + dataset_ids_filter: BigQuery dataset IDs to filter by. + types_filter: Entry types to filter by (e.g., BigQueryEntryType.TABLE, + BigQueryEntryType.DATASET). + + Returns: + Search results or error. The "results" list contains items with: + - name: The Dataplex Entry name (e.g., + "projects/p/locations/l/entryGroups/g/entries/e"). + - linked_resource: The underlying BigQuery resource name (e.g., + "//bigquery.googleapis.com/projects/p/datasets/d/tables/t"). + - display_name, entry_type, description, location, update_time. + + Examples: + Search for tables related to customer data: + + >>> search_catalog( + ... prompt="Search for tables related to customer data", + ... project_id="my-project", + ... credentials=creds, + ... settings=settings + ... ) + { + "status": "SUCCESS", + "results": [ + { + "name": + "projects/my-project/locations/us/entryGroups/@bigquery/entries/entry-id", + "display_name": "customer_table", + "entry_type": + "projects/p/locations/l/entryTypes/bigquery-table", + "linked_resource": + "//bigquery.googleapis.com/projects/my-project/datasets/d/tables/customer_table", + "description": "Table containing customer details.", + "location": "us", + "update_time": "2024-01-01 12:00:00+00:00" + } + ] + } + """ + + try: + if not project_id: + return { + "status": "ERROR", + "error_details": "project_id must be provided.", + } + + with client.get_dataplex_catalog_client( + credentials=credentials, + user_agent=[settings.application_name, "search_catalog"], + ) as dataplex_client: + query_parts = [] + if prompt: + query_parts.append(f"({prompt})") + + # Filter by project IDs + projects_to_filter = ( + project_ids_filter if project_ids_filter else [project_id] + ) + if projects_to_filter: + query_parts.append( + _construct_search_query_helper("projectid", "=", projects_to_filter) + ) + + # Filter by dataset IDs + if dataset_ids_filter: + dataset_resource_filters = [] + for pid in projects_to_filter: + for did in dataset_ids_filter: + dataset_resource_filters.append( + f'linked_resource:"//bigquery.googleapis.com/projects/{pid}/datasets/{did}/*"' + ) + if dataset_resource_filters: + query_parts.append(f"({' OR '.join(dataset_resource_filters)})") + # Filter by entry types + if types_filter: + query_parts.append( + _construct_search_query_helper("type", "=", types_filter) + ) + + # Always scope to BigQuery system + query_parts.append("system=BIGQUERY") + + full_query = " AND ".join(filter(None, query_parts)) + + search_location = location or settings.location or "global" + search_scope = f"projects/{project_id}/locations/{search_location}" + + request = dataplex_v1.SearchEntriesRequest( + name=search_scope, + query=full_query, + page_size=page_size, + semantic_search=True, + ) + + response = dataplex_client.search_entries(request=request) + + results = [] + for result in response.results: + entry = result.dataplex_entry + source = entry.entry_source + results.append({ + "name": entry.name, + "display_name": source.display_name or "", + "entry_type": entry.entry_type, + "update_time": str(entry.update_time), + "linked_resource": source.resource or "", + "description": source.description or "", + "location": source.location or "", + }) + return {"status": "SUCCESS", "results": results} + + except api_exceptions.GoogleAPICallError as e: + logging.exception("search_catalog tool: API call failed") + return {"status": "ERROR", "error_details": f"Dataplex API Error: {e}"} + except Exception as e: + logging.exception("search_catalog tool: Unexpected error") + return {"status": "ERROR", "error_details": repr(e)} diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py index 80a97f8f..d8d5e726 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_client.py +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -18,9 +18,13 @@ import os from unittest import mock import google.adk +from google.adk.tools.bigquery.client import DP_USER_AGENT from google.adk.tools.bigquery.client import get_bigquery_client +from google.adk.tools.bigquery.client import get_dataplex_catalog_client +from google.api_core.gapic_v1 import client_info as gapic_client_info import google.auth from google.auth.exceptions import DefaultCredentialsError +from google.cloud import dataplex_v1 from google.cloud.bigquery import client as bigquery_client from google.oauth2.credentials import Credentials @@ -201,3 +205,74 @@ def test_bigquery_client_location_custom(): # Verify that the client has the desired project set assert client.project == "test-gcp-project" assert client.location == "us-central1" + + +# Tests for Dataplex Catalog Client +# ------------------------------------------------------------------------------ + + +# Mock the CatalogServiceClient class directly +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_default(mock_catalog_service_client): + """Test get_dataplex_catalog_client with default user agent.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + + client = get_dataplex_catalog_client(credentials=mock_creds) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + + assert kwargs["credentials"] == mock_creds + client_info = kwargs["client_info"] + assert isinstance(client_info, gapic_client_info.ClientInfo) + assert client_info.user_agent == DP_USER_AGENT + + # Ensure the function returns the mock instance + assert client == mock_catalog_service_client.return_value + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_str(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent string.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua = "catalog_ua/1.0" + expected_ua = f"{DP_USER_AGENT} {custom_ua}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent list.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} {' '.join(custom_ua_list)}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list_with_none( + mock_catalog_service_client, +): + """Test get_dataplex_catalog_client with a list containing None.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", None, "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} catalog_ua catalog_ua_2.0" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index 9cf8c9e4..e2066292 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -44,9 +44,11 @@ class TestBigQueryCredentials: # Verify that the credentials are properly stored and attributes are extracted assert config.credentials == auth_creds - assert config.client_id is None assert config.client_secret is None - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_valid_credentials_object_oauth2_credentials(self): """Test that providing valid Credentials object works correctly with @@ -86,7 +88,10 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_valid_client_id_secret_pair_w_scope(self): """Test that providing client ID and secret with explicit scopes works. @@ -128,7 +133,10 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_missing_client_secret_raises_error(self): """Test that missing client secret raises appropriate validation error. diff --git a/tests/unittests/tools/bigquery/test_bigquery_search_tool.py b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py new file mode 100644 index 00000000..0ccdc9e1 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py @@ -0,0 +1,448 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from typing import Any +import unittest +from unittest import mock + +from absl.testing import parameterized + +# Mock google.genai and pydantic if not available, before importing google.adk modules +try: + import google.genai +except ImportError: + m = mock.MagicMock() + m.__path__ = [] + sys.modules["google.genai"] = m + sys.modules["google.genai.types"] = mock.MagicMock() + sys.modules["google.genai.errors"] = mock.MagicMock() + +try: + import pydantic +except ImportError: + m_pydantic = mock.MagicMock() + + class MockBaseModel: + pass + + m_pydantic.BaseModel = MockBaseModel + sys.modules["pydantic"] = m_pydantic + +try: + import fastapi + import fastapi.openapi.models +except ImportError: + m_fastapi = mock.MagicMock() + m_fastapi.openapi.models = mock.MagicMock() + sys.modules["fastapi"] = m_fastapi + sys.modules["fastapi.openapi"] = mock.MagicMock() + sys.modules["fastapi.openapi.models"] = mock.MagicMock() + + +from google.adk.tools.bigquery import search_tool +from google.adk.tools.bigquery.config import BigQueryToolConfig +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + + +def _mock_creds(): + return mock.create_autospec(Credentials, instance=True) + + +def _mock_settings(app_name: str | None = "test-app"): + return BigQueryToolConfig(application_name=app_name) + + +def _mock_search_entries_response(results: list[dict[str, Any]]): + mock_response = mock.MagicMock(spec=dataplex_v1.SearchEntriesResponse) + mock_results = [] + for r in results: + mock_result = mock.create_autospec( + dataplex_v1.SearchEntriesResult, instance=True + ) + # Manually attach dataplex_entry since it's not visible in dir() of the proto class + mock_entry = mock.create_autospec(dataplex_v1.Entry, instance=True) + mock_result.dataplex_entry = mock_entry + + mock_entry.name = r.get("name") + mock_entry.entry_type = r.get("entry_type") + mock_entry.update_time = r.get("update_time", "2026-01-14T05:00:00Z") + + # Manually attach entry_source since it's not visible in dir() of the proto class + mock_source = mock.create_autospec(dataplex_v1.EntrySource, instance=True) + mock_entry.entry_source = mock_source + + mock_source.display_name = r.get("display_name") + mock_source.resource = r.get("linked_resource") + mock_source.description = r.get("description") + mock_source.location = r.get("location") + mock_results.append(mock_result) + mock_response.results = mock_results + return mock_response + + +class TestSearchCatalog(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.mock_dataplex_client = mock.create_autospec( + dataplex_v1.CatalogServiceClient, instance=True + ) + + # Patch get_dataplex_catalog_client + self.mock_get_dataplex_client = self.enter_context( + mock.patch( + "google.adk.tools.bigquery.client.get_dataplex_catalog_client", + autospec=True, + ) + ) + self.mock_get_dataplex_client.return_value = self.mock_dataplex_client + self.mock_dataplex_client.__enter__.return_value = self.mock_dataplex_client + + # Patch SearchEntriesRequest + self.mock_search_request = self.enter_context( + mock.patch( + "google.cloud.dataplex_v1.SearchEntriesRequest", autospec=True + ) + ) + + def test_search_catalog_success(self): + """Test the successful path of search_catalog.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "customer data" + project_id = "test-project" + location = "us" + + mock_api_results = [{ + "name": "entry1", + "entry_type": "TABLE", + "display_name": "Cust Table", + "linked_resource": ( + "//bigquery.googleapis.com/projects/p/datasets/d/tables/t1" + ), + "description": "Table 1", + "location": "us", + }] + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response(mock_api_results) + ) + + result = search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + ) + + with self.subTest("Test result content"): + self.assertEqual(result["status"], "SUCCESS") + self.assertLen(result["results"], 1) + self.assertEqual(result["results"][0]["name"], "entry1") + self.assertEqual(result["results"][0]["display_name"], "Cust Table") + + with self.subTest("Test mock calls"): + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=["test-app", "search_catalog"] + ) + + expected_query = ( + '(customer data) AND projectid="test-project" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/us", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once_with( + request=self.mock_search_request.return_value + ) + + def test_search_catalog_no_project_id(self): + """Test search_catalog with missing project_id.""" + result = search_tool.search_catalog( + prompt="test", + project_id="", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn("project_id must be provided", result["error_details"]) + self.mock_get_dataplex_client.assert_not_called() + + def test_search_catalog_api_error(self): + """Test search_catalog handling API exceptions.""" + self.mock_dataplex_client.search_entries.side_effect = ( + api_exceptions.BadRequest("Invalid query") + ) + + result = search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn( + "Dataplex API Error: 400 Invalid query", result["error_details"] + ) + + def test_search_catalog_other_exception(self): + """Test search_catalog handling unexpected exceptions.""" + self.mock_get_dataplex_client.side_effect = Exception( + "Something went wrong" + ) + + result = search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn("Something went wrong", result["error_details"]) + + @parameterized.named_parameters( + ("project_filter", "p", ["proj1"], None, None, 'projectid="proj1"'), + ( + "multi_project_filter", + "p", + ["p1", "p2"], + None, + None, + '(projectid="p1" OR projectid="p2")', + ), + ("type_filter", "p", None, None, ["TABLE"], 'type="TABLE"'), + ( + "multi_type_filter", + "p", + None, + None, + ["TABLE", "DATASET"], + '(type="TABLE" OR type="DATASET")', + ), + ( + "project_and_dataset_filters", + "inventory", + ["proj1", "proj2"], + ["dsetA"], + None, + ( + '(projectid="proj1" OR projectid="proj2") AND' + ' (linked_resource:"//bigquery.googleapis.com/projects/proj1/datasets/dsetA/*"' + ' OR linked_resource:"//bigquery.googleapis.com/projects/proj2/datasets/dsetA/*")' + ), + ), + ) + def test_search_catalog_query_construction( + self, prompt, project_ids, dataset_ids, types, expected_query_part + ): + """Test different query constructions based on filters.""" + search_tool.search_catalog( + prompt=prompt, + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + project_ids_filter=project_ids, + dataset_ids_filter=dataset_ids, + types_filter=types, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + query = kwargs["query"] + + if prompt: + assert f"({prompt})" in query + assert "system=BIGQUERY" in query + assert expected_query_part in query + + def test_search_catalog_no_app_name(self): + """Test search_catalog when settings.application_name is None.""" + creds = _mock_creds() + settings = _mock_settings(app_name=None) + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + location="us", + ) + + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=[None, "search_catalog"] + ) + + def test_search_catalog_multi_project_filter_semantic(self): + """Test semantic search with a multi-project filter.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "What datasets store user profiles?" + project_id = "main-project" + project_filters = ["user-data-proj", "shared-infra-proj"] + location = "global" + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + project_ids_filter=project_filters, + types_filter=["DATASET"], + ) + + expected_query = ( + f"({prompt}) AND " + '(projectid="user-data-proj" OR projectid="shared-infra-proj") AND ' + 'type="DATASET" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + def test_search_catalog_natural_language_semantic(self): + """Test natural language prompts with semantic search enabled and check output.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "Find tables about football matches" + project_id = "sports-analytics" + location = "europe-west1" + + # Mock the results that the API would return for this semantic query + mock_api_results = [ + { + "name": ( + "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb1" + ), + "display_name": "uk_football_premiership", + "entry_type": ( + "projects/655216118709/locations/global/entryTypes/bigquery-table" + ), + "linked_resource": ( + "//bigquery.googleapis.com/projects/sports-analytics/datasets/uk/tables/premiership" + ), + "description": "Stats for UK Premier League matches.", + "location": "europe-west1", + }, + { + "name": ( + "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb2" + ), + "display_name": "serie_a_matches", + "entry_type": ( + "projects/655216118709/locations/global/entryTypes/bigquery-table" + ), + "linked_resource": ( + "//bigquery.googleapis.com/projects/sports-analytics/datasets/italy/tables/serie_a" + ), + "description": "Italian Serie A football results.", + "location": "europe-west1", + }, + ] + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response(mock_api_results) + ) + + result = search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + ) + + with self.subTest("Query Construction"): + # Assert the request was made as expected + expected_query = ( + f'({prompt}) AND projectid="{project_id}" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + with self.subTest("Response Processing"): + # Assert the output is processed correctly + self.assertEqual(result["status"], "SUCCESS") + self.assertLen(result["results"], 2) + self.assertEqual( + result["results"][0]["display_name"], "uk_football_premiership" + ) + self.assertEqual(result["results"][1]["display_name"], "serie_a_matches") + self.assertIn("UK Premier League", result["results"][0]["description"]) + + def test_search_catalog_default_location(self): + """Test search_catalog fallback to global location when None is provided.""" + creds = _mock_creds() + settings = _mock_settings() + # settings.location is None by default + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + name_arg = kwargs["name"] + self.assertIn("locations/global", name_arg) + + def test_search_catalog_settings_location(self): + """Test search_catalog uses settings.location when provided.""" + creds = _mock_creds() + settings = BigQueryToolConfig(location="eu") + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + name_arg = kwargs["name"] + self.assertIn("locations/eu", name_arg) diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index f1f73aa6..0eced4b1 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 10 + assert len(tools) == 11 assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ @@ -55,6 +55,7 @@ async def test_bigquery_toolset_tools_default(): "forecast", "analyze_contribution", "detect_anomalies", + "search_catalog", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names From 6770e419f5e200f4c7ad26587e1f769693ef4da0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 4 Mar 2026 07:43:18 -0800 Subject: [PATCH 088/102] feat: New implementation of RemoteA2aAgent and A2A-ADK conversion This change introduces compatibility of the remote_a2a_agent with the new a2a_agent_executor. New Event Converters: `to_adk_event.py`= Defines a new set of default converters for transforming A2A types (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) into ADK Event objects. Configurable Remote Agent: The A2aRemoteAgentConfig object allows users to override the default event converters with custom ones. New AgentExecutor Compatibility: RemoteA2aAgent now checks for `agent_executor_v2` metadata in A2A responses. If detected, it delegates response handling to a new `_handle_a2a_response_impl` method, which utilizes the modular converters defined in the configuration. PiperOrigin-RevId: 878487448 --- src/google/adk/a2a/agent/config.py | 38 +- src/google/adk/a2a/converters/to_adk_event.py | 374 ++++++++++++++++++ src/google/adk/agents/remote_a2a_agent.py | 87 +++- .../a2a/converters/test_event_round_trip.py | 208 ++++++++++ tests/unittests/a2a/converters/test_to_adk.py | 195 +++++++++ .../unittests/agents/test_remote_a2a_agent.py | 238 ++++++++++- 6 files changed, 1134 insertions(+), 6 deletions(-) create mode 100644 src/google/adk/a2a/converters/to_adk_event.py create mode 100644 tests/unittests/a2a/converters/test_event_round_trip.py create mode 100644 tests/unittests/a2a/converters/test_to_adk.py diff --git a/src/google/adk/a2a/agent/config.py b/src/google/adk/a2a/agent/config.py index e8f012cf..98984362 100644 --- a/src/google/adk/a2a/agent/config.py +++ b/src/google/adk/a2a/agent/config.py @@ -25,9 +25,18 @@ from typing import Union from a2a.client.middleware import ClientCallContext from a2a.server.events import Event as A2AEvent from a2a.types import Message as A2AMessage -from a2a.types import MessageSendConfiguration from pydantic import BaseModel +from ...a2a.converters.part_converter import A2APartToGenAIPartConverter +from ...a2a.converters.part_converter import convert_a2a_part_to_genai_part +from ...a2a.converters.to_adk_event import A2AArtifactUpdateToEventConverter +from ...a2a.converters.to_adk_event import A2AMessageToEventConverter +from ...a2a.converters.to_adk_event import A2AStatusUpdateToEventConverter +from ...a2a.converters.to_adk_event import A2ATaskToEventConverter +from ...a2a.converters.to_adk_event import convert_a2a_artifact_update_to_event +from ...a2a.converters.to_adk_event import convert_a2a_message_to_event +from ...a2a.converters.to_adk_event import convert_a2a_status_update_to_event +from ...a2a.converters.to_adk_event import convert_a2a_task_to_event from ...agents.invocation_context import InvocationContext from ...events.event import Event @@ -71,6 +80,31 @@ class RequestInterceptor(BaseModel): class A2aRemoteAgentConfig(BaseModel): - """Configuration for the RemoteA2aAgent.""" + """Configuration for A2A remote agents.""" + + # Converts standard A2A Messages into ADK Event. + a2a_message_converter: A2AMessageToEventConverter = ( + convert_a2a_message_to_event + ) + + # Converts an A2A Task into an ADK Event. + a2a_task_converter: A2ATaskToEventConverter = convert_a2a_task_to_event + + # Converts A2A TaskStatusUpdateEvents into ADK Event. + a2a_status_update_converter: A2AStatusUpdateToEventConverter = ( + convert_a2a_status_update_to_event + ) + + # Converts A2A TaskArtifactUpdateEvents into ADK Event. + a2a_artifact_update_converter: A2AArtifactUpdateToEventConverter = ( + convert_a2a_artifact_update_to_event + ) + + # A low-level hook that converts individual A2A Message Parts + # into native ADK/GenAI Part objects. + # This is utilized internally by the other converters. + a2a_part_converter: A2APartToGenAIPartConverter = ( + convert_a2a_part_to_genai_part + ) request_interceptors: Optional[list[RequestInterceptor]] = None diff --git a/src/google/adk/a2a/converters/to_adk_event.py b/src/google/adk/a2a/converters/to_adk_event.py new file mode 100644 index 00000000..66d7768e --- /dev/null +++ b/src/google/adk/a2a/converters/to_adk_event.py @@ -0,0 +1,374 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +import logging +from typing import Any +from typing import List +from typing import Optional +import uuid + +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Task +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent +from google.genai import types as genai_types + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event +from ..experimental import a2a_experimental +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from .part_converter import A2APartToGenAIPartConverter +from .part_converter import convert_a2a_part_to_genai_part +from .utils import _get_adk_metadata_key + +# Logger +logger = logging.getLogger("google_adk." + __name__) + +A2AMessageToEventConverter = Callable[ + [ + Message, + Optional[str], + Optional[InvocationContext], + A2APartToGenAIPartConverter, + ], + Optional[Event], +] +"""A Callable that converts an A2A Message to an ADK Event. + +Args: + Message: The A2A message to convert. + Optional[str]: The author of the event. + Optional[InvocationContext]: The invocation context. + A2APartToGenAIPartConverter: The part converter function. + +Returns: + Optional[Event]: The converted ADK Event. +""" + +A2ATaskToEventConverter = Callable[ + [ + Task, + Optional[str], + Optional[InvocationContext], + A2APartToGenAIPartConverter, + ], + Optional[Event], +] +"""A Callable that converts an A2A Task to an ADK Event. + +Args: + Task: The A2A task to convert. + Optional[str]: The author of the event. + Optional[InvocationContext]: The invocation context. + A2APartToGenAIPartConverter: The part converter function. + +Returns: + Optional[Event]: The converted ADK Event. +""" + +A2AStatusUpdateToEventConverter = Callable[ + [ + TaskStatusUpdateEvent, + Optional[str], + Optional[InvocationContext], + A2APartToGenAIPartConverter, + ], + Optional[Event], +] +"""A Callable that converts an A2A TaskStatusUpdateEvent to an ADK Event. + +Args: + TaskStatusUpdateEvent: The A2A status update event to convert. + Optional[str]: The author of the event. + Optional[InvocationContext]: The invocation context. + A2APartToGenAIPartConverter: The part converter function. + +Returns: + Optional[Event]: The converted ADK Event. +""" + +A2AArtifactUpdateToEventConverter = Callable[ + [ + TaskArtifactUpdateEvent, + Optional[str], + Optional[InvocationContext], + A2APartToGenAIPartConverter, + ], + Optional[Event], +] +"""A Callable that converts an A2A TaskArtifactUpdateEvent to an ADK Event. + +Args: + TaskArtifactUpdateEvent: The A2A artifact update event to convert. + Optional[str]: The author of the event. + Optional[InvocationContext]: The invocation context. + A2APartToGenAIPartConverter: The part converter function. + +Returns: + Optional[Event]: The converted ADK Event. +""" + + +def _convert_a2a_parts_to_adk_parts( + a2a_parts: List[A2APart], + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> tuple[List[genai_types.Part], set[str]]: + """Converts a list of A2A parts to a list of ADK parts.""" + output_parts = [] + long_running_function_ids = set() + + for a2a_part in a2a_parts: + try: + parts = part_converter(a2a_part) + if not isinstance(parts, list): + parts = [parts] if parts else [] + if not parts: + logger.warning("Failed to convert A2A part, skipping: %s", a2a_part) + continue + + # Check for long-running functions + if ( + a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ) + is True + ): + for part in parts: + if part.function_call: + long_running_function_ids.add(part.function_call.id) + + output_parts.extend(parts) + + except Exception as e: + logger.error("Failed to convert A2A part: %s, error: %s", a2a_part, e) + # Continue processing other parts instead of failing completely + continue + + if not output_parts: + logger.warning("No parts could be converted from A2A message") + + return output_parts, long_running_function_ids + + +def _create_event( + output_parts: List[genai_types.Part], + invocation_context: Optional[InvocationContext], + author: Optional[str], + long_running_function_ids: Optional[set[str]] = None, + partial: bool = False, +) -> Optional[Event]: + """Creates an ADK event from parts and metadata.""" + if not output_parts: + return None + + event = Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + long_running_tool_ids=( + long_running_function_ids if long_running_function_ids else None + ), + content=genai_types.Content( + role="model", + parts=output_parts, + ), + partial=partial, + ) + + return event + + +@a2a_experimental +def convert_a2a_task_to_event( + a2a_task: Task, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> Optional[Event]: + """Converts an A2A task to an ADK event. + + Args: + a2a_task: The A2A task to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + part_converter: The function to convert A2A part to GenAI part. + + Returns: + An ADK Event object representing the converted task. + + Raises: + ValueError: If a2a_task is None. + RuntimeError: If conversion of the underlying message fails. + """ + if a2a_task is None: + raise ValueError("A2A task cannot be None") + + try: + output_parts = [] + long_running_function_ids = set() + if a2a_task.artifacts: + artifact_parts = [ + part for artifact in a2a_task.artifacts for part in artifact.parts + ] + output_parts, _ = _convert_a2a_parts_to_adk_parts( + artifact_parts, part_converter + ) + if ( + a2a_task.status.message + and a2a_task.status.state == TaskState.input_required + ): + parts, ids = _convert_a2a_parts_to_adk_parts( + a2a_task.status.message.parts, part_converter + ) + output_parts.extend(parts) + long_running_function_ids.update(ids) + + return _create_event( + output_parts, + invocation_context, + author, + long_running_function_ids, + ) + + except Exception as e: + logger.error("Failed to convert A2A task to event: %s", e) + raise + + +@a2a_experimental +def convert_a2a_message_to_event( + a2a_message: Message, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> Optional[Event]: + """Converts an A2A message to an ADK event. + + Args: + a2a_message: The A2A message to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + part_converter: The function to convert A2A part to GenAI part. + + Returns: + An ADK Event object with converted content and long-running function + metadata. + + Raises: + ValueError: If a2a_message is None. + RuntimeError: If conversion of message parts fails. + """ + if a2a_message is None: + raise ValueError("A2A message cannot be None") + + try: + output_parts, _ = _convert_a2a_parts_to_adk_parts( + a2a_message.parts, part_converter + ) + return _create_event(output_parts, invocation_context, author) + + except Exception as e: + logger.error("Failed to convert A2A message to event: %s", e) + raise RuntimeError(f"Failed to convert message: {e}") from e + + +@a2a_experimental +def convert_a2a_status_update_to_event( + a2a_status_update: TaskStatusUpdateEvent, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> Optional[Event]: + """Converts an A2A task status update to an ADK event. + + Args: + a2a_status_update: The A2A task status update to convert. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + part_converter: The function to convert A2A part to GenAI part. + + Returns: + An ADK Event object representing the converted status update. + """ + if a2a_status_update is None: + raise ValueError("A2A status update cannot be None") + + try: + output_parts = [] + long_running_function_ids = set() + if a2a_status_update.status.message: + parts, ids = _convert_a2a_parts_to_adk_parts( + a2a_status_update.status.message.parts, part_converter + ) + output_parts.extend(parts) + long_running_function_ids.update(ids) + + return _create_event( + output_parts, + invocation_context, + author, + long_running_function_ids, + ) + except Exception as e: + logger.error("Failed to convert A2A status update to event: %s", e) + raise RuntimeError(f"Failed to convert status update: {e}") from e + + +# TODO: Add support for non-ADK Artifact Updates. +@a2a_experimental +def convert_a2a_artifact_update_to_event( + a2a_artifact_update: TaskArtifactUpdateEvent, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> Optional[Event]: + """Converts an A2A task artifact update to an ADK event. + + Args: + a2a_artifact_update: The A2A task artifact update to convert. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + part_converter: The function to convert A2A part to GenAI part. + + Returns: + An ADK Event object representing the converted artifact update. + """ + if a2a_artifact_update is None: + raise ValueError("A2A artifact update cannot be None") + + try: + output_parts, _ = _convert_a2a_parts_to_adk_parts( + a2a_artifact_update.artifact.parts, part_converter + ) + return _create_event( + output_parts, + invocation_context, + author, + partial=not a2a_artifact_update.last_chunk, + ) + except Exception as e: + logger.error("Failed to convert A2A artifact update to event: %s", e) + raise RuntimeError(f"Failed to convert artifact update: {e}") from e diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 5ffd123f..9b3a7b22 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -38,6 +38,7 @@ from a2a.types import Message as A2AMessage from a2a.types import MessageSendConfiguration from a2a.types import Part as A2APart from a2a.types import Role +from a2a.types import Task as A2ATask from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent @@ -62,6 +63,7 @@ from ..a2a.converters.part_converter import A2APartToGenAIPartConverter from ..a2a.converters.part_converter import convert_a2a_part_to_genai_part from ..a2a.converters.part_converter import convert_genai_part_to_a2a_part from ..a2a.converters.part_converter import GenAIPartToA2APartConverter +from ..a2a.converters.utils import _get_adk_metadata_key from ..a2a.experimental import a2a_experimental from ..a2a.logs.log_utils import build_a2a_request_log from ..a2a.logs.log_utils import build_a2a_response_log @@ -522,6 +524,76 @@ class RemoteA2aAgent(BaseAgent): branch=ctx.branch, ) + async def _handle_a2a_response_v2( + self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext + ) -> Optional[Event]: + """Handle A2A response and convert to Event. + + Args: + a2a_response: The A2A response object + ctx: The invocation context + + Returns: + Event object representing the response, or None if no event should be + emitted. + """ + try: + if isinstance(a2a_response, tuple): + task, update = a2a_response + event = None + if update is None: + # This is the initial response for a streaming task or the complete + # response for a non-streaming task. + event = self._config.a2a_task_converter( + task, self.name, ctx, self._config.a2a_part_converter + ) + elif isinstance(update, A2ATaskStatusUpdateEvent): + # This is a streaming task status update. + event = self._config.a2a_status_update_converter( + update, self.name, ctx, self._config.a2a_part_converter + ) + elif isinstance(update, A2ATaskArtifactUpdateEvent): + # This is a streaming task artifact update. + event = self._config.a2a_artifact_update_converter( + update, self.name, ctx, self._config.a2a_part_converter + ) + if not event: + return None + event.custom_metadata = event.custom_metadata or {} + event.custom_metadata[A2A_METADATA_PREFIX + "task_id"] = task.id + if task.context_id: + event.custom_metadata[A2A_METADATA_PREFIX + "context_id"] = ( + task.context_id + ) + + # Otherwise, it's a regular A2AMessage. + elif isinstance(a2a_response, A2AMessage): + event = self._config.a2a_message_converter( + a2a_response, self.name, ctx, self._config.a2a_part_converter + ) + event.custom_metadata = event.custom_metadata or {} + + if a2a_response.context_id: + event.custom_metadata[A2A_METADATA_PREFIX + "context_id"] = ( + a2a_response.context_id + ) + else: + event = Event( + author=self.name, + error_message="Unknown A2A response type", + invocation_id=ctx.invocation_id, + branch=ctx.branch, + ) + return event + except A2AClientError as e: + logger.error("Failed to handle A2A response: %s", e) + return Event( + author=self.name, + error_message=f"Failed to process A2A response: {e}", + invocation_id=ctx.invocation_id, + branch=ctx.branch, + ) + async def _run_async_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: @@ -589,7 +661,20 @@ class RemoteA2aAgent(BaseAgent): ): logger.debug(build_a2a_response_log(a2a_response)) - event = await self._handle_a2a_response(a2a_response, ctx) + metadata = None + if isinstance(a2a_response, tuple): + task = a2a_response[0] + if task: + metadata = task.metadata + else: + metadata = a2a_response.metadata + + if metadata and metadata.get( + _get_adk_metadata_key("agent_executor_v2") + ): + event = await self._handle_a2a_response_v2(a2a_response, ctx) + else: + event = await self._handle_a2a_response(a2a_response, ctx) if not event: continue diff --git a/tests/unittests/a2a/converters/test_event_round_trip.py b/tests/unittests/a2a/converters/test_event_round_trip.py new file mode 100644 index 00000000..00036f6a --- /dev/null +++ b/tests/unittests/a2a/converters/test_event_round_trip.py @@ -0,0 +1,208 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Round trip tests for ADK and A2A event converters.""" + +from __future__ import annotations + +from typing import Dict +from unittest.mock import Mock + +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskStatusUpdateEvent +from google.adk.a2a.converters.from_adk_event import convert_event_to_a2a_events +from google.adk.a2a.converters.from_adk_event import create_error_status_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_artifact_update_to_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_status_update_to_event +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.genai import types as genai_types + + +def test_round_trip_text_event(): + original_event = Event( + invocation_id="test_invocation", + author="test_agent", + branch="main", + content=genai_types.Content( + role="model", + parts=[genai_types.Part.from_text(text="Hello world!")], + ), + partial=False, + ) + agents_artifacts: Dict[str, str] = {} + + a2a_events = convert_event_to_a2a_events( + event=original_event, + agents_artifacts=agents_artifacts, + task_id="task1", + context_id="context1", + ) + + assert len(a2a_events) == 1 + a2a_event = a2a_events[0] + assert isinstance(a2a_event, TaskArtifactUpdateEvent) + + mock_context = Mock( + spec=InvocationContext, invocation_id="test_invocation", branch="main" + ) + + restored_event = convert_a2a_artifact_update_to_event( + a2a_artifact_update=a2a_event, + author="test_agent", + invocation_context=mock_context, + ) + + assert restored_event is not None + assert restored_event.author == original_event.author + assert restored_event.invocation_id == original_event.invocation_id + assert restored_event.branch == original_event.branch + assert restored_event.partial == original_event.partial + assert len(restored_event.content.parts) == len(original_event.content.parts) + assert ( + restored_event.content.parts[0].text + == original_event.content.parts[0].text + ) + + +def test_round_trip_error_status_event(): + original_event = Event( + invocation_id="error_inv", + author="error_agent", + branch="main", + error_message="Test Error", + ) + + a2a_event = create_error_status_event( + event=original_event, + task_id="task2", + context_id="ctx2", + ) + + assert isinstance(a2a_event, TaskStatusUpdateEvent) + + mock_context = Mock( + spec=InvocationContext, invocation_id="error_inv", branch="main" + ) + + restored_event = convert_a2a_status_update_to_event( + a2a_status_update=a2a_event, + author="error_agent", + invocation_context=mock_context, + ) + + assert restored_event is not None + assert restored_event.author == original_event.author + assert restored_event.invocation_id == original_event.invocation_id + assert restored_event.branch == original_event.branch + assert len(restored_event.content.parts) == 1 + assert restored_event.content.parts[0].text == "Test Error" + + +def test_round_trip_function_call_event(): + original_event = Event( + invocation_id="test_invocation", + author="test_agent", + branch="main", + content=genai_types.Content( + role="model", + parts=[ + genai_types.Part.from_function_call( + name="my_function", + args={"arg1": "value1"}, + ) + ], + ), + partial=False, + ) + agents_artifacts: Dict[str, str] = {} + + a2a_events = convert_event_to_a2a_events( + event=original_event, + agents_artifacts=agents_artifacts, + task_id="task1", + context_id="context1", + ) + + assert len(a2a_events) == 1 + a2a_event = a2a_events[0] + + mock_context = Mock( + spec=InvocationContext, invocation_id="test_invocation", branch="main" + ) + + restored_event = convert_a2a_artifact_update_to_event( + a2a_artifact_update=a2a_event, + author="test_agent", + invocation_context=mock_context, + ) + + assert restored_event is not None + assert restored_event.author == original_event.author + assert restored_event.invocation_id == original_event.invocation_id + assert restored_event.branch == original_event.branch + assert len(restored_event.content.parts) == 1 + assert restored_event.content.parts[0].function_call.name == "my_function" + assert restored_event.content.parts[0].function_call.args == { + "arg1": "value1" + } + + +def test_round_trip_function_response_event(): + original_event = Event( + invocation_id="test_invocation", + author="test_agent", + branch="main", + content=genai_types.Content( + role="user", + parts=[ + genai_types.Part.from_function_response( + name="my_function", + response={"result": "success"}, + ) + ], + ), + partial=False, + ) + agents_artifacts: Dict[str, str] = {} + + a2a_events = convert_event_to_a2a_events( + event=original_event, + agents_artifacts=agents_artifacts, + task_id="task1", + context_id="context1", + ) + + assert len(a2a_events) == 1 + a2a_event = a2a_events[0] + + mock_context = Mock( + spec=InvocationContext, invocation_id="test_invocation", branch="main" + ) + + restored_event = convert_a2a_artifact_update_to_event( + a2a_artifact_update=a2a_event, + author="test_agent", + invocation_context=mock_context, + ) + + assert restored_event is not None + assert restored_event.author == original_event.author + assert restored_event.invocation_id == original_event.invocation_id + assert restored_event.branch == original_event.branch + assert len(restored_event.content.parts) == 1 + assert restored_event.content.parts[0].function_response.name == "my_function" + assert restored_event.content.parts[0].function_response.response == { + "result": "success" + } diff --git a/tests/unittests/a2a/converters/test_to_adk.py b/tests/unittests/a2a/converters/test_to_adk.py new file mode 100644 index 00000000..90651956 --- /dev/null +++ b/tests/unittests/a2a/converters/test_to_adk.py @@ -0,0 +1,195 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import Mock + +from a2a.types import Artifact +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Task +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from google.adk.a2a.converters.to_adk_event import convert_a2a_artifact_update_to_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_message_to_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_status_update_to_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_task_to_event +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.agents.invocation_context import InvocationContext +from google.genai import types as genai_types +import pytest + + +class TestToAdk: + """Test suite for to_adk functions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_context = Mock(spec=InvocationContext) + self.mock_context.invocation_id = "test-invocation" + self.mock_context.branch = "test-branch" + + def test_convert_a2a_message_to_event_success(self): + """Test successful conversion of A2A message to Event.""" + a2a_part = Mock(spec=A2APart) + a2a_part.root = Mock() + a2a_part.root.metadata = {} + message = Message(message_id="msg-1", role="user", parts=[a2a_part]) + + mock_genai_part = genai_types.Part.from_text(text="hello") + mock_part_converter = Mock(return_value=[mock_genai_part]) + + event = convert_a2a_message_to_event( + message, + author="test-author", + invocation_context=self.mock_context, + part_converter=mock_part_converter, + ) + + assert event.author == "test-author" + assert event.invocation_id == "test-invocation" + assert event.branch == "test-branch" + assert len(event.content.parts) == 1 + assert event.content.parts[0] == mock_genai_part + + def test_convert_a2a_message_to_event_none(self): + """Test convert_a2a_message_to_event with None.""" + with pytest.raises(ValueError, match="A2A message cannot be None"): + convert_a2a_message_to_event(None) + + def test_convert_a2a_task_to_event_success(self): + """Test successful conversion of A2A task to Event.""" + a2a_part = Mock(spec=A2APart) + a2a_part.root = Mock() + a2a_part.root.metadata = {} + task = Task( + id="task-1", + status=TaskStatus( + state=TaskState.submitted, timestamp="2024-01-01T00:00:00Z" + ), + context_id="context-1", + history=[Message(message_id="msg-1", role="agent", parts=[a2a_part])], + artifacts=[ + Artifact( + artifact_id="art-1", artifact_type="message", parts=[a2a_part] + ) + ], + ) + + mock_genai_part = genai_types.Part.from_text(text="task artifact text") + mock_part_converter = Mock(return_value=[mock_genai_part]) + + event = convert_a2a_task_to_event( + task, + author="test-author", + invocation_context=self.mock_context, + part_converter=mock_part_converter, + ) + + assert event.author == "test-author" + assert event.invocation_id == "test-invocation" + assert len(event.content.parts) == 1 + assert event.content.parts[0] == mock_genai_part + + def test_convert_a2a_task_to_event_none(self): + """Test convert_a2a_task_to_event with None.""" + with pytest.raises(ValueError, match="A2A task cannot be None"): + convert_a2a_task_to_event(None) + + def test_convert_a2a_status_update_to_event_success(self): + """Test successful conversion of A2A status update to Event.""" + a2a_part = Mock(spec=A2APart) + a2a_part.root = Mock() + a2a_part.root.metadata = { + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY): True + } + update = TaskStatusUpdateEvent( + task_id="task-1", + status=TaskStatus( + state=TaskState.input_required, + timestamp="now", + message=Message( + message_id="m1", + role="agent", + parts=[a2a_part], + ), + ), + context_id="context-1", + final=False, + ) + + mock_genai_part = genai_types.Part( + function_call=genai_types.FunctionCall( + name="status update text", args={"arg": "value"}, id="call-1" + ) + ) + mock_part_converter = Mock(return_value=[mock_genai_part]) + + event = convert_a2a_status_update_to_event( + update, + author="test-author", + invocation_context=self.mock_context, + part_converter=mock_part_converter, + ) + + assert event.author == "test-author" + assert event.invocation_id == "test-invocation" + assert len(event.content.parts) == 1 + assert event.content.parts[0] == mock_genai_part + + def test_convert_a2a_status_update_to_event_none(self): + """Test convert_a2a_status_update_to_event with None.""" + with pytest.raises(ValueError, match="A2A status update cannot be None"): + convert_a2a_status_update_to_event(None) + + def test_convert_a2a_artifact_update_to_event_success(self): + """Test successful conversion of A2A artifact update to Event.""" + a2a_part = Mock(spec=A2APart) + a2a_part.root = Mock() + a2a_part.root.metadata = {} + update = TaskArtifactUpdateEvent( + task_id="task-1", + artifact=Artifact( + artifact_id="art-1", artifact_type="message", parts=[a2a_part] + ), + append=True, + context_id="context-1", + last_chunk=False, + ) + + mock_genai_part = genai_types.Part.from_text(text="artifact chunk text") + mock_part_converter = Mock(return_value=[mock_genai_part]) + + event = convert_a2a_artifact_update_to_event( + update, + author="test-author", + invocation_context=self.mock_context, + part_converter=mock_part_converter, + ) + + assert event.author == "test-author" + assert event.invocation_id == "test-invocation" + assert event.partial is True + assert len(event.content.parts) == 1 + assert event.content.parts[0] == mock_genai_part + + def test_convert_a2a_artifact_update_to_event_none(self): + """Test convert_a2a_artifact_update_to_event with None.""" + with pytest.raises(ValueError, match="A2A artifact update cannot be None"): + convert_a2a_artifact_update_to_event(None) diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index fe155d30..0f1ce896 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -1685,6 +1685,236 @@ class TestRemoteA2aAgentMessageHandlingFromFactory: assert result is None +class TestRemoteA2aAgentMessageHandlingV2: + """Test _handle_a2a_response_impl functionality.""" + + def setup_method(self): + """Setup test fixtures.""" + from google.adk.a2a.agent.config import A2aRemoteAgentConfig + + self.agent_card = create_test_agent_card() + self.mock_config = Mock(spec=A2aRemoteAgentConfig) + self.mock_config.a2a_part_converter = Mock() + self.mock_config.a2a_task_converter = Mock() + self.mock_config.a2a_status_update_converter = Mock() + self.mock_config.a2a_artifact_update_converter = Mock() + self.mock_config.a2a_message_converter = Mock() + + self.agent = RemoteA2aAgent( + name="test_agent", + agent_card=self.agent_card, + config=self.mock_config, + ) + + # Mock session and context + self.mock_session = Mock(spec=Session) + self.mock_session.id = "session-123" + self.mock_session.events = [] + + self.mock_context = Mock(spec=InvocationContext) + self.mock_context.session = self.mock_session + self.mock_context.invocation_id = "invocation-123" + self.mock_context.branch = "main" + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_with_message(self): + """Test _handle_a2a_response_impl with A2AMessage.""" + mock_a2a_message = Mock(spec=A2AMessage) + mock_a2a_message.metadata = {} + mock_a2a_message.metadata = {} + mock_a2a_message.context_id = "context-123" + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + self.mock_config.a2a_message_converter.return_value = mock_event + + result = await self.agent._handle_a2a_response_v2( + mock_a2a_message, self.mock_context + ) + + assert result == mock_event + self.mock_config.a2a_message_converter.assert_called_once_with( + mock_a2a_message, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + assert result.custom_metadata is not None + assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata + assert ( + result.custom_metadata[A2A_METADATA_PREFIX + "context_id"] + == "context-123" + ) + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_with_task_and_no_update(self): + """Test _handle_a2a_response_impl with Task and no update.""" + mock_a2a_task = Mock(spec=A2ATask) + mock_a2a_task.id = "task-123" + mock_a2a_task.context_id = "context-123" + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + self.mock_config.a2a_task_converter.return_value = mock_event + + result = await self.agent._handle_a2a_response_v2( + (mock_a2a_task, None), self.mock_context + ) + + assert result == mock_event + self.mock_config.a2a_task_converter.assert_called_once_with( + mock_a2a_task, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + assert result.custom_metadata is not None + assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata + assert result.custom_metadata[A2A_METADATA_PREFIX + "task_id"] == "task-123" + assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata + assert ( + result.custom_metadata[A2A_METADATA_PREFIX + "context_id"] + == "context-123" + ) + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_with_task_status_update(self): + """Test _handle_a2a_response_impl with TaskStatusUpdateEvent.""" + mock_a2a_task = Mock(spec=A2ATask) + mock_a2a_task.id = "task-123" + mock_a2a_task.context_id = None + + mock_update = Mock(spec=TaskStatusUpdateEvent) + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + self.mock_config.a2a_status_update_converter.return_value = mock_event + + result = await self.agent._handle_a2a_response_v2( + (mock_a2a_task, mock_update), self.mock_context + ) + + assert result == mock_event + self.mock_config.a2a_status_update_converter.assert_called_once_with( + mock_update, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + assert result.custom_metadata is not None + assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata + assert result.custom_metadata[A2A_METADATA_PREFIX + "task_id"] == "task-123" + assert A2A_METADATA_PREFIX + "context_id" not in result.custom_metadata + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_with_task_artifact_update(self): + """Test _handle_a2a_response_impl with TaskArtifactUpdateEvent.""" + mock_a2a_task = Mock(spec=A2ATask) + mock_a2a_task.id = "task-123" + mock_a2a_task.context_id = "context-123" + + mock_update = Mock(spec=TaskArtifactUpdateEvent) + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + self.mock_config.a2a_artifact_update_converter.return_value = mock_event + + result = await self.agent._handle_a2a_response_v2( + (mock_a2a_task, mock_update), self.mock_context + ) + + assert result == mock_event + self.mock_config.a2a_artifact_update_converter.assert_called_once_with( + mock_update, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + assert result.custom_metadata is not None + assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata + assert result.custom_metadata[A2A_METADATA_PREFIX + "task_id"] == "task-123" + assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata + assert ( + result.custom_metadata[A2A_METADATA_PREFIX + "context_id"] + == "context-123" + ) + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_update_converter_returns_none(self): + """Test _handle_a2a_response_impl when converter returns None.""" + mock_a2a_task = Mock(spec=A2ATask) + mock_a2a_task.id = "task-123" + + mock_update = Mock(spec=TaskArtifactUpdateEvent) + + self.mock_config.a2a_artifact_update_converter.return_value = None + + result = await self.agent._handle_a2a_response_v2( + (mock_a2a_task, mock_update), self.mock_context + ) + + assert result is None + self.mock_config.a2a_artifact_update_converter.assert_called_once_with( + mock_update, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_unknown_response_type(self): + """Test _handle_a2a_response_impl with unknown response type.""" + unknown_response = object() + + result = await self.agent._handle_a2a_response_v2( + unknown_response, self.mock_context + ) + + assert result is not None + assert result.author == self.agent.name + assert result.error_message == "Unknown A2A response type" + assert result.invocation_id == self.mock_context.invocation_id + assert result.branch == self.mock_context.branch + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_handles_client_error(self): + """Test _handle_a2a_response_impl catches A2AClientError.""" + mock_a2a_message = Mock(spec=A2AMessage) + mock_a2a_message.metadata = {} + mock_a2a_message.metadata = {} + + from google.adk.agents.remote_a2a_agent import A2AClientError + + self.mock_config.a2a_message_converter.side_effect = A2AClientError( + "Test client error" + ) + + result = await self.agent._handle_a2a_response_v2( + mock_a2a_message, self.mock_context + ) + + assert result is not None + assert result.author == self.agent.name + assert ( + "Failed to process A2A response: Test client error" + in result.error_message + ) + assert result.invocation_id == self.mock_context.invocation_id + assert result.branch == self.mock_context.branch + + class TestRemoteA2aAgentExecution: """Test agent execution functionality.""" @@ -1773,7 +2003,7 @@ class TestRemoteA2aAgentExecution: # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock() + mock_response = Mock(metadata={}) mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] mock_a2a_client.send_message.return_value = mock_send_message @@ -1912,7 +2142,7 @@ class TestRemoteA2aAgentExecution: # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock() + mock_response = Mock(metadata={}) mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] mock_a2a_client.send_message.return_value = mock_send_message @@ -2049,7 +2279,7 @@ class TestRemoteA2aAgentExecutionFromFactory: # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock() + mock_response = Mock(metadata={}) mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] mock_a2a_client.send_message.return_value = mock_send_message @@ -2295,6 +2525,7 @@ class TestRemoteA2aAgentIntegration: with patch.object(agent, "_a2a_client") as mock_a2a_client: mock_a2a_message = create_autospec(spec=A2AMessage, instance=True) mock_a2a_message.context_id = "context-123" + mock_a2a_message.metadata = {} mock_response = mock_a2a_message mock_send_message = AsyncMock() @@ -2391,6 +2622,7 @@ class TestRemoteA2aAgentIntegration: with patch.object(agent, "_a2a_client") as mock_a2a_client: mock_a2a_message = create_autospec(spec=A2AMessage, instance=True) mock_a2a_message.context_id = "context-123" + mock_a2a_message.metadata = {} mock_response = mock_a2a_message mock_send_message = AsyncMock() From 2780ae2892adfbebc7580c843d2eaad29f86c335 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=8D=E5=81=9A=E4=BA=86=E7=9D=A1=E5=A4=A7=E8=A7=89?= <64798754+stakeswky@users.noreply.github.com> Date: Wed, 4 Mar 2026 08:15:15 -0800 Subject: [PATCH 089/102] fix: temp-scoped state now visible to subsequent agents in same invocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merge https://github.com/google/adk-python/pull/4618 ## Summary Fixes #4564 When using `output_key` with a `temp:` prefix (e.g. `output_key='temp:result'`) in a `SequentialAgent`, the output was silently lost. Agent-2 could never read the temp state written by agent-1. ## Root Cause Two issues in `append_event`: 1. `_trim_temp_delta_state()` removed temp keys from the event delta **before** `_update_session_state()` could apply them to the in-memory session 2. `_update_session_state()` also explicitly skipped `temp:`-prefixed keys ```python # Before (broken ordering): async def append_event(self, session, event): event = self._trim_temp_delta_state(event) # temp keys gone! self._update_session_state(session, event) # nothing to apply ``` ## Fix Introduce `_apply_temp_state()` which writes temp-scoped keys to the in-memory `session.state` **before** the event delta is trimmed: ```python # After: async def append_event(self, session, event): self._apply_temp_state(session, event) # temp keys → session.state event = self._trim_temp_delta_state(event) # temp keys removed from delta self._update_session_state(session, event) # non-temp keys applied ``` This ensures: - ✅ Temp state is available to subsequent agents within the same invocation - ✅ Temp state is still stripped from event deltas (not persisted to storage) - ✅ All three session services (InMemory, Database, SQLite) behave consistently ## Files Changed - `src/google/adk/sessions/base_session_service.py`: Added `_apply_temp_state()`, reordered `append_event` logic, removed temp-skip in `_update_session_state` - `src/google/adk/sessions/database_session_service.py`: Added `_apply_temp_state()` call before trim - `src/google/adk/sessions/sqlite_session_service.py`: Added `_apply_temp_state()` call before trim - `tests/unittests/sessions/test_session_service.py`: Updated existing test + added new test for sequential agent scenario ## Testing All 67 session service tests pass across InMemory, Database, and SQLite backends. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4618 from stakeswky:fix/temp-state-output-key b9fc737e7a6dc07e06e99af3271a8fc026acae4a PiperOrigin-RevId: 878499263 --- .../adk/sessions/base_session_service.py | 26 +++++++++-- .../adk/sessions/database_session_service.py | 3 ++ .../adk/sessions/sqlite_session_service.py | 3 ++ .../sessions/test_session_service.py | 43 +++++++++++++++---- 4 files changed, 63 insertions(+), 12 deletions(-) diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index dddc2c83..eb22a83b 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -106,13 +106,35 @@ class BaseSessionService(abc.ABC): """Appends an event to a session object.""" if event.partial: return event + # Apply temp-scoped state to the in-memory session BEFORE trimming the + # event delta, so that subsequent agents within the same invocation can + # read temp values (e.g. output_key='temp:my_key' in SequentialAgent). + self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) self._update_session_state(session, event) session.events.append(event) return event + def _apply_temp_state(self, session: Session, event: Event) -> None: + """Applies temp-scoped state delta to the in-memory session state. + + Temp state is ephemeral: it lives in the session's in-memory state for + the duration of the current invocation but is NOT persisted to storage + (the event delta is trimmed separately by _trim_temp_delta_state). + """ + if not event.actions or not event.actions.state_delta: + return + for key, value in event.actions.state_delta.items(): + if key.startswith(State.TEMP_PREFIX): + session.state[key] = value + def _trim_temp_delta_state(self, event: Event) -> Event: - """Removes temporary state delta keys from the event.""" + """Removes temporary state delta keys from the event. + + This prevents temp-scoped state from being persisted, while the + in-memory session state (updated by _apply_temp_state) retains the + values for the duration of the current invocation. + """ if not event.actions or not event.actions.state_delta: return event @@ -128,6 +150,4 @@ class BaseSessionService(abc.ABC): if not event.actions or not event.actions.state_delta: return for key, value in event.actions.state_delta.items(): - if key.startswith(State.TEMP_PREFIX): - continue session.state.update({key: value}) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 6b19464e..321a5cc6 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -522,6 +522,9 @@ class DatabaseSessionService(BaseSessionService): if event.partial: return event + # Apply temp state to in-memory session before trimming, so that + # subsequent agents within the same invocation can read temp values. + self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index d23c8278..600f89c4 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -361,6 +361,9 @@ class SqliteSessionService(BaseSessionService): if event.partial: return event + # Apply temp state to in-memory session before trimming, so that + # subsequent agents within the same invocation can read temp values. + self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) event_timestamp = event.timestamp diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 4e277195..5c5aa83e 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -418,16 +418,41 @@ async def test_temp_state_is_not_persisted_in_state_or_events(session_service): ) await session_service.append_event(session=session, event=event) - # Refetch session and check state and event - session_got = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id='s1' - ) - # Check session state does not contain temp keys - assert session_got.state.get('sk') == 'v2' - assert 'temp:k1' not in session_got.state + # Temp state IS available in the in-memory session (same invocation) + assert session.state.get('temp:k1') == 'v1' + assert session.state.get('sk') == 'v2' + # Check event as stored in session does not contain temp keys in state_delta - assert 'temp:k1' not in session_got.events[0].actions.state_delta - assert session_got.events[0].actions.state_delta.get('sk') == 'v2' + assert 'temp:k1' not in event.actions.state_delta + assert event.actions.state_delta.get('sk') == 'v2' + + +@pytest.mark.asyncio +async def test_temp_state_visible_across_sequential_events(session_service): + """Temp state set by one event should be readable before the next event. + + This simulates a SequentialAgent where agent-1 writes output_key='temp:out' + and agent-2 needs to read it from session.state within the same invocation. + """ + app_name = 'my_app' + user_id = 'u1' + session = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s_seq' + ) + + # Agent-1 writes temp state + event1 = Event( + invocation_id='inv1', + author='agent1', + actions=EventActions(state_delta={'temp:output': 'result_from_a1'}), + ) + await session_service.append_event(session=session, event=event1) + + # Agent-2 should be able to read temp state from the same session object + assert session.state.get('temp:output') == 'result_from_a1' + + # But the event delta should NOT contain the temp key (not persisted) + assert 'temp:output' not in event1.actions.state_delta @pytest.mark.asyncio From d0825d817e39a9bd2e2fbb7ccd8690cf60593d14 Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 4 Mar 2026 13:33:52 -0800 Subject: [PATCH 090/102] fix: Change Mypy workflow trigger to manual dispatch Co-authored-by: George Weale PiperOrigin-RevId: 878643857 --- .github/workflows/mypy.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index f2626209..e893ce9e 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -1,10 +1,7 @@ name: Mypy Type Check on: - push: - branches: [ main ] - pull_request: - branches: [ main ] + workflow_dispatch: jobs: mypy: From 34c560e66e7ad379f586bbcd45a9460dc059bee2 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 4 Mar 2026 13:41:27 -0800 Subject: [PATCH 091/102] feat(bigtable): add Bigtable cluster metadata tools PiperOrigin-RevId: 878648015 --- contributing/samples/bigtable/agent.py | 4 +- .../adk/tools/bigtable/bigtable_toolset.py | 4 + .../adk/tools/bigtable/metadata_tool.py | 285 +++++++++++++-- .../bigtable/test_bigtable_metadata_tool.py | 327 +++++++++++++----- .../tools/bigtable/test_bigtable_toolset.py | 4 +- 5 files changed, 499 insertions(+), 125 deletions(-) diff --git a/contributing/samples/bigtable/agent.py b/contributing/samples/bigtable/agent.py index 6b4a50fc..1d52e1fe 100644 --- a/contributing/samples/bigtable/agent.py +++ b/contributing/samples/bigtable/agent.py @@ -62,8 +62,8 @@ bigtable_toolset = BigtableToolset( credentials_config=credentials_config, bigtable_tool_settings=tool_settings ) -_BIGTABLE_PROJECT_ID = "google.com:cloud-bigtable-dev" -_BIGTABLE_INSTANCE_ID = "annenguyen-bus-instance" +_BIGTABLE_PROJECT_ID = "" +_BIGTABLE_INSTANCE_ID = "" def search_hotels_by_location( diff --git a/src/google/adk/tools/bigtable/bigtable_toolset.py b/src/google/adk/tools/bigtable/bigtable_toolset.py index 8e9f430f..97fc2eb0 100644 --- a/src/google/adk/tools/bigtable/bigtable_toolset.py +++ b/src/google/adk/tools/bigtable/bigtable_toolset.py @@ -44,6 +44,8 @@ class BigtableToolset(BaseToolset): - bigtable_get_instance_info - bigtable_list_tables - bigtable_get_table_info + - bigtable_list_clusters + - bigtable_get_cluster_info - bigtable_execute_sql """ @@ -95,6 +97,8 @@ class BigtableToolset(BaseToolset): metadata_tool.get_instance_info, metadata_tool.list_tables, metadata_tool.get_table_info, + metadata_tool.list_clusters, + metadata_tool.get_cluster_info, query_tool.execute_sql, ] ] diff --git a/src/google/adk/tools/bigtable/metadata_tool.py b/src/google/adk/tools/bigtable/metadata_tool.py index 703c3447..de4fea6a 100644 --- a/src/google/adk/tools/bigtable/metadata_tool.py +++ b/src/google/adk/tools/bigtable/metadata_tool.py @@ -14,12 +14,16 @@ from __future__ import annotations +import enum import logging from google.auth.credentials import Credentials +from google.cloud.bigtable import enums from . import client +logger = logging.getLogger(f"google_adk.{__name__}") + def list_instances(project_id: str, credentials: Credentials) -> dict: """List Bigtable instance ids in a Google Cloud project. @@ -29,7 +33,22 @@ def list_instances(project_id: str, credentials: Credentials) -> dict: credentials (Credentials): The credentials to use for the request. Returns: - dict: Dictionary with a list of the Bigtable instance ids present in the project. + dict: Dictionary with a list of dictionaries, each representing a Bigtable instance. + + Example: + { + "status": "SUCCESS", + "results": [ + { + "project_id": "test-project", + "instance_id": "test-instance", + "display_name": "Test Instance", + "state": "READY", + "type": "PRODUCTION", + "labels": {"env": "test"}, + } + ], + } """ try: bt_client = client.get_bigtable_admin_client( @@ -41,12 +60,27 @@ def list_instances(project_id: str, credentials: Credentials) -> dict: "Failed to list instances from the following locations: %s", failed_locations_list, ) - instance_ids = [instance.instance_id for instance in instances_list] - return {"status": "SUCCESS", "results": instance_ids} + result = [ + { + "project_id": project_id, + "instance_id": instance.instance_id, + "display_name": instance.display_name, + "state": _enum_name_from_value( + enums.Instance.State, instance.state, "UNKNOWN_STATE" + ), + "type": _enum_name_from_value( + enums.Instance.Type, instance.type_, "UNKNOWN_TYPE" + ), + "labels": instance.labels, + } + for instance in instances_list + ] + return {"status": "SUCCESS", "results": result} except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), } @@ -69,26 +103,33 @@ def get_instance_info( ) instance = bt_client.instance(instance_id) instance.reload() - instance_info = { - "project_id": project_id, - "instance_id": instance.instance_id, - "display_name": instance.display_name, - "state": instance.state, - "type": instance.type_, - "labels": instance.labels, + return { + "status": "SUCCESS", + "results": { + "project_id": project_id, + "instance_id": instance.instance_id, + "display_name": instance.display_name, + "state": _enum_name_from_value( + enums.Instance.State, instance.state, "UNKNOWN_STATE" + ), + "type": _enum_name_from_value( + enums.Instance.Type, instance.type_, "UNKNOWN_TYPE" + ), + "labels": instance.labels, + }, } - return {"status": "SUCCESS", "results": instance_info} except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), } def list_tables( project_id: str, instance_id: str, credentials: Credentials ) -> dict: - """List table ids in a Bigtable instance. + """List tables and their metadata in a Bigtable instance. Args: project_id (str): The Google Cloud project id containing the instance. @@ -96,7 +137,21 @@ def list_tables( credentials (Credentials): The credentials to use for the request. Returns: - dict: Dictionary with a list of the tables ids present in the instance. + dict: A dictionary with status and results, where results is a list of + table properties. + + Example: + { + "status": "SUCCESS", + "results": [ + { + "project_id": "test-project", + "instance_id": "test-instance", + "table_id": "test-table", + "table_name": "fake-table-name", + } + ], + } """ try: bt_client = client.get_bigtable_admin_client( @@ -104,17 +159,29 @@ def list_tables( ) instance = bt_client.instance(instance_id) tables = instance.list_tables() - table_ids = [table.table_id for table in tables] - return {"status": "SUCCESS", "results": table_ids} + result = [ + { + "project_id": project_id, + "instance_id": instance_id, + "table_id": table.table_id, + "table_name": table.name, + } + for table in tables + ] + return {"status": "SUCCESS", "results": result} except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), } def get_table_info( - project_id: str, instance_id: str, table_id: str, credentials: Credentials + project_id: str, + instance_id: str, + table_id: str, + credentials: Credentials, ) -> dict: """Get metadata information about a Bigtable table. @@ -126,6 +193,17 @@ def get_table_info( Returns: dict: Dictionary representing the properties of the table. + + Example: + { + "status": "SUCCESS", + "results": { + "project_id": "test-project", + "instance_id": "test-instance", + "table_id": "test-table", + "column_families": ["cf1", "cf2"], + }, + } """ try: bt_client = client.get_bigtable_admin_client( @@ -134,15 +212,170 @@ def get_table_info( instance = bt_client.instance(instance_id) table = instance.table(table_id) column_families = table.list_column_families() - table_info = { - "project_id": project_id, - "instance_id": instance.instance_id, - "table_id": table.table_id, - "column_families": list(column_families.keys()), + return { + "status": "SUCCESS", + "results": { + "project_id": project_id, + "instance_id": instance.instance_id, + "table_id": table.table_id, + "column_families": list(column_families.keys()), + }, } - return {"status": "SUCCESS", "results": table_info} except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), + } + + +def _enum_name_from_value( + enum_class: type[enum.Enum], value: int, prefix: str = "UNKNOWN" +) -> str: + for attr_name in dir(enum_class): + if not attr_name.startswith("_"): + if getattr(enum_class, attr_name) == value: + return attr_name + return f"{prefix}_{value}" + + +def list_clusters( + project_id: str, instance_id: str, credentials: Credentials +) -> dict: + """List clusters and their metadata in a Bigtable instance. + + Args: + project_id (str): The Google Cloud project id containing the instance. + instance_id (str): The Bigtable instance id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary representing the properties of the cluster. + + Example: + { + "status": "SUCCESS", + "results": [ + { + "project_id": "test-project", + "instance_id": "test-instance", + "cluster_id": "test-cluster", + "cluster_name": "fake-cluster-name", + "state": "READY", + "serve_nodes": 3, + "default_storage_type": "SSD", + "location_id": "us-central1-a", + } + ], + } + """ + try: + bt_client = client.get_bigtable_admin_client( + project=project_id, credentials=credentials + ) + instance = bt_client.instance(instance_id) + instance.reload() + clusters_list, failed_locations = instance.list_clusters() + if failed_locations: + logging.warning( + "Failed to list clusters from the following locations: %s", + failed_locations, + ) + + result = [ + { + "project_id": project_id, + "instance_id": instance_id, + "cluster_id": cluster.cluster_id, + "cluster_name": cluster.name, + "state": _enum_name_from_value( + enums.Cluster.State, cluster.state, "UNKNOWN_STATE" + ), + "serve_nodes": cluster.serve_nodes, + "default_storage_type": _enum_name_from_value( + enums.StorageType, + cluster.default_storage_type, + "UNKNOWN_STORAGE_TYPE", + ), + "location_id": cluster.location_id, + } + for cluster in clusters_list + ] + return {"status": "SUCCESS", "results": result} + except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) + return { + "status": "ERROR", + "error_details": repr(ex), + } + + +def get_cluster_info( + project_id: str, + instance_id: str, + cluster_id: str, + credentials: Credentials, +) -> dict: + """Get detailed metadata information about a Bigtable cluster. + + Args: + project_id (str): The Google Cloud project id containing the instance. + instance_id (str): The Bigtable instance id containing the cluster. + cluster_id (str): The Bigtable cluster id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary representing the properties of the cluster. + + Example: + { + "status": "SUCCESS", + "results": { + "project_id": "test-project", + "instance_id": "test-instance", + "cluster_id": "test-cluster", + "state": "READY", + "serve_nodes": 3, + "default_storage_type": "SSD", + "location_id": "us-central1-a", + "min_serve_nodes": 1, + "max_serve_nodes": 10, + "cpu_utilization_percent": 80, + }, + } + """ + try: + bt_client = client.get_bigtable_admin_client( + project=project_id, credentials=credentials + ) + instance = bt_client.instance(instance_id) + instance.reload() + cluster = instance.cluster(cluster_id) + cluster.reload() + return { + "status": "SUCCESS", + "results": { + "project_id": project_id, + "instance_id": instance_id, + "cluster_id": cluster.cluster_id, + "state": _enum_name_from_value( + enums.Cluster.State, cluster.state, "UNKNOWN_STATE" + ), + "serve_nodes": cluster.serve_nodes, + "default_storage_type": _enum_name_from_value( + enums.StorageType, + cluster.default_storage_type, + "UNKNOWN_STORAGE_TYPE", + ), + "location_id": cluster.location_id, + "min_serve_nodes": cluster.min_serve_nodes, + "max_serve_nodes": cluster.max_serve_nodes, + "cpu_utilization_percent": cluster.cpu_utilization_percent, + }, + } + except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) + return { + "status": "ERROR", + "error_details": repr(ex), } diff --git a/tests/unittests/tools/bigtable/test_bigtable_metadata_tool.py b/tests/unittests/tools/bigtable/test_bigtable_metadata_tool.py index d7debf02..46904828 100644 --- a/tests/unittests/tools/bigtable/test_bigtable_metadata_tool.py +++ b/tests/unittests/tools/bigtable/test_bigtable_metadata_tool.py @@ -15,69 +15,69 @@ import logging from unittest import mock +from google.adk.tools.bigtable import client from google.adk.tools.bigtable import metadata_tool from google.auth.credentials import Credentials +from google.cloud.bigtable import enums +import pytest -def test_list_instances(): - """Test list_instances function.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" +@pytest.fixture +def mock_get_client(): + with mock.patch.object( + client, "get_bigtable_admin_client" ) as mock_get_client: mock_client = mock.MagicMock() mock_get_client.return_value = mock_client + yield mock_get_client + + +def test_list_instances(mock_get_client): + mock_instance = mock.MagicMock() + mock_instance.instance_id = "test-instance" + mock_get_client.return_value.list_instances.return_value = ( + [mock_instance], + [], + ) + + mock_instance.display_name = "Test Instance" + mock_instance.state = enums.Instance.State.READY + mock_instance.type_ = enums.Instance.Type.PRODUCTION + mock_instance.labels = {"env": "test"} + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.list_instances( + project_id="test-project", credentials=creds + ) + expected_result = { + "project_id": "test-project", + "instance_id": "test-instance", + "display_name": "Test Instance", + "state": "READY", + "type": "PRODUCTION", + "labels": {"env": "test"}, + } + assert result == {"status": "SUCCESS", "results": [expected_result]} + + +def test_list_instances_failed_locations(mock_get_client): + with mock.patch.object(logging, "warning") as mock_warning: mock_instance = mock.MagicMock() mock_instance.instance_id = "test-instance" - mock_client.list_instances.return_value = ([mock_instance], []) + failed_locations = ["us-west1-a"] + mock_get_client.return_value.list_instances.return_value = ( + [mock_instance], + failed_locations, + ) - creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.list_instances("test-project", creds) - assert result == {"status": "SUCCESS", "results": ["test-instance"]} - - -def test_list_instances_failed_locations(): - """Test list_instances function when some locations fail.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" - ) as mock_get_client: - with mock.patch.object(logging, "warning") as mock_warning: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_instance = mock.MagicMock() - mock_instance.instance_id = "test-instance" - failed_locations = ["us-west1-a"] - mock_client.list_instances.return_value = ( - [mock_instance], - failed_locations, - ) - - creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.list_instances("test-project", creds) - assert result == {"status": "SUCCESS", "results": ["test-instance"]} - mock_warning.assert_called_once_with( - "Failed to list instances from the following locations: %s", - failed_locations, - ) - - -def test_get_instance_info(): - """Test get_instance_info function.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_instance = mock.MagicMock() - mock_client.instance.return_value = mock_instance - mock_instance.instance_id = "test-instance" mock_instance.display_name = "Test Instance" - mock_instance.state = "READY" - mock_instance.type_ = "PRODUCTION" + mock_instance.state = enums.Instance.State.READY + mock_instance.type_ = enums.Instance.Type.PRODUCTION mock_instance.labels = {"env": "test"} creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.get_instance_info( - "test-project", "test-instance", creds + result = metadata_tool.list_instances( + project_id="test-project", credentials=creds ) expected_result = { "project_id": "test-project", @@ -87,51 +87,186 @@ def test_get_instance_info(): "type": "PRODUCTION", "labels": {"env": "test"}, } - assert result == {"status": "SUCCESS", "results": expected_result} - mock_instance.reload.assert_called_once() - - -def test_list_tables(): - """Test list_tables function.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_instance = mock.MagicMock() - mock_client.instance.return_value = mock_instance - mock_table = mock.MagicMock() - mock_table.table_id = "test-table" - mock_instance.list_tables.return_value = [mock_table] - - creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.list_tables("test-project", "test-instance", creds) - assert result == {"status": "SUCCESS", "results": ["test-table"]} - - -def test_get_table_info(): - """Test get_table_info function.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_instance = mock.MagicMock() - mock_client.instance.return_value = mock_instance - mock_table = mock.MagicMock() - mock_instance.table.return_value = mock_table - mock_table.table_id = "test-table" - mock_instance.instance_id = "test-instance" - mock_table.list_column_families.return_value = {"cf1": mock.MagicMock()} - - creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.get_table_info( - "test-project", "test-instance", "test-table", creds + assert result == {"status": "SUCCESS", "results": [expected_result]} + mock_warning.assert_called_once_with( + "Failed to list instances from the following locations: %s", + failed_locations, ) - expected_result = { - "project_id": "test-project", - "instance_id": "test-instance", - "table_id": "test-table", - "column_families": ["cf1"], - } - assert result == {"status": "SUCCESS", "results": expected_result} + + +def test_get_instance_info(mock_get_client): + mock_instance = mock.MagicMock() + mock_get_client.return_value.instance.return_value = mock_instance + mock_instance.instance_id = "test-instance" + mock_instance.display_name = "Test Instance" + mock_instance.state = enums.Instance.State.READY + mock_instance.type_ = enums.Instance.Type.PRODUCTION + mock_instance.labels = {"env": "test"} + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_instance_info( + project_id="test-project", + instance_id="test-instance", + credentials=creds, + ) + expected_result = { + "project_id": "test-project", + "instance_id": "test-instance", + "display_name": "Test Instance", + "state": "READY", + "type": "PRODUCTION", + "labels": {"env": "test"}, + } + assert result == {"status": "SUCCESS", "results": expected_result} + mock_instance.reload.assert_called_once() + + +def test_list_tables(mock_get_client): + mock_instance = mock.MagicMock() + mock_get_client.return_value.instance.return_value = mock_instance + mock_table = mock.MagicMock() + mock_table.table_id = "test-table" + mock_table.name = ( + "projects/test-project/instances/test-instance/tables/test-table" + ) + mock_instance.list_tables.return_value = [mock_table] + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.list_tables( + project_id="test-project", + instance_id="test-instance", + credentials=creds, + ) + expected_result = [{ + "project_id": "test-project", + "instance_id": "test-instance", + "table_id": "test-table", + "table_name": ( + "projects/test-project/instances/test-instance/tables/test-table" + ), + }] + assert result == {"status": "SUCCESS", "results": expected_result} + + +def test_get_table_info(mock_get_client): + mock_instance = mock.MagicMock() + mock_instance.instance_id = "test-instance" + mock_get_client.return_value.instance.return_value = mock_instance + mock_table = mock.MagicMock() + mock_instance.table.return_value = mock_table + mock_table.table_id = "test-table" + mock_table.list_column_families.return_value = {"cf1": mock.MagicMock()} + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_table_info( + project_id="test-project", + instance_id="test-instance", + table_id="test-table", + credentials=creds, + ) + expected_result = { + "project_id": "test-project", + "instance_id": "test-instance", + "table_id": "test-table", + "column_families": ["cf1"], + } + assert result == {"status": "SUCCESS", "results": expected_result} + + +def test_list_clusters(mock_get_client): + mock_instance = mock.MagicMock() + mock_get_client.return_value.instance.return_value = mock_instance + mock_cluster = mock.MagicMock() + mock_cluster.cluster_id = "test-cluster" + mock_cluster.name = ( + "projects/test-project/instances/test-instance/clusters/test-cluster" + ) + mock_cluster.state = enums.Cluster.State.READY + mock_cluster.serve_nodes = 3 + mock_cluster.default_storage_type = enums.StorageType.SSD + mock_cluster.location_id = "us-central1-a" + mock_instance.list_clusters.return_value = ([mock_cluster], []) + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.list_clusters( + project_id="test-project", + instance_id="test-instance", + credentials=creds, + ) + expected_result = [{ + "project_id": "test-project", + "instance_id": "test-instance", + "cluster_id": "test-cluster", + "cluster_name": mock_cluster.name, + "state": "READY", + "serve_nodes": 3, + "default_storage_type": "SSD", + "location_id": "us-central1-a", + }] + assert result == {"status": "SUCCESS", "results": expected_result} + + +def test_list_clusters_error(mock_get_client): + mock_get_client.side_effect = Exception("test-error") + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.list_clusters( + project_id="test-project", + instance_id="test-instance", + credentials=creds, + ) + assert result == { + "status": "ERROR", + "error_details": "Exception('test-error')", + } + + +def test_get_cluster_info(mock_get_client): + mock_instance = mock.MagicMock() + mock_get_client.return_value.instance.return_value = mock_instance + mock_cluster = mock.MagicMock() + mock_instance.cluster.return_value = mock_cluster + mock_cluster.cluster_id = "test-cluster" + mock_cluster.state = enums.Cluster.State.READY + mock_cluster.serve_nodes = 3 + mock_cluster.default_storage_type = enums.StorageType.SSD + mock_cluster.location_id = "us-central1-a" + mock_cluster.min_serve_nodes = 3 + mock_cluster.max_serve_nodes = 10 + mock_cluster.cpu_utilization_percent = 50 + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_cluster_info( + project_id="test-project", + instance_id="test-instance", + cluster_id="test-cluster", + credentials=creds, + ) + expected_results = { + "project_id": "test-project", + "instance_id": "test-instance", + "cluster_id": "test-cluster", + "state": "READY", + "serve_nodes": 3, + "default_storage_type": "SSD", + "location_id": "us-central1-a", + "min_serve_nodes": 3, + "max_serve_nodes": 10, + "cpu_utilization_percent": 50, + } + assert result == {"status": "SUCCESS", "results": expected_results} + mock_cluster.reload.assert_called_once() + + +def test_get_cluster_info_error(mock_get_client): + mock_get_client.side_effect = Exception("test-error") + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_cluster_info( + project_id="test-project", + instance_id="test-instance", + cluster_id="test-cluster", + credentials=creds, + ) + assert result == { + "status": "ERROR", + "error_details": "Exception('test-error')", + } diff --git a/tests/unittests/tools/bigtable/test_bigtable_toolset.py b/tests/unittests/tools/bigtable/test_bigtable_toolset.py index 53040395..b5698cfc 100644 --- a/tests/unittests/tools/bigtable/test_bigtable_toolset.py +++ b/tests/unittests/tools/bigtable/test_bigtable_toolset.py @@ -45,7 +45,7 @@ async def test_bigtable_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 5 + assert len(tools) == 7 assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ @@ -54,6 +54,8 @@ async def test_bigtable_toolset_tools_default(): "list_tables", "get_table_info", "execute_sql", + "list_clusters", + "get_cluster_info", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names From 45fb53b9e2d356098c938bce0151baa41148690d Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 4 Mar 2026 14:08:49 -0800 Subject: [PATCH 092/102] chore: Move API registry to the integrations folder Added a deprecation warning in the old tools/api_registry file. Co-authored-by: Kathy Wu PiperOrigin-RevId: 878660213 --- .../samples/api_registry_agent/agent.py | 2 +- .../adk/integrations/api_registry/__init__.py | 17 +++ .../integrations/api_registry/api_registry.py | 140 ++++++++++++++++++ src/google/adk/tools/api_registry.py | 131 +--------------- .../integrations/api_registry/__init__.py | 11 ++ .../api_registry}/test_api_registry.py | 21 ++- 6 files changed, 192 insertions(+), 130 deletions(-) create mode 100644 src/google/adk/integrations/api_registry/__init__.py create mode 100644 src/google/adk/integrations/api_registry/api_registry.py create mode 100644 tests/unittests/integrations/api_registry/__init__.py rename tests/unittests/{tools => integrations/api_registry}/test_api_registry.py (96%) diff --git a/contributing/samples/api_registry_agent/agent.py b/contributing/samples/api_registry_agent/agent.py index 9f55ef80..87faea31 100644 --- a/contributing/samples/api_registry_agent/agent.py +++ b/contributing/samples/api_registry_agent/agent.py @@ -15,7 +15,7 @@ import os from google.adk.agents.llm_agent import LlmAgent -from google.adk.tools.api_registry import ApiRegistry +from google.adk.integrations.api_registry import ApiRegistry # TODO: Fill in with your GCloud project id and MCP server name PROJECT_ID = "your-google-cloud-project-id" diff --git a/src/google/adk/integrations/api_registry/__init__.py b/src/google/adk/integrations/api_registry/__init__.py new file mode 100644 index 00000000..1179bc86 --- /dev/null +++ b/src/google/adk/integrations/api_registry/__init__.py @@ -0,0 +1,17 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .api_registry import ApiRegistry + +__all__ = [ + 'ApiRegistry', +] diff --git a/src/google/adk/integrations/api_registry/api_registry.py b/src/google/adk/integrations/api_registry/api_registry.py new file mode 100644 index 00000000..966ad68b --- /dev/null +++ b/src/google/adk/integrations/api_registry/api_registry.py @@ -0,0 +1,140 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any +from typing import Callable + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.tools.base_toolset import ToolPredicate +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +import google.auth +import google.auth.transport.requests +import httpx + +API_REGISTRY_URL = "https://cloudapiregistry.googleapis.com" + + +class ApiRegistry: + """Registry that provides McpToolsets for MCP servers registered in API Registry.""" + + def __init__( + self, + api_registry_project_id: str, + location: str = "global", + header_provider: ( + Callable[[ReadonlyContext], dict[str, str]] | None + ) = None, + ): + """Initialize the API Registry. + + Args: + api_registry_project_id: The project ID for the Google Cloud API Registry. + location: The location of the API Registry resources. + header_provider: Optional function to provide additional headers for MCP + server calls. + """ + self.api_registry_project_id = api_registry_project_id + self.location = location + self._credentials, _ = google.auth.default() + self._mcp_servers: dict[str, dict[str, Any]] = {} + self._header_provider = header_provider + + url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers" + + try: + headers = self._get_auth_headers() + headers["Content-Type"] = "application/json" + page_token = None + with httpx.Client() as client: + while True: + params = {} + if page_token: + params["pageToken"] = page_token + + response = client.get(url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + mcp_servers_list = data.get("mcpServers", []) + for server in mcp_servers_list: + server_name = server.get("name", "") + if server_name: + self._mcp_servers[server_name] = server + + page_token = data.get("nextPageToken") + if not page_token: + break + except (httpx.HTTPError, ValueError) as e: + # Handle error in fetching or parsing tool definitions + raise RuntimeError( + f"Error fetching MCP servers from API Registry: {e}" + ) from e + + def get_toolset( + self, + mcp_server_name: str, + tool_filter: ToolPredicate | list[str] | None = None, + tool_name_prefix: str | None = None, + ) -> McpToolset: + """Return the MCP Toolset based on the params. + + Args: + mcp_server_name: Filter to select the MCP server name to get tools from. + tool_filter: Optional filter to select specific tools. Can be a list of + tool names or a ToolPredicate function. + tool_name_prefix: Optional prefix to prepend to the names of the tools + returned by the toolset. + + Returns: + McpToolset: A toolset for the MCP server specified. + """ + server = self._mcp_servers.get(mcp_server_name) + if not server: + raise ValueError( + f"MCP server {mcp_server_name} not found in API Registry." + ) + if not server.get("urls"): + raise ValueError(f"MCP server {mcp_server_name} has no URLs.") + + mcp_server_url = server["urls"][0] + headers = self._get_auth_headers() + + # Only prepend "https://" if the URL doesn't already have a scheme + if not mcp_server_url.startswith(("http://", "https://")): + mcp_server_url = "https://" + mcp_server_url + + return McpToolset( + connection_params=StreamableHTTPConnectionParams( + url=mcp_server_url, + headers=headers, + ), + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + header_provider=self._header_provider, + ) + + def _get_auth_headers(self) -> dict[str, str]: + """Refreshes credentials and returns authorization headers.""" + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + } + # Add quota project header if available in ADC + quota_project_id = getattr(self._credentials, "quota_project_id", None) + if quota_project_id: + headers["x-goog-user-project"] = quota_project_id + return headers diff --git a/src/google/adk/tools/api_registry.py b/src/google/adk/tools/api_registry.py index feaf1c7a..d3483fc2 100644 --- a/src/google/adk/tools/api_registry.py +++ b/src/google/adk/tools/api_registry.py @@ -14,128 +14,13 @@ from __future__ import annotations -from typing import Any -from typing import Callable +import warnings -from google.adk.agents.readonly_context import ReadonlyContext -import google.auth -import google.auth.transport.requests -import httpx +from google.adk.integrations.api_registry import ApiRegistry -from .base_toolset import ToolPredicate -from .mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams -from .mcp_tool.mcp_toolset import McpToolset - -API_REGISTRY_URL = "https://cloudapiregistry.googleapis.com" - - -class ApiRegistry: - """Registry that provides McpToolsets for MCP servers registered in API Registry.""" - - def __init__( - self, - api_registry_project_id: str, - location: str = "global", - header_provider: ( - Callable[[ReadonlyContext], dict[str, str]] | None - ) = None, - ): - """Initialize the API Registry. - - Args: - api_registry_project_id: The project ID for the Google Cloud API Registry. - location: The location of the API Registry resources. - header_provider: Optional function to provide additional headers for MCP - server calls. - """ - self.api_registry_project_id = api_registry_project_id - self.location = location - self._credentials, _ = google.auth.default() - self._mcp_servers: dict[str, dict[str, Any]] = {} - self._header_provider = header_provider - - url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers" - - try: - headers = self._get_auth_headers() - headers["Content-Type"] = "application/json" - page_token = None - with httpx.Client() as client: - while True: - params = {} - if page_token: - params["pageToken"] = page_token - - response = client.get(url, headers=headers, params=params) - response.raise_for_status() - data = response.json() - mcp_servers_list = data.get("mcpServers", []) - for server in mcp_servers_list: - server_name = server.get("name", "") - if server_name: - self._mcp_servers[server_name] = server - - page_token = data.get("nextPageToken") - if not page_token: - break - except (httpx.HTTPError, ValueError) as e: - # Handle error in fetching or parsing tool definitions - raise RuntimeError( - f"Error fetching MCP servers from API Registry: {e}" - ) from e - - def get_toolset( - self, - mcp_server_name: str, - tool_filter: ToolPredicate | list[str] | None = None, - tool_name_prefix: str | None = None, - ) -> McpToolset: - """Return the MCP Toolset based on the params. - - Args: - mcp_server_name: Filter to select the MCP server name to get tools from. - tool_filter: Optional filter to select specific tools. Can be a list of - tool names or a ToolPredicate function. - tool_name_prefix: Optional prefix to prepend to the names of the tools - returned by the toolset. - - Returns: - McpToolset: A toolset for the MCP server specified. - """ - server = self._mcp_servers.get(mcp_server_name) - if not server: - raise ValueError( - f"MCP server {mcp_server_name} not found in API Registry." - ) - if not server.get("urls"): - raise ValueError(f"MCP server {mcp_server_name} has no URLs.") - - mcp_server_url = server["urls"][0] - headers = self._get_auth_headers() - - # Only prepend "https://" if the URL doesn't already have a scheme - if not mcp_server_url.startswith(("http://", "https://")): - mcp_server_url = "https://" + mcp_server_url - - return McpToolset( - connection_params=StreamableHTTPConnectionParams( - url=mcp_server_url, - headers=headers, - ), - tool_filter=tool_filter, - tool_name_prefix=tool_name_prefix, - header_provider=self._header_provider, - ) - - def _get_auth_headers(self) -> dict[str, str]: - """Refreshes credentials and returns authorization headers.""" - request = google.auth.transport.requests.Request() - self._credentials.refresh(request) - headers = { - "Authorization": f"Bearer {self._credentials.token}", - } - # Add quota project header if available in ADC - quota_project_id = getattr(self._credentials, "quota_project_id", None) - if quota_project_id: - headers["x-goog-user-project"] = quota_project_id - return headers +warnings.warn( + "google.adk.tools.api_registry is moved to" + " google.adk.integrations.api_registry", + DeprecationWarning, + stacklevel=2, +) diff --git a/tests/unittests/integrations/api_registry/__init__.py b/tests/unittests/integrations/api_registry/__init__.py new file mode 100644 index 00000000..4d9a9249 --- /dev/null +++ b/tests/unittests/integrations/api_registry/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/test_api_registry.py b/tests/unittests/integrations/api_registry/test_api_registry.py similarity index 96% rename from tests/unittests/tools/test_api_registry.py rename to tests/unittests/integrations/api_registry/test_api_registry.py index 59612434..7edaee9f 100644 --- a/tests/unittests/tools/test_api_registry.py +++ b/tests/unittests/integrations/api_registry/test_api_registry.py @@ -18,8 +18,8 @@ from unittest.mock import create_autospec from unittest.mock import MagicMock from unittest.mock import patch -from google.adk.tools import api_registry -from google.adk.tools.api_registry import ApiRegistry +from google.adk.integrations import api_registry +from google.adk.integrations.api_registry import ApiRegistry from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams import httpx @@ -218,7 +218,10 @@ class TestApiRegistry(unittest.IsolatedAsyncioTestCase): ) mock_response.raise_for_status.assert_called_once() - @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch( + "google.adk.integrations.api_registry.api_registry.McpToolset", + autospec=True, + ) @patch("httpx.Client", autospec=True) async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset): mock_response = MagicMock() @@ -245,7 +248,10 @@ class TestApiRegistry(unittest.IsolatedAsyncioTestCase): ) self.assertEqual(toolset, MockMcpToolset.return_value) - @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch( + "google.adk.integrations.api_registry.api_registry.McpToolset", + autospec=True, + ) @patch("httpx.Client", autospec=True) async def test_get_toolset_with_quota_project_id_success( self, MockHttpClient, MockMcpToolset @@ -277,7 +283,10 @@ class TestApiRegistry(unittest.IsolatedAsyncioTestCase): ) self.assertEqual(toolset, MockMcpToolset.return_value) - @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch( + "google.adk.integrations.api_registry.api_registry.McpToolset", + autospec=True, + ) @patch("httpx.Client", autospec=True) async def test_get_toolset_with_filter_and_prefix( self, MockHttpClient, MockMcpToolset @@ -321,7 +330,7 @@ class TestApiRegistry(unittest.IsolatedAsyncioTestCase): with ( patch.object(httpx, "Client", autospec=True) as MockHttpClient, patch.object( - api_registry, "McpToolset", autospec=True + api_registry.api_registry, "McpToolset", autospec=True ) as MockMcpToolset, ): mock_response = create_autospec(httpx.Response, instance=True) From c36a708058163ade061cd3d2f9957231a505a62d Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 4 Mar 2026 14:14:15 -0800 Subject: [PATCH 093/102] fix: Support before_tool_callback and after_tool_callback in Live mode Close #4704 Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 878662637 --- src/google/adk/flows/llm_flows/functions.py | 98 +++++++---- .../llm_flows/test_plugin_tool_callbacks.py | 155 ++++++++++++++++++ 2 files changed, 222 insertions(+), 31 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 66274d3d..24057c37 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -730,41 +730,77 @@ async def _execute_single_function_call_live( # Make a deep copy to avoid being modified. function_response = None - # Handle before_tool_callbacks - iterate through the canonical callback - # list - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break + # Step 1: Check if plugin before_tool_callback overrides the function + # response. + function_response = ( + await invocation_context.plugin_manager.run_before_tool_callback( + tool=tool, tool_args=function_args, tool_context=tool_context + ) + ) + # Step 2: If no overrides are provided from the plugins, further run the + # canonical callback. if function_response is None: - function_response = await _process_function_live_helper( - tool, - tool_context, - function_call, - function_args, - invocation_context, - streaming_lock, - ) + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break - # Calls after_tool_callback if it exists. - altered_function_response = None - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break + # Step 3: Otherwise, proceed calling the tool normally. + if function_response is None: + try: + function_response = await _process_function_live_helper( + tool, + tool_context, + function_call, + function_args, + invocation_context, + streaming_lock, + ) + except Exception as tool_error: + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + function_response = error_response + else: + raise tool_error + # Step 4: Check if plugin after_tool_callback overrides the function + # response. + altered_function_response = ( + await invocation_context.plugin_manager.run_after_tool_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + result=function_response, + ) + ) + + # Step 5: If no overrides are provided from the plugins, further run the + # canonical after_tool_callbacks. + if altered_function_response is None: + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + # Step 6: If alternative response exists from after_tool_callback, use it + # instead of the original function response. if altered_function_response is not None: function_response = altered_function_response diff --git a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py index cc375ad0..3c39e284 100644 --- a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py @@ -19,6 +19,7 @@ from typing import Optional from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.flows.llm_flows.functions import handle_function_calls_async +from google.adk.flows.llm_flows.functions import handle_function_calls_live from google.adk.plugins.base_plugin import BasePlugin from google.adk.tools.base_tool import BaseTool from google.adk.tools.function_tool import FunctionTool @@ -185,5 +186,159 @@ async def test_async_on_tool_error_fallback_to_runner( assert e == mock_error +async def invoke_tool_with_plugin_live( + mock_tool, mock_plugin +) -> Optional[Event]: + """Invokes a tool with a plugin using the live path.""" + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[mock_tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="", plugins=[mock_plugin] + ) + # Build function call event + function_call = types.FunctionCall(name=mock_tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {mock_tool.name: mock_tool} + return await handle_function_calls_live( + invocation_context, + event, + tools_dict, + ) + + +@pytest.mark.asyncio +async def test_live_before_tool_callback(mock_tool, mock_plugin): + mock_plugin.enable_before_tool_callback = True + + result_event = await invoke_tool_with_plugin_live(mock_tool, mock_plugin) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_plugin.before_tool_response + + +@pytest.mark.asyncio +async def test_live_after_tool_callback(mock_tool, mock_plugin): + mock_plugin.enable_after_tool_callback = True + + result_event = await invoke_tool_with_plugin_live(mock_tool, mock_plugin) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_plugin.after_tool_response + + +@pytest.mark.asyncio +async def test_live_on_tool_error_use_plugin_response( + mock_error_tool, mock_plugin +): + mock_plugin.enable_on_tool_error_callback = True + + result_event = await invoke_tool_with_plugin_live( + mock_error_tool, mock_plugin + ) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_plugin.on_tool_error_response + + +@pytest.mark.asyncio +async def test_live_on_tool_error_fallback_to_runner( + mock_error_tool, mock_plugin +): + mock_plugin.enable_on_tool_error_callback = False + + try: + await invoke_tool_with_plugin_live(mock_error_tool, mock_plugin) + except Exception as e: + assert e == mock_error + + +@pytest.mark.asyncio +async def test_live_plugin_before_tool_callback_takes_priority( + mock_tool, mock_plugin +): + """Plugin before_tool_callback should run before agent canonical callbacks.""" + mock_plugin.enable_before_tool_callback = True + + def agent_before_cb(tool, args, tool_context): + return {"agent": "should_not_be_called"} + + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[mock_tool], + before_tool_callback=agent_before_cb, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="", plugins=[mock_plugin] + ) + function_call = types.FunctionCall(name=mock_tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {mock_tool.name: mock_tool} + result_event = await handle_function_calls_live( + invocation_context, event, tools_dict + ) + + assert result_event is not None + part = result_event.content.parts[0] + # Plugin response should win, not the agent callback + assert part.function_response.response == mock_plugin.before_tool_response + + +@pytest.mark.asyncio +async def test_live_plugin_after_tool_callback_takes_priority( + mock_tool, mock_plugin +): + """Plugin after_tool_callback should run before agent canonical callbacks.""" + mock_plugin.enable_after_tool_callback = True + + def agent_after_cb(tool, args, tool_context, tool_response): + return {"agent": "should_not_be_called"} + + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[mock_tool], + after_tool_callback=agent_after_cb, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="", plugins=[mock_plugin] + ) + function_call = types.FunctionCall(name=mock_tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {mock_tool.name: mock_tool} + result_event = await handle_function_calls_live( + invocation_context, event, tools_dict + ) + + assert result_event is not None + part = result_event.content.parts[0] + # Plugin response should win, not the agent callback + assert part.function_response.response == mock_plugin.after_tool_response + + if __name__ == "__main__": pytest.main([__file__]) From 94684874e436c2959cfc90ec346010a6f4fddc49 Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 4 Mar 2026 14:27:26 -0800 Subject: [PATCH 094/102] fix: Expand LiteLLM reasoning extraction to include 'reasoning' field The `_extract_reasoning_value` function now checks for both 'reasoning_content' and 'reasoning' fields in LiteLLM messages, with 'reasoning_content' taking precedence Close #3694 Co-authored-by: George Weale PiperOrigin-RevId: 878668213 --- src/google/adk/models/lite_llm.py | 13 ++- tests/unittests/models/test_litellm.py | 134 +++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 2 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index dad5543f..8c1568cc 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -388,10 +388,18 @@ def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: def _extract_reasoning_value(message: Message | Delta | None) -> Any: - """Fetches the reasoning payload from a LiteLLM message.""" + """Fetches the reasoning payload from a LiteLLM message. + + Checks for both 'reasoning_content' (LiteLLM standard, used by Azure/Foundry, + Ollama via LiteLLM) and 'reasoning' (used by LM Studio, vLLM). + Prioritizes 'reasoning_content' when both are present. + """ if message is None: return None - return message.get("reasoning_content") + reasoning_content = message.get("reasoning_content") + if reasoning_content is not None: + return reasoning_content + return message.get("reasoning") class ChatCompletionFileUrlObject(TypedDict, total=False): @@ -1302,6 +1310,7 @@ def _model_response_to_chunk( or message.get("tool_calls") or message.get("function_call") or message.get("reasoning_content") + or message.get("reasoning") ) if isinstance(response, ModelResponseStream): diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index aa19bfa8..2bd5f7d2 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -27,6 +27,7 @@ import warnings from google.adk.models.lite_llm import _append_fallback_user_content_if_missing from google.adk.models.lite_llm import _content_to_message_param from google.adk.models.lite_llm import _enforce_strict_openai_schema +from google.adk.models.lite_llm import _extract_reasoning_value from google.adk.models.lite_llm import _FILE_ID_REQUIRED_PROVIDERS from google.adk.models.lite_llm import _FINISH_REASON_MAPPING from google.adk.models.lite_llm import _function_declaration_to_tool_param @@ -2285,6 +2286,139 @@ def test_model_response_to_generate_content_response_reasoning_content(): assert response.content.parts[1].text == "Answer" +def test_message_to_generate_content_response_reasoning_field(): + """Test that the 'reasoning' field is supported (LM Studio, vLLM).""" + message = { + "role": "assistant", + "content": "Final answer", + "reasoning": "Thinking process", + } + response = _message_to_generate_content_response(message) + + assert len(response.content.parts) == 2 + thought_part = response.content.parts[0] + text_part = response.content.parts[1] + assert thought_part.text == "Thinking process" + assert thought_part.thought is True + assert text_part.text == "Final answer" + + +def test_model_response_to_generate_content_response_reasoning_field(): + """Test that 'reasoning' field is supported in ModelResponse.""" + model_response = ModelResponse( + model="test-model", + choices=[{ + "message": { + "role": "assistant", + "content": "Result", + "reasoning": "Chain of thought", + }, + "finish_reason": "stop", + }], + ) + + response = _model_response_to_generate_content_response(model_response) + + assert response.content.parts[0].text == "Chain of thought" + assert response.content.parts[0].thought is True + assert response.content.parts[1].text == "Result" + + +def test_reasoning_content_takes_precedence_over_reasoning(): + """Test that 'reasoning_content' is prioritized over 'reasoning'.""" + message = { + "role": "assistant", + "content": "Answer", + "reasoning_content": "LiteLLM standard reasoning", + "reasoning": "Alternative reasoning", + } + response = _message_to_generate_content_response(message) + + assert len(response.content.parts) == 2 + thought_part = response.content.parts[0] + assert thought_part.text == "LiteLLM standard reasoning" + assert thought_part.thought is True + + +def test_extract_reasoning_value_from_reasoning_content(): + """Test extraction from reasoning_content (LiteLLM standard).""" + message = ChatCompletionAssistantMessage( + role="assistant", + content="Answer", + reasoning_content="LiteLLM reasoning", + ) + result = _extract_reasoning_value(message) + assert result == "LiteLLM reasoning" + + +def test_extract_reasoning_value_from_reasoning(): + """Test extraction from reasoning (LM Studio, vLLM).""" + + class MockMessage: + + def __init__(self): + self.role = "assistant" + self.content = "Answer" + self.reasoning = "Alternative reasoning" + + def get(self, key, default=None): + return getattr(self, key, default) + + message = MockMessage() + result = _extract_reasoning_value(message) + assert result == "Alternative reasoning" + + +def test_extract_reasoning_value_dict_reasoning_content(): + """Test extraction from dict with reasoning_content field.""" + message = { + "role": "assistant", + "content": "Answer", + "reasoning_content": "Dict reasoning content", + } + result = _extract_reasoning_value(message) + assert result == "Dict reasoning content" + + +def test_extract_reasoning_value_dict_reasoning(): + """Test extraction from dict with reasoning field.""" + message = { + "role": "assistant", + "content": "Answer", + "reasoning": "Dict reasoning", + } + result = _extract_reasoning_value(message) + assert result == "Dict reasoning" + + +def test_extract_reasoning_value_dict_prefers_reasoning_content(): + """Test that reasoning_content takes precedence over reasoning in dicts.""" + message = { + "role": "assistant", + "content": "Answer", + "reasoning_content": "Primary", + "reasoning": "Secondary", + } + result = _extract_reasoning_value(message) + assert result == "Primary" + + +def test_extract_reasoning_value_none_message(): + """Test that None message returns None.""" + result = _extract_reasoning_value(None) + assert result is None + + +def test_extract_reasoning_value_no_reasoning_fields(): + """Test that None is returned when no reasoning fields exist.""" + message = { + "role": "assistant", + "content": "Answer only", + } + result = _extract_reasoning_value(message) + assert result is None + + def test_parse_tool_calls_from_text_multiple_calls(): text = ( '{"name":"alpha","arguments":{"value":1}}\n' From 3b5937f022adf9286dc41e01e3618071a23eb992 Mon Sep 17 00:00:00 2001 From: Mark Nawar Date: Wed, 4 Mar 2026 14:40:39 -0800 Subject: [PATCH 095/102] fix: filter non-agent directoris from list_agents() Merge https://github.com/google/adk-python/pull/4648 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: #4647 - Related: #3429, #3430 **2. Or, if no issue exists, describe the change:** **Problem:** `AgentLoader.list_agents()` returns every non-hidden subdirectory in the agents directory, regardless of whether it contains a valid agent definition. This causes non-agent directories (e.g. `tmp/`, `data/`, `utils/`) to appear in the `/list-apps` API response. This affects both the ADK web UI agent selector and any production deployment depending on this API. **Solution:** Reuse the existing `_determine_agent_language()` method inside `list_agents()` to verify each candidate directory contains at least one recognized agent file (`root_agent.yaml`, `agent.py`, or `__init__.py`). Directories that fail this check are excluded from the result. This avoids introducing any new methods or abstractions and keeps the check lightweight (filesystem only, no agent imports). ### Testing Plan **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. 27 passed in 2.85s: pytest tests/unittests/cli/utils/test_agent_loader.py -v ======================= 27 passed, 14 warnings in 2.85s ======================== Added `test_list_agents_excludes_non_agent_directories` which creates a temp directory with three valid agent types (package with `__init__.py`, module with `agent.py`, YAML with `root_agent.yaml`) and three non-agent directories, and asserts only the valid agents are listed. **Screenshots / Video:** | Before (non-agent directories listed) | After (only valid agents listed) | |----------------------------------------|----------------------------------| |Image|Image| **Manual End-to-End (E2E) Tests:** 1. Create a project directory containing both valid agent subdirectories and non-agent subdirectories 2. Run `adk web .` 3. Open the web UI and verify only valid agents appear in the agent selector 4. See screenshots below for before/after comparison ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [ ] Any dependent changes have been merged and published in downstream modules. ### Additional context COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4648 from markadelnawar:fix/list-agents-filter-non-agents-dirs 041895610fa0c52f2bf3cf7ba0d072a5c580c1b6 PiperOrigin-RevId: 878674609 --- src/google/adk/cli/utils/agent_loader.py | 19 +++++--- .../unittests/cli/utils/test_agent_loader.py | 45 +++++++++++++++++++ 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py index 8b5805c5..efd24648 100644 --- a/src/google/adk/cli/utils/agent_loader.py +++ b/src/google/adk/cli/utils/agent_loader.py @@ -335,13 +335,18 @@ class AgentLoader(BaseAgentLoader): def list_agents(self) -> list[str]: """Lists all agents available in the agent loader (sorted alphabetically).""" base_path = Path.cwd() / self.agents_dir - agent_names = [ - x - for x in os.listdir(base_path) - if os.path.isdir(os.path.join(base_path, x)) - and not x.startswith(".") - and x != "__pycache__" - ] + agent_names = [] + for x in os.listdir(base_path): + if ( + os.path.isdir(os.path.join(base_path, x)) + and not x.startswith(".") + and x != "__pycache__" + ): + try: + self._determine_agent_language(x) + agent_names.append(x) + except ValueError: + continue agent_names.sort() return agent_names diff --git a/tests/unittests/cli/utils/test_agent_loader.py b/tests/unittests/cli/utils/test_agent_loader.py index f3eb3396..0a7f9fc0 100644 --- a/tests/unittests/cli/utils/test_agent_loader.py +++ b/tests/unittests/cli/utils/test_agent_loader.py @@ -993,3 +993,48 @@ class TestAgentLoader: assert len(detailed_list) == 1 assert detailed_list[0]["name"] == agent_name assert not detailed_list[0]["is_computer_use"] + + def test_list_agents_excludes_non_agent_directories(self): + """Test that list_agents filters out directories without agent definitions.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + valid_package = temp_path / "valid_agent" + valid_package.mkdir() + (valid_package / "__init__.py").write_text(dedent(""" + from google.adk.agents.base_agent import BaseAgent + + class ValidAgent(BaseAgent): + def __init__(self): + super().__init__(name="valid_agent") + + root_agent = ValidAgent() + """)) + + valid_module = temp_path / "module_agent" + valid_module.mkdir() + (valid_module / "agent.py").write_text(dedent(""" + from google.adk.agents.base_agent import BaseAgent + + class ModuleAgent(BaseAgent): + def __init__(self): + super().__init__(name="module_agent") + + root_agent = ModuleAgent() + """)) + + valid_yaml = temp_path / "yaml_agent" + valid_yaml.mkdir() + (valid_yaml / "root_agent.yaml").write_text("name: yaml_agent\n") + + (temp_path / "random_folder").mkdir() + (temp_path / "data").mkdir() + (temp_path / "tmp").mkdir() + + loader = AgentLoader(str(temp_path)) + agents = loader.list_agents() + + assert agents == ["module_agent", "valid_agent", "yaml_agent"] + assert "random_folder" not in agents + assert "data" not in agents + assert "tmp" not in agents From 2addf6b9dacfe87344aeec0101df98d99c23bdb1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 4 Mar 2026 14:59:18 -0800 Subject: [PATCH 096/102] fix: Fix Type Error by initializing user_content as a Content object PiperOrigin-RevId: 878682788 --- src/google/adk/evaluation/evaluation_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index 1d9662bd..725bddc1 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -280,7 +280,7 @@ class EvaluationGenerator: invocations = [] for invocation_id, events in events_by_invocation_id.items(): final_response = None - user_content = "" + user_content = Content(parts=[]) invocation_timestamp = 0 app_details = None if ( From ab4b9526fc9e6c5b651434aa112340eac633f129 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 4 Mar 2026 17:23:58 -0800 Subject: [PATCH 097/102] chore: Move spanner tools to integration folder Added a deprecation warning in the old tools/spanner/__init__.py Co-authored-by: Kathy Wu PiperOrigin-RevId: 878742289 --- contributing/samples/spanner/README.md | 2 +- contributing/samples/spanner/agent.py | 12 +++--- .../samples/spanner_rag_agent/README.md | 18 ++++----- .../samples/spanner_rag_agent/agent.py | 10 ++--- .../adk/integrations/spanner/__init__.py | 40 +++++++++++++++++++ .../{tools => integrations}/spanner/client.py | 0 .../spanner/metadata_tool.py | 0 .../spanner/query_tool.py | 2 +- .../spanner/search_tool.py | 0 .../spanner/settings.py | 0 .../spanner/spanner_credentials.py | 2 +- .../spanner/spanner_toolset.py | 6 +-- .../{tools => integrations}/spanner/utils.py | 2 +- src/google/adk/tools/spanner/__init__.py | 26 +++++------- .../spanner/__init__.py} | 2 +- .../spanner/test_metadata_tool.py | 16 ++++---- .../spanner/test_search_tool.py | 6 +-- .../spanner/test_spanner_client.py | 2 +- .../spanner/test_spanner_credentials.py | 2 +- .../spanner/test_spanner_query_tool.py | 12 +++--- .../spanner/test_spanner_tool_settings.py | 8 ++-- .../spanner/test_spanner_toolset.py | 8 ++-- .../spanner/test_utils.py | 10 ++--- tests/unittests/tools/test_google_tool.py | 2 +- 24 files changed, 110 insertions(+), 78 deletions(-) create mode 100644 src/google/adk/integrations/spanner/__init__.py rename src/google/adk/{tools => integrations}/spanner/client.py (100%) rename src/google/adk/{tools => integrations}/spanner/metadata_tool.py (100%) rename src/google/adk/{tools => integrations}/spanner/query_tool.py (99%) rename src/google/adk/{tools => integrations}/spanner/search_tool.py (100%) rename src/google/adk/{tools => integrations}/spanner/settings.py (100%) rename src/google/adk/{tools => integrations}/spanner/spanner_credentials.py (95%) rename src/google/adk/{tools => integrations}/spanner/spanner_toolset.py (96%) rename src/google/adk/{tools => integrations}/spanner/utils.py (99%) rename tests/{unittests/tools/spanner/__init__ => integration/spanner/__init__.py} (94%) rename tests/{unittests/tools => integration}/spanner/test_metadata_tool.py (94%) rename tests/{unittests/tools => integration}/spanner/test_search_tool.py (99%) rename tests/{unittests/tools => integration}/spanner/test_spanner_client.py (98%) rename tests/{unittests/tools => integration}/spanner/test_spanner_credentials.py (95%) rename tests/{unittests/tools => integration}/spanner/test_spanner_query_tool.py (94%) rename tests/{unittests/tools => integration}/spanner/test_spanner_tool_settings.py (92%) rename tests/{unittests/tools => integration}/spanner/test_spanner_toolset.py (96%) rename tests/{unittests/tools => integration}/spanner/test_utils.py (97%) diff --git a/contributing/samples/spanner/README.md b/contributing/samples/spanner/README.md index ea7f9d83..7b51f833 100644 --- a/contributing/samples/spanner/README.md +++ b/contributing/samples/spanner/README.md @@ -3,7 +3,7 @@ ## Introduction This sample agent demonstrates the Spanner first-party tools in ADK, -distributed via the `google.adk.tools.spanner` module. These tools include: +distributed via the `google.adk.integrations.spanner` module. These tools include: 1. `list_table_names` diff --git a/contributing/samples/spanner/agent.py b/contributing/samples/spanner/agent.py index 36dde572..ec3b1c16 100644 --- a/contributing/samples/spanner/agent.py +++ b/contributing/samples/spanner/agent.py @@ -16,13 +16,13 @@ import os from google.adk.agents.llm_agent import LlmAgent from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.integrations.spanner.settings import Capabilities +from google.adk.integrations.spanner.settings import QueryResultMode +from google.adk.integrations.spanner.settings import SpannerToolSettings +from google.adk.integrations.spanner.spanner_credentials import SpannerCredentialsConfig +from google.adk.integrations.spanner.spanner_toolset import SpannerToolset +import google.adk.integrations.spanner.utils as spanner_tool_utils from google.adk.tools.google_tool import GoogleTool -from google.adk.tools.spanner.settings import Capabilities -from google.adk.tools.spanner.settings import QueryResultMode -from google.adk.tools.spanner.settings import SpannerToolSettings -from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig -from google.adk.tools.spanner.spanner_toolset import SpannerToolset -import google.adk.tools.spanner.utils as spanner_tool_utils from google.adk.tools.tool_context import ToolContext import google.auth from google.auth.credentials import Credentials diff --git a/contributing/samples/spanner_rag_agent/README.md b/contributing/samples/spanner_rag_agent/README.md index 08d134b9..5a473962 100644 --- a/contributing/samples/spanner_rag_agent/README.md +++ b/contributing/samples/spanner_rag_agent/README.md @@ -4,7 +4,7 @@ This sample demonstrates how to build an intelligent Retrieval Augmented Generation (RAG) agent using the flexible, built-in Spanner tools available -in the ADK's `google.adk.tools.spanner` module, including how to create +in the ADK's `google.adk.integrations.spanner` module, including how to create customized Spanner tools by extending the existing ones. [Spanner](https://cloud.google.com/spanner/docs) is a fully managed, @@ -199,10 +199,10 @@ There are a few options to perform similarity search: ```py from google.adk.agents.llm_agent import LlmAgent - from google.adk.tools.spanner.settings import Capabilities - from google.adk.tools.spanner.settings import SpannerToolSettings - from google.adk.tools.spanner.settings import SpannerVectorStoreSettings - from google.adk.tools.spanner.spanner_toolset import SpannerToolset + from google.adk.integrations.spanner.settings import Capabilities + from google.adk.integrations.spanner.settings import SpannerToolSettings + from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings + from google.adk.integrations.spanner.spanner_toolset import SpannerToolset # credentials_config = SpannerCredentialsConfig(...) @@ -267,9 +267,9 @@ There are a few options to perform similarity search: ```py from google.adk.agents.llm_agent import LlmAgent - from google.adk.tools.spanner.settings import Capabilities - from google.adk.tools.spanner.settings import SpannerToolSettings - from google.adk.tools.spanner.spanner_toolset import SpannerToolset + from google.adk.integrations.spanner.settings import Capabilities + from google.adk.integrations.spanner.settings import SpannerToolSettings + from google.adk.integrations.spanner.spanner_toolset import SpannerToolset # credentials_config = SpannerCredentialsConfig(...) @@ -312,7 +312,7 @@ There are a few options to perform similarity search: from google.adk.agents.llm_agent import LlmAgent from google.adk.tools.google_tool import GoogleTool - from google.adk.tools.spanner import search_tool + from google.adk.integrations.spanner import search_tool import google.auth from google.auth.credentials import Credentials diff --git a/contributing/samples/spanner_rag_agent/agent.py b/contributing/samples/spanner_rag_agent/agent.py index cd479c0c..d7233345 100644 --- a/contributing/samples/spanner_rag_agent/agent.py +++ b/contributing/samples/spanner_rag_agent/agent.py @@ -16,11 +16,11 @@ import os from google.adk.agents.llm_agent import LlmAgent from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.tools.spanner.settings import Capabilities -from google.adk.tools.spanner.settings import SpannerToolSettings -from google.adk.tools.spanner.settings import SpannerVectorStoreSettings -from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig -from google.adk.tools.spanner.spanner_toolset import SpannerToolset +from google.adk.integrations.spanner.settings import Capabilities +from google.adk.integrations.spanner.settings import SpannerToolSettings +from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings +from google.adk.integrations.spanner.spanner_credentials import SpannerCredentialsConfig +from google.adk.integrations.spanner.spanner_toolset import SpannerToolset import google.auth # Define an appropriate credential type diff --git a/src/google/adk/integrations/spanner/__init__.py b/src/google/adk/integrations/spanner/__init__.py new file mode 100644 index 00000000..e41b9b79 --- /dev/null +++ b/src/google/adk/integrations/spanner/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Spanner Tools (Experimental). + +Spanner Tools under this module are hand crafted and customized while the tools +under google.adk.tools.google_api_tool are auto generated based on API +definition. The rationales to have customized tool are: + +1. A dedicated Spanner toolset to provide an easier, integrated way to interact +with Spanner database and tables for building AI Agent applications quickly. +2. We want to provide more high-level tools like Search, ML.Predict, and Graph +etc. +3. We want to provide extra access guardrails and controls in those tools. +For example, execute_sql can't arbitrarily mutate existing data. +4. We want to provide Spanner best practices and knowledge assistants for ad-hoc +analytics queries. +5. Use Spanner Toolset for more customization and control to interact with +Spanner database and tables. +""" + +from . import spanner_credentials +from .spanner_toolset import SpannerToolset + +SpannerCredentialsConfig = spanner_credentials.SpannerCredentialsConfig +__all__ = [ + "SpannerToolset", + "SpannerCredentialsConfig", +] diff --git a/src/google/adk/tools/spanner/client.py b/src/google/adk/integrations/spanner/client.py similarity index 100% rename from src/google/adk/tools/spanner/client.py rename to src/google/adk/integrations/spanner/client.py diff --git a/src/google/adk/tools/spanner/metadata_tool.py b/src/google/adk/integrations/spanner/metadata_tool.py similarity index 100% rename from src/google/adk/tools/spanner/metadata_tool.py rename to src/google/adk/integrations/spanner/metadata_tool.py diff --git a/src/google/adk/tools/spanner/query_tool.py b/src/google/adk/integrations/spanner/query_tool.py similarity index 99% rename from src/google/adk/tools/spanner/query_tool.py rename to src/google/adk/integrations/spanner/query_tool.py index 24c1be60..18d0a789 100644 --- a/src/google/adk/tools/spanner/query_tool.py +++ b/src/google/adk/integrations/spanner/query_tool.py @@ -22,7 +22,7 @@ from typing import Callable from google.auth.credentials import Credentials from . import utils -from ..tool_context import ToolContext +from ...tools.tool_context import ToolContext from .settings import QueryResultMode from .settings import SpannerToolSettings diff --git a/src/google/adk/tools/spanner/search_tool.py b/src/google/adk/integrations/spanner/search_tool.py similarity index 100% rename from src/google/adk/tools/spanner/search_tool.py rename to src/google/adk/integrations/spanner/search_tool.py diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/integrations/spanner/settings.py similarity index 100% rename from src/google/adk/tools/spanner/settings.py rename to src/google/adk/integrations/spanner/settings.py diff --git a/src/google/adk/tools/spanner/spanner_credentials.py b/src/google/adk/integrations/spanner/spanner_credentials.py similarity index 95% rename from src/google/adk/tools/spanner/spanner_credentials.py rename to src/google/adk/integrations/spanner/spanner_credentials.py index 84d78bd0..e6909524 100644 --- a/src/google/adk/tools/spanner/spanner_credentials.py +++ b/src/google/adk/integrations/spanner/spanner_credentials.py @@ -16,7 +16,7 @@ from __future__ import annotations from ...features import experimental from ...features import FeatureName -from .._google_credentials import BaseGoogleCredentialsConfig +from ...tools._google_credentials import BaseGoogleCredentialsConfig SPANNER_TOKEN_CACHE_KEY = "spanner_token_cache" SPANNER_DEFAULT_SCOPE = [ diff --git a/src/google/adk/tools/spanner/spanner_toolset.py b/src/google/adk/integrations/spanner/spanner_toolset.py similarity index 96% rename from src/google/adk/tools/spanner/spanner_toolset.py rename to src/google/adk/integrations/spanner/spanner_toolset.py index 089dd1b3..7d1e32a3 100644 --- a/src/google/adk/tools/spanner/spanner_toolset.py +++ b/src/google/adk/integrations/spanner/spanner_toolset.py @@ -19,11 +19,11 @@ from typing import Optional from typing import Union from google.adk.agents.readonly_context import ReadonlyContext -from google.adk.tools.spanner import metadata_tool -from google.adk.tools.spanner import query_tool -from google.adk.tools.spanner import search_tool from typing_extensions import override +from . import metadata_tool +from . import query_tool +from . import search_tool from ...features import experimental from ...features import FeatureName from ...tools.base_tool import BaseTool diff --git a/src/google/adk/tools/spanner/utils.py b/src/google/adk/integrations/spanner/utils.py similarity index 99% rename from src/google/adk/tools/spanner/utils.py rename to src/google/adk/integrations/spanner/utils.py index 9f5efdb7..7818cee7 100644 --- a/src/google/adk/tools/spanner/utils.py +++ b/src/google/adk/integrations/spanner/utils.py @@ -29,7 +29,7 @@ from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from . import client from ...features import experimental from ...features import FeatureName -from ..tool_context import ToolContext +from ...tools.tool_context import ToolContext from .settings import QueryResultMode from .settings import SpannerToolSettings from .settings import SpannerVectorStoreSettings diff --git a/src/google/adk/tools/spanner/__init__.py b/src/google/adk/tools/spanner/__init__.py index e41b9b79..f4a5eec8 100644 --- a/src/google/adk/tools/spanner/__init__.py +++ b/src/google/adk/tools/spanner/__init__.py @@ -12,26 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Spanner Tools (Experimental). +"""Backward compatibility for google.adk.tools.spanner.""" -Spanner Tools under this module are hand crafted and customized while the tools -under google.adk.tools.google_api_tool are auto generated based on API -definition. The rationales to have customized tool are: +import warnings -1. A dedicated Spanner toolset to provide an easier, integrated way to interact -with Spanner database and tables for building AI Agent applications quickly. -2. We want to provide more high-level tools like Search, ML.Predict, and Graph -etc. -3. We want to provide extra access guardrails and controls in those tools. -For example, execute_sql can't arbitrarily mutate existing data. -4. We want to provide Spanner best practices and knowledge assistants for ad-hoc -analytics queries. -5. Use Spanner Toolset for more customization and control to interact with -Spanner database and tables. -""" +from google.adk.integrations.spanner import spanner_credentials +from google.adk.integrations.spanner.spanner_toolset import SpannerToolset -from . import spanner_credentials -from .spanner_toolset import SpannerToolset +warnings.warn( + "google.adk.tools.spanner is moved to google.adk.integrations.spanner.", + DeprecationWarning, + stacklevel=2, +) SpannerCredentialsConfig = spanner_credentials.SpannerCredentialsConfig __all__ = [ diff --git a/tests/unittests/tools/spanner/__init__ b/tests/integration/spanner/__init__.py similarity index 94% rename from tests/unittests/tools/spanner/__init__ rename to tests/integration/spanner/__init__.py index 30cb974f..58d482ea 100644 --- a/tests/unittests/tools/spanner/__init__ +++ b/tests/integration/spanner/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/tests/unittests/tools/spanner/test_metadata_tool.py b/tests/integration/spanner/test_metadata_tool.py similarity index 94% rename from tests/unittests/tools/spanner/test_metadata_tool.py rename to tests/integration/spanner/test_metadata_tool.py index fcfcd4bd..75bdcde6 100644 --- a/tests/unittests/tools/spanner/test_metadata_tool.py +++ b/tests/integration/spanner/test_metadata_tool.py @@ -15,7 +15,7 @@ from unittest.mock import MagicMock from unittest.mock import patch -from google.adk.tools.spanner import metadata_tool +from google.adk.integrations.spanner import metadata_tool from google.cloud.spanner_admin_database_v1.types import DatabaseDialect import pytest @@ -35,7 +35,7 @@ def mock_spanner_ids(): } -@patch("google.adk.tools.spanner.client.get_spanner_client") +@patch("google.adk.integrations.spanner.client.get_spanner_client") def test_list_table_names_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -60,7 +60,7 @@ def test_list_table_names_success( assert result["results"] == ["table1"] -@patch("google.adk.tools.spanner.client.get_spanner_client") +@patch("google.adk.integrations.spanner.client.get_spanner_client") def test_list_table_names_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -76,7 +76,7 @@ def test_list_table_names_error( assert result["error_details"] == "Test Exception" -@patch("google.adk.tools.spanner.client.get_spanner_client") +@patch("google.adk.integrations.spanner.client.get_spanner_client") def test_get_table_schema_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -154,7 +154,7 @@ def test_get_table_schema_success( ) -@patch("google.adk.tools.spanner.client.get_spanner_client") +@patch("google.adk.integrations.spanner.client.get_spanner_client") def test_list_table_indexes_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -192,7 +192,7 @@ def test_list_table_indexes_success( assert result["results"][0]["INDEX_NAME"] == "PRIMARY_KEY" -@patch("google.adk.tools.spanner.client.get_spanner_client") +@patch("google.adk.integrations.spanner.client.get_spanner_client") def test_list_table_indexes_circular_row_fallback_to_string( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -231,7 +231,7 @@ def test_list_table_indexes_circular_row_fallback_to_string( assert isinstance(result["results"][0], str) -@patch("google.adk.tools.spanner.client.get_spanner_client") +@patch("google.adk.integrations.spanner.client.get_spanner_client") def test_list_table_index_columns_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -268,7 +268,7 @@ def test_list_table_index_columns_success( assert result["results"][0]["COLUMN_NAME"] == "col1" -@patch("google.adk.tools.spanner.client.get_spanner_client") +@patch("google.adk.integrations.spanner.client.get_spanner_client") def test_list_named_schemas_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): diff --git a/tests/unittests/tools/spanner/test_search_tool.py b/tests/integration/spanner/test_search_tool.py similarity index 99% rename from tests/unittests/tools/spanner/test_search_tool.py rename to tests/integration/spanner/test_search_tool.py index c6a6c742..6eb69f45 100644 --- a/tests/unittests/tools/spanner/test_search_tool.py +++ b/tests/integration/spanner/test_search_tool.py @@ -15,9 +15,9 @@ from unittest import mock from unittest.mock import MagicMock -from google.adk.tools.spanner import client -from google.adk.tools.spanner import search_tool -from google.adk.tools.spanner import utils +from google.adk.integrations.spanner import client +from google.adk.integrations.spanner import search_tool +from google.adk.integrations.spanner import utils from google.cloud.spanner_admin_database_v1.types import DatabaseDialect import pytest diff --git a/tests/unittests/tools/spanner/test_spanner_client.py b/tests/integration/spanner/test_spanner_client.py similarity index 98% rename from tests/unittests/tools/spanner/test_spanner_client.py rename to tests/integration/spanner/test_spanner_client.py index 142a3796..53430c76 100644 --- a/tests/unittests/tools/spanner/test_spanner_client.py +++ b/tests/integration/spanner/test_spanner_client.py @@ -18,7 +18,7 @@ import os import re from unittest import mock -from google.adk.tools.spanner.client import get_spanner_client +from google.adk.integrations.spanner.client import get_spanner_client from google.auth.exceptions import DefaultCredentialsError from google.oauth2.credentials import Credentials import pytest diff --git a/tests/unittests/tools/spanner/test_spanner_credentials.py b/tests/integration/spanner/test_spanner_credentials.py similarity index 95% rename from tests/unittests/tools/spanner/test_spanner_credentials.py rename to tests/integration/spanner/test_spanner_credentials.py index 84d355f5..713c4367 100644 --- a/tests/unittests/tools/spanner/test_spanner_credentials.py +++ b/tests/integration/spanner/test_spanner_credentials.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig +from google.adk.integrations.spanner.spanner_credentials import SpannerCredentialsConfig # Mock the Google OAuth and API dependencies import google.auth.credentials import google.oauth2.credentials diff --git a/tests/unittests/tools/spanner/test_spanner_query_tool.py b/tests/integration/spanner/test_spanner_query_tool.py similarity index 94% rename from tests/unittests/tools/spanner/test_spanner_query_tool.py rename to tests/integration/spanner/test_spanner_query_tool.py index 928c207d..676cbe65 100644 --- a/tests/unittests/tools/spanner/test_spanner_query_tool.py +++ b/tests/integration/spanner/test_spanner_query_tool.py @@ -17,13 +17,13 @@ from __future__ import annotations import textwrap from unittest import mock +from google.adk.integrations.spanner import query_tool +from google.adk.integrations.spanner import settings +from google.adk.integrations.spanner.settings import QueryResultMode +from google.adk.integrations.spanner.settings import SpannerToolSettings +from google.adk.integrations.spanner.spanner_credentials import SpannerCredentialsConfig +from google.adk.integrations.spanner.spanner_toolset import SpannerToolset from google.adk.tools.base_tool import BaseTool -from google.adk.tools.spanner import query_tool -from google.adk.tools.spanner import settings -from google.adk.tools.spanner.settings import QueryResultMode -from google.adk.tools.spanner.settings import SpannerToolSettings -from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig -from google.adk.tools.spanner.spanner_toolset import SpannerToolset from google.adk.tools.tool_context import ToolContext from google.auth.credentials import Credentials import pytest diff --git a/tests/unittests/tools/spanner/test_spanner_tool_settings.py b/tests/integration/spanner/test_spanner_tool_settings.py similarity index 92% rename from tests/unittests/tools/spanner/test_spanner_tool_settings.py rename to tests/integration/spanner/test_spanner_tool_settings.py index df71dba2..45e31f1a 100644 --- a/tests/unittests/tools/spanner/test_spanner_tool_settings.py +++ b/tests/integration/spanner/test_spanner_tool_settings.py @@ -17,10 +17,10 @@ from __future__ import annotations import warnings from google.adk.features._feature_registry import _WARNED_FEATURES -from google.adk.tools.spanner.settings import Capabilities -from google.adk.tools.spanner.settings import QueryResultMode -from google.adk.tools.spanner.settings import SpannerToolSettings -from google.adk.tools.spanner.settings import SpannerVectorStoreSettings +from google.adk.integrations.spanner.settings import Capabilities +from google.adk.integrations.spanner.settings import QueryResultMode +from google.adk.integrations.spanner.settings import SpannerToolSettings +from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings from pydantic import ValidationError import pytest diff --git a/tests/unittests/tools/spanner/test_spanner_toolset.py b/tests/integration/spanner/test_spanner_toolset.py similarity index 96% rename from tests/unittests/tools/spanner/test_spanner_toolset.py rename to tests/integration/spanner/test_spanner_toolset.py index fe8422e9..92478c7c 100644 --- a/tests/unittests/tools/spanner/test_spanner_toolset.py +++ b/tests/integration/spanner/test_spanner_toolset.py @@ -14,11 +14,11 @@ from __future__ import annotations +from google.adk.integrations.spanner import SpannerCredentialsConfig +from google.adk.integrations.spanner import SpannerToolset +from google.adk.integrations.spanner.settings import SpannerToolSettings +from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings from google.adk.tools.google_tool import GoogleTool -from google.adk.tools.spanner import SpannerCredentialsConfig -from google.adk.tools.spanner import SpannerToolset -from google.adk.tools.spanner.settings import SpannerToolSettings -from google.adk.tools.spanner.settings import SpannerVectorStoreSettings import pytest diff --git a/tests/unittests/tools/spanner/test_utils.py b/tests/integration/spanner/test_utils.py similarity index 97% rename from tests/unittests/tools/spanner/test_utils.py rename to tests/integration/spanner/test_utils.py index fe8d7db4..6986c1b9 100644 --- a/tests/unittests/tools/spanner/test_utils.py +++ b/tests/integration/spanner/test_utils.py @@ -16,11 +16,11 @@ from __future__ import annotations from unittest import mock -from google.adk.tools.spanner import utils as spanner_utils -from google.adk.tools.spanner.settings import SpannerToolSettings -from google.adk.tools.spanner.settings import SpannerVectorStoreSettings -from google.adk.tools.spanner.settings import TableColumn -from google.adk.tools.spanner.settings import VectorSearchIndexSettings +from google.adk.integrations.spanner import utils as spanner_utils +from google.adk.integrations.spanner.settings import SpannerToolSettings +from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings +from google.adk.integrations.spanner.settings import TableColumn +from google.adk.integrations.spanner.settings import VectorSearchIndexSettings from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from google.cloud.spanner_v1 import batch as spanner_batch from google.cloud.spanner_v1 import client as spanner_client_v1 diff --git a/tests/unittests/tools/test_google_tool.py b/tests/unittests/tools/test_google_tool.py index 738edbae..ba242306 100644 --- a/tests/unittests/tools/test_google_tool.py +++ b/tests/unittests/tools/test_google_tool.py @@ -16,11 +16,11 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.integrations.spanner.settings import SpannerToolSettings from google.adk.tools._google_credentials import GoogleCredentialsManager from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.google_tool import GoogleTool -from google.adk.tools.spanner.settings import SpannerToolSettings from google.adk.tools.tool_context import ToolContext # Mock the Google OAuth and API dependencies from google.oauth2.credentials import Credentials From 36e76b98b31fe4a8e5a60d1b5704fcbdd38994c3 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 4 Mar 2026 18:48:27 -0800 Subject: [PATCH 098/102] ADK changes PiperOrigin-RevId: 878768583 --- contributing/samples/spanner/README.md | 2 +- contributing/samples/spanner/agent.py | 12 +++--- .../samples/spanner_rag_agent/README.md | 18 ++++----- .../samples/spanner_rag_agent/agent.py | 10 ++--- .../adk/integrations/spanner/__init__.py | 40 ------------------- src/google/adk/tools/spanner/__init__.py | 26 +++++++----- .../{integrations => tools}/spanner/client.py | 0 .../spanner/metadata_tool.py | 0 .../spanner/query_tool.py | 2 +- .../spanner/search_tool.py | 0 .../spanner/settings.py | 0 .../spanner/spanner_credentials.py | 2 +- .../spanner/spanner_toolset.py | 6 +-- .../{integrations => tools}/spanner/utils.py | 2 +- .../tools/spanner/__init__} | 2 +- .../tools}/spanner/test_metadata_tool.py | 16 ++++---- .../tools}/spanner/test_search_tool.py | 6 +-- .../tools}/spanner/test_spanner_client.py | 2 +- .../spanner/test_spanner_credentials.py | 2 +- .../tools}/spanner/test_spanner_query_tool.py | 12 +++--- .../spanner/test_spanner_tool_settings.py | 8 ++-- .../tools}/spanner/test_spanner_toolset.py | 8 ++-- .../tools}/spanner/test_utils.py | 10 ++--- tests/unittests/tools/test_google_tool.py | 2 +- 24 files changed, 78 insertions(+), 110 deletions(-) delete mode 100644 src/google/adk/integrations/spanner/__init__.py rename src/google/adk/{integrations => tools}/spanner/client.py (100%) rename src/google/adk/{integrations => tools}/spanner/metadata_tool.py (100%) rename src/google/adk/{integrations => tools}/spanner/query_tool.py (99%) rename src/google/adk/{integrations => tools}/spanner/search_tool.py (100%) rename src/google/adk/{integrations => tools}/spanner/settings.py (100%) rename src/google/adk/{integrations => tools}/spanner/spanner_credentials.py (95%) rename src/google/adk/{integrations => tools}/spanner/spanner_toolset.py (96%) rename src/google/adk/{integrations => tools}/spanner/utils.py (99%) rename tests/{integration/spanner/__init__.py => unittests/tools/spanner/__init__} (94%) rename tests/{integration => unittests/tools}/spanner/test_metadata_tool.py (94%) rename tests/{integration => unittests/tools}/spanner/test_search_tool.py (99%) rename tests/{integration => unittests/tools}/spanner/test_spanner_client.py (98%) rename tests/{integration => unittests/tools}/spanner/test_spanner_credentials.py (95%) rename tests/{integration => unittests/tools}/spanner/test_spanner_query_tool.py (94%) rename tests/{integration => unittests/tools}/spanner/test_spanner_tool_settings.py (92%) rename tests/{integration => unittests/tools}/spanner/test_spanner_toolset.py (96%) rename tests/{integration => unittests/tools}/spanner/test_utils.py (97%) diff --git a/contributing/samples/spanner/README.md b/contributing/samples/spanner/README.md index 7b51f833..ea7f9d83 100644 --- a/contributing/samples/spanner/README.md +++ b/contributing/samples/spanner/README.md @@ -3,7 +3,7 @@ ## Introduction This sample agent demonstrates the Spanner first-party tools in ADK, -distributed via the `google.adk.integrations.spanner` module. These tools include: +distributed via the `google.adk.tools.spanner` module. These tools include: 1. `list_table_names` diff --git a/contributing/samples/spanner/agent.py b/contributing/samples/spanner/agent.py index ec3b1c16..36dde572 100644 --- a/contributing/samples/spanner/agent.py +++ b/contributing/samples/spanner/agent.py @@ -16,13 +16,13 @@ import os from google.adk.agents.llm_agent import LlmAgent from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.integrations.spanner.settings import Capabilities -from google.adk.integrations.spanner.settings import QueryResultMode -from google.adk.integrations.spanner.settings import SpannerToolSettings -from google.adk.integrations.spanner.spanner_credentials import SpannerCredentialsConfig -from google.adk.integrations.spanner.spanner_toolset import SpannerToolset -import google.adk.integrations.spanner.utils as spanner_tool_utils from google.adk.tools.google_tool import GoogleTool +from google.adk.tools.spanner.settings import Capabilities +from google.adk.tools.spanner.settings import QueryResultMode +from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig +from google.adk.tools.spanner.spanner_toolset import SpannerToolset +import google.adk.tools.spanner.utils as spanner_tool_utils from google.adk.tools.tool_context import ToolContext import google.auth from google.auth.credentials import Credentials diff --git a/contributing/samples/spanner_rag_agent/README.md b/contributing/samples/spanner_rag_agent/README.md index 5a473962..08d134b9 100644 --- a/contributing/samples/spanner_rag_agent/README.md +++ b/contributing/samples/spanner_rag_agent/README.md @@ -4,7 +4,7 @@ This sample demonstrates how to build an intelligent Retrieval Augmented Generation (RAG) agent using the flexible, built-in Spanner tools available -in the ADK's `google.adk.integrations.spanner` module, including how to create +in the ADK's `google.adk.tools.spanner` module, including how to create customized Spanner tools by extending the existing ones. [Spanner](https://cloud.google.com/spanner/docs) is a fully managed, @@ -199,10 +199,10 @@ There are a few options to perform similarity search: ```py from google.adk.agents.llm_agent import LlmAgent - from google.adk.integrations.spanner.settings import Capabilities - from google.adk.integrations.spanner.settings import SpannerToolSettings - from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings - from google.adk.integrations.spanner.spanner_toolset import SpannerToolset + from google.adk.tools.spanner.settings import Capabilities + from google.adk.tools.spanner.settings import SpannerToolSettings + from google.adk.tools.spanner.settings import SpannerVectorStoreSettings + from google.adk.tools.spanner.spanner_toolset import SpannerToolset # credentials_config = SpannerCredentialsConfig(...) @@ -267,9 +267,9 @@ There are a few options to perform similarity search: ```py from google.adk.agents.llm_agent import LlmAgent - from google.adk.integrations.spanner.settings import Capabilities - from google.adk.integrations.spanner.settings import SpannerToolSettings - from google.adk.integrations.spanner.spanner_toolset import SpannerToolset + from google.adk.tools.spanner.settings import Capabilities + from google.adk.tools.spanner.settings import SpannerToolSettings + from google.adk.tools.spanner.spanner_toolset import SpannerToolset # credentials_config = SpannerCredentialsConfig(...) @@ -312,7 +312,7 @@ There are a few options to perform similarity search: from google.adk.agents.llm_agent import LlmAgent from google.adk.tools.google_tool import GoogleTool - from google.adk.integrations.spanner import search_tool + from google.adk.tools.spanner import search_tool import google.auth from google.auth.credentials import Credentials diff --git a/contributing/samples/spanner_rag_agent/agent.py b/contributing/samples/spanner_rag_agent/agent.py index d7233345..cd479c0c 100644 --- a/contributing/samples/spanner_rag_agent/agent.py +++ b/contributing/samples/spanner_rag_agent/agent.py @@ -16,11 +16,11 @@ import os from google.adk.agents.llm_agent import LlmAgent from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.integrations.spanner.settings import Capabilities -from google.adk.integrations.spanner.settings import SpannerToolSettings -from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings -from google.adk.integrations.spanner.spanner_credentials import SpannerCredentialsConfig -from google.adk.integrations.spanner.spanner_toolset import SpannerToolset +from google.adk.tools.spanner.settings import Capabilities +from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings +from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig +from google.adk.tools.spanner.spanner_toolset import SpannerToolset import google.auth # Define an appropriate credential type diff --git a/src/google/adk/integrations/spanner/__init__.py b/src/google/adk/integrations/spanner/__init__.py deleted file mode 100644 index e41b9b79..00000000 --- a/src/google/adk/integrations/spanner/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Spanner Tools (Experimental). - -Spanner Tools under this module are hand crafted and customized while the tools -under google.adk.tools.google_api_tool are auto generated based on API -definition. The rationales to have customized tool are: - -1. A dedicated Spanner toolset to provide an easier, integrated way to interact -with Spanner database and tables for building AI Agent applications quickly. -2. We want to provide more high-level tools like Search, ML.Predict, and Graph -etc. -3. We want to provide extra access guardrails and controls in those tools. -For example, execute_sql can't arbitrarily mutate existing data. -4. We want to provide Spanner best practices and knowledge assistants for ad-hoc -analytics queries. -5. Use Spanner Toolset for more customization and control to interact with -Spanner database and tables. -""" - -from . import spanner_credentials -from .spanner_toolset import SpannerToolset - -SpannerCredentialsConfig = spanner_credentials.SpannerCredentialsConfig -__all__ = [ - "SpannerToolset", - "SpannerCredentialsConfig", -] diff --git a/src/google/adk/tools/spanner/__init__.py b/src/google/adk/tools/spanner/__init__.py index f4a5eec8..e41b9b79 100644 --- a/src/google/adk/tools/spanner/__init__.py +++ b/src/google/adk/tools/spanner/__init__.py @@ -12,18 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Backward compatibility for google.adk.tools.spanner.""" +"""Spanner Tools (Experimental). -import warnings +Spanner Tools under this module are hand crafted and customized while the tools +under google.adk.tools.google_api_tool are auto generated based on API +definition. The rationales to have customized tool are: -from google.adk.integrations.spanner import spanner_credentials -from google.adk.integrations.spanner.spanner_toolset import SpannerToolset +1. A dedicated Spanner toolset to provide an easier, integrated way to interact +with Spanner database and tables for building AI Agent applications quickly. +2. We want to provide more high-level tools like Search, ML.Predict, and Graph +etc. +3. We want to provide extra access guardrails and controls in those tools. +For example, execute_sql can't arbitrarily mutate existing data. +4. We want to provide Spanner best practices and knowledge assistants for ad-hoc +analytics queries. +5. Use Spanner Toolset for more customization and control to interact with +Spanner database and tables. +""" -warnings.warn( - "google.adk.tools.spanner is moved to google.adk.integrations.spanner.", - DeprecationWarning, - stacklevel=2, -) +from . import spanner_credentials +from .spanner_toolset import SpannerToolset SpannerCredentialsConfig = spanner_credentials.SpannerCredentialsConfig __all__ = [ diff --git a/src/google/adk/integrations/spanner/client.py b/src/google/adk/tools/spanner/client.py similarity index 100% rename from src/google/adk/integrations/spanner/client.py rename to src/google/adk/tools/spanner/client.py diff --git a/src/google/adk/integrations/spanner/metadata_tool.py b/src/google/adk/tools/spanner/metadata_tool.py similarity index 100% rename from src/google/adk/integrations/spanner/metadata_tool.py rename to src/google/adk/tools/spanner/metadata_tool.py diff --git a/src/google/adk/integrations/spanner/query_tool.py b/src/google/adk/tools/spanner/query_tool.py similarity index 99% rename from src/google/adk/integrations/spanner/query_tool.py rename to src/google/adk/tools/spanner/query_tool.py index 18d0a789..24c1be60 100644 --- a/src/google/adk/integrations/spanner/query_tool.py +++ b/src/google/adk/tools/spanner/query_tool.py @@ -22,7 +22,7 @@ from typing import Callable from google.auth.credentials import Credentials from . import utils -from ...tools.tool_context import ToolContext +from ..tool_context import ToolContext from .settings import QueryResultMode from .settings import SpannerToolSettings diff --git a/src/google/adk/integrations/spanner/search_tool.py b/src/google/adk/tools/spanner/search_tool.py similarity index 100% rename from src/google/adk/integrations/spanner/search_tool.py rename to src/google/adk/tools/spanner/search_tool.py diff --git a/src/google/adk/integrations/spanner/settings.py b/src/google/adk/tools/spanner/settings.py similarity index 100% rename from src/google/adk/integrations/spanner/settings.py rename to src/google/adk/tools/spanner/settings.py diff --git a/src/google/adk/integrations/spanner/spanner_credentials.py b/src/google/adk/tools/spanner/spanner_credentials.py similarity index 95% rename from src/google/adk/integrations/spanner/spanner_credentials.py rename to src/google/adk/tools/spanner/spanner_credentials.py index e6909524..84d78bd0 100644 --- a/src/google/adk/integrations/spanner/spanner_credentials.py +++ b/src/google/adk/tools/spanner/spanner_credentials.py @@ -16,7 +16,7 @@ from __future__ import annotations from ...features import experimental from ...features import FeatureName -from ...tools._google_credentials import BaseGoogleCredentialsConfig +from .._google_credentials import BaseGoogleCredentialsConfig SPANNER_TOKEN_CACHE_KEY = "spanner_token_cache" SPANNER_DEFAULT_SCOPE = [ diff --git a/src/google/adk/integrations/spanner/spanner_toolset.py b/src/google/adk/tools/spanner/spanner_toolset.py similarity index 96% rename from src/google/adk/integrations/spanner/spanner_toolset.py rename to src/google/adk/tools/spanner/spanner_toolset.py index 7d1e32a3..089dd1b3 100644 --- a/src/google/adk/integrations/spanner/spanner_toolset.py +++ b/src/google/adk/tools/spanner/spanner_toolset.py @@ -19,11 +19,11 @@ from typing import Optional from typing import Union from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.tools.spanner import metadata_tool +from google.adk.tools.spanner import query_tool +from google.adk.tools.spanner import search_tool from typing_extensions import override -from . import metadata_tool -from . import query_tool -from . import search_tool from ...features import experimental from ...features import FeatureName from ...tools.base_tool import BaseTool diff --git a/src/google/adk/integrations/spanner/utils.py b/src/google/adk/tools/spanner/utils.py similarity index 99% rename from src/google/adk/integrations/spanner/utils.py rename to src/google/adk/tools/spanner/utils.py index 7818cee7..9f5efdb7 100644 --- a/src/google/adk/integrations/spanner/utils.py +++ b/src/google/adk/tools/spanner/utils.py @@ -29,7 +29,7 @@ from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from . import client from ...features import experimental from ...features import FeatureName -from ...tools.tool_context import ToolContext +from ..tool_context import ToolContext from .settings import QueryResultMode from .settings import SpannerToolSettings from .settings import SpannerVectorStoreSettings diff --git a/tests/integration/spanner/__init__.py b/tests/unittests/tools/spanner/__init__ similarity index 94% rename from tests/integration/spanner/__init__.py rename to tests/unittests/tools/spanner/__init__ index 58d482ea..30cb974f 100644 --- a/tests/integration/spanner/__init__.py +++ b/tests/unittests/tools/spanner/__init__ @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License. \ No newline at end of file diff --git a/tests/integration/spanner/test_metadata_tool.py b/tests/unittests/tools/spanner/test_metadata_tool.py similarity index 94% rename from tests/integration/spanner/test_metadata_tool.py rename to tests/unittests/tools/spanner/test_metadata_tool.py index 75bdcde6..fcfcd4bd 100644 --- a/tests/integration/spanner/test_metadata_tool.py +++ b/tests/unittests/tools/spanner/test_metadata_tool.py @@ -15,7 +15,7 @@ from unittest.mock import MagicMock from unittest.mock import patch -from google.adk.integrations.spanner import metadata_tool +from google.adk.tools.spanner import metadata_tool from google.cloud.spanner_admin_database_v1.types import DatabaseDialect import pytest @@ -35,7 +35,7 @@ def mock_spanner_ids(): } -@patch("google.adk.integrations.spanner.client.get_spanner_client") +@patch("google.adk.tools.spanner.client.get_spanner_client") def test_list_table_names_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -60,7 +60,7 @@ def test_list_table_names_success( assert result["results"] == ["table1"] -@patch("google.adk.integrations.spanner.client.get_spanner_client") +@patch("google.adk.tools.spanner.client.get_spanner_client") def test_list_table_names_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -76,7 +76,7 @@ def test_list_table_names_error( assert result["error_details"] == "Test Exception" -@patch("google.adk.integrations.spanner.client.get_spanner_client") +@patch("google.adk.tools.spanner.client.get_spanner_client") def test_get_table_schema_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -154,7 +154,7 @@ def test_get_table_schema_success( ) -@patch("google.adk.integrations.spanner.client.get_spanner_client") +@patch("google.adk.tools.spanner.client.get_spanner_client") def test_list_table_indexes_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -192,7 +192,7 @@ def test_list_table_indexes_success( assert result["results"][0]["INDEX_NAME"] == "PRIMARY_KEY" -@patch("google.adk.integrations.spanner.client.get_spanner_client") +@patch("google.adk.tools.spanner.client.get_spanner_client") def test_list_table_indexes_circular_row_fallback_to_string( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -231,7 +231,7 @@ def test_list_table_indexes_circular_row_fallback_to_string( assert isinstance(result["results"][0], str) -@patch("google.adk.integrations.spanner.client.get_spanner_client") +@patch("google.adk.tools.spanner.client.get_spanner_client") def test_list_table_index_columns_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -268,7 +268,7 @@ def test_list_table_index_columns_success( assert result["results"][0]["COLUMN_NAME"] == "col1" -@patch("google.adk.integrations.spanner.client.get_spanner_client") +@patch("google.adk.tools.spanner.client.get_spanner_client") def test_list_named_schemas_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): diff --git a/tests/integration/spanner/test_search_tool.py b/tests/unittests/tools/spanner/test_search_tool.py similarity index 99% rename from tests/integration/spanner/test_search_tool.py rename to tests/unittests/tools/spanner/test_search_tool.py index 6eb69f45..c6a6c742 100644 --- a/tests/integration/spanner/test_search_tool.py +++ b/tests/unittests/tools/spanner/test_search_tool.py @@ -15,9 +15,9 @@ from unittest import mock from unittest.mock import MagicMock -from google.adk.integrations.spanner import client -from google.adk.integrations.spanner import search_tool -from google.adk.integrations.spanner import utils +from google.adk.tools.spanner import client +from google.adk.tools.spanner import search_tool +from google.adk.tools.spanner import utils from google.cloud.spanner_admin_database_v1.types import DatabaseDialect import pytest diff --git a/tests/integration/spanner/test_spanner_client.py b/tests/unittests/tools/spanner/test_spanner_client.py similarity index 98% rename from tests/integration/spanner/test_spanner_client.py rename to tests/unittests/tools/spanner/test_spanner_client.py index 53430c76..142a3796 100644 --- a/tests/integration/spanner/test_spanner_client.py +++ b/tests/unittests/tools/spanner/test_spanner_client.py @@ -18,7 +18,7 @@ import os import re from unittest import mock -from google.adk.integrations.spanner.client import get_spanner_client +from google.adk.tools.spanner.client import get_spanner_client from google.auth.exceptions import DefaultCredentialsError from google.oauth2.credentials import Credentials import pytest diff --git a/tests/integration/spanner/test_spanner_credentials.py b/tests/unittests/tools/spanner/test_spanner_credentials.py similarity index 95% rename from tests/integration/spanner/test_spanner_credentials.py rename to tests/unittests/tools/spanner/test_spanner_credentials.py index 713c4367..84d355f5 100644 --- a/tests/integration/spanner/test_spanner_credentials.py +++ b/tests/unittests/tools/spanner/test_spanner_credentials.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.adk.integrations.spanner.spanner_credentials import SpannerCredentialsConfig +from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig # Mock the Google OAuth and API dependencies import google.auth.credentials import google.oauth2.credentials diff --git a/tests/integration/spanner/test_spanner_query_tool.py b/tests/unittests/tools/spanner/test_spanner_query_tool.py similarity index 94% rename from tests/integration/spanner/test_spanner_query_tool.py rename to tests/unittests/tools/spanner/test_spanner_query_tool.py index 676cbe65..928c207d 100644 --- a/tests/integration/spanner/test_spanner_query_tool.py +++ b/tests/unittests/tools/spanner/test_spanner_query_tool.py @@ -17,13 +17,13 @@ from __future__ import annotations import textwrap from unittest import mock -from google.adk.integrations.spanner import query_tool -from google.adk.integrations.spanner import settings -from google.adk.integrations.spanner.settings import QueryResultMode -from google.adk.integrations.spanner.settings import SpannerToolSettings -from google.adk.integrations.spanner.spanner_credentials import SpannerCredentialsConfig -from google.adk.integrations.spanner.spanner_toolset import SpannerToolset from google.adk.tools.base_tool import BaseTool +from google.adk.tools.spanner import query_tool +from google.adk.tools.spanner import settings +from google.adk.tools.spanner.settings import QueryResultMode +from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig +from google.adk.tools.spanner.spanner_toolset import SpannerToolset from google.adk.tools.tool_context import ToolContext from google.auth.credentials import Credentials import pytest diff --git a/tests/integration/spanner/test_spanner_tool_settings.py b/tests/unittests/tools/spanner/test_spanner_tool_settings.py similarity index 92% rename from tests/integration/spanner/test_spanner_tool_settings.py rename to tests/unittests/tools/spanner/test_spanner_tool_settings.py index 45e31f1a..df71dba2 100644 --- a/tests/integration/spanner/test_spanner_tool_settings.py +++ b/tests/unittests/tools/spanner/test_spanner_tool_settings.py @@ -17,10 +17,10 @@ from __future__ import annotations import warnings from google.adk.features._feature_registry import _WARNED_FEATURES -from google.adk.integrations.spanner.settings import Capabilities -from google.adk.integrations.spanner.settings import QueryResultMode -from google.adk.integrations.spanner.settings import SpannerToolSettings -from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings +from google.adk.tools.spanner.settings import Capabilities +from google.adk.tools.spanner.settings import QueryResultMode +from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings from pydantic import ValidationError import pytest diff --git a/tests/integration/spanner/test_spanner_toolset.py b/tests/unittests/tools/spanner/test_spanner_toolset.py similarity index 96% rename from tests/integration/spanner/test_spanner_toolset.py rename to tests/unittests/tools/spanner/test_spanner_toolset.py index 92478c7c..fe8422e9 100644 --- a/tests/integration/spanner/test_spanner_toolset.py +++ b/tests/unittests/tools/spanner/test_spanner_toolset.py @@ -14,11 +14,11 @@ from __future__ import annotations -from google.adk.integrations.spanner import SpannerCredentialsConfig -from google.adk.integrations.spanner import SpannerToolset -from google.adk.integrations.spanner.settings import SpannerToolSettings -from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings from google.adk.tools.google_tool import GoogleTool +from google.adk.tools.spanner import SpannerCredentialsConfig +from google.adk.tools.spanner import SpannerToolset +from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings import pytest diff --git a/tests/integration/spanner/test_utils.py b/tests/unittests/tools/spanner/test_utils.py similarity index 97% rename from tests/integration/spanner/test_utils.py rename to tests/unittests/tools/spanner/test_utils.py index 6986c1b9..fe8d7db4 100644 --- a/tests/integration/spanner/test_utils.py +++ b/tests/unittests/tools/spanner/test_utils.py @@ -16,11 +16,11 @@ from __future__ import annotations from unittest import mock -from google.adk.integrations.spanner import utils as spanner_utils -from google.adk.integrations.spanner.settings import SpannerToolSettings -from google.adk.integrations.spanner.settings import SpannerVectorStoreSettings -from google.adk.integrations.spanner.settings import TableColumn -from google.adk.integrations.spanner.settings import VectorSearchIndexSettings +from google.adk.tools.spanner import utils as spanner_utils +from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings +from google.adk.tools.spanner.settings import TableColumn +from google.adk.tools.spanner.settings import VectorSearchIndexSettings from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from google.cloud.spanner_v1 import batch as spanner_batch from google.cloud.spanner_v1 import client as spanner_client_v1 diff --git a/tests/unittests/tools/test_google_tool.py b/tests/unittests/tools/test_google_tool.py index ba242306..738edbae 100644 --- a/tests/unittests/tools/test_google_tool.py +++ b/tests/unittests/tools/test_google_tool.py @@ -16,11 +16,11 @@ from unittest.mock import Mock from unittest.mock import patch -from google.adk.integrations.spanner.settings import SpannerToolSettings from google.adk.tools._google_credentials import GoogleCredentialsManager from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.google_tool import GoogleTool +from google.adk.tools.spanner.settings import SpannerToolSettings from google.adk.tools.tool_context import ToolContext # Mock the Google OAuth and API dependencies from google.oauth2.credentials import Credentials From a61ccf36a2e19a9cb4151c3b73bbf6618c1af451 Mon Sep 17 00:00:00 2001 From: Achuth Narayan Rajagopal Date: Wed, 4 Mar 2026 22:10:44 -0800 Subject: [PATCH 099/102] feat(telemetry): add new gen_ai.agent.version span attribute Link to Issue or Description of Change 1. Link to an existing issue (if applicable): Closes: #issue_number Related: #issue_number 2. Or, if no issue exists, describe the change: Problem: The agent's specific version wasn't being tracked in our telemetry data, limiting our ability to trace issues to specific agent versions. This change introduces the gen_ai.agent.version attribute to span context, defaulting to an empty string if omitted for backwards compatibility. Solution: We want to capture the specific version of an agent during execution by adding an optional version field to the base agent configurations (BaseAgent, BaseAgentConfig). This solution was chosen because exposing this field directly to OpenTelemetry span attributes (gen_ai.agent.version) ensures the version is automatically recorded alongside other existing metadata (like name and description) during invocation. Defaulting the value to an empty string ensures backwards compatibility without breaking existing agent implementations that do not specify a version. Testing Plan - Added test_trace_agent_invocation_with_version to verify that the gen_ai.agent.version attribute is correctly captured when agent.version is populated. - Updated existing telemetry span tests to ensure gen_ai.agent.version safely defaults to an empty string ('') when no version is provided. Unit Tests: - I have added or updated unit tests for my change. - All unit tests pass locally. Manual End-to-End (E2E) Tests: - Tested on Agent Engine and in a local deployment. Checklist [x] I have read the CONTRIBUTING.md document. [x] I have performed a self-review of my own code. [x] I have commented my code, particularly in hard-to-understand areas. [x] I have added tests that prove my fix is effective or that my feature works. [x] New and existing unit tests pass locally with my changes. [x] I have manually tested my changes end-to-end. [x] Any dependent changes have been merged and published in downstream modules. Additional context Add any other context or screenshots about the feature request here. Co-authored-by: Achuth Narayan Rajagopal PiperOrigin-RevId: 878835568 --- src/google/adk/agents/base_agent.py | 7 +++++ src/google/adk/agents/base_agent_config.py | 4 +++ src/google/adk/telemetry/tracing.py | 5 ++++ tests/unittests/telemetry/test_spans.py | 30 ++++++++++++++++++++++ 4 files changed, 46 insertions(+) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index dec85690..ac54bb68 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -115,6 +115,12 @@ class BaseAgent(BaseModel): Agent name cannot be "user", since it's reserved for end-user's input. """ + version: str = '' + """The agent's version. + + Version of the agent being invoked. Used to identify the Agent involved in telemetry. + """ + description: str = '' """Description about the agent's capability. @@ -680,6 +686,7 @@ class BaseAgent(BaseModel): kwargs: Dict[str, Any] = { 'name': config.name, + 'version': config.version, 'description': config.description, } if config.sub_agents: diff --git a/src/google/adk/agents/base_agent_config.py b/src/google/adk/agents/base_agent_config.py index 3859cb35..9dca68c5 100644 --- a/src/google/adk/agents/base_agent_config.py +++ b/src/google/adk/agents/base_agent_config.py @@ -55,6 +55,10 @@ class BaseAgentConfig(BaseModel): name: str = Field(description='Required. The name of the agent.') + version: str = Field( + default='', description='Optional. The version of the agent.' + ) + description: str = Field( default='', description='Optional. The description of the agent.' ) diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index 707bc313..8b777ea8 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -82,6 +82,8 @@ OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = ( USER_CONTENT_ELIDED = '' +GEN_AI_AGENT_VERSION = 'gen_ai.agent.version' + # Needed to avoid circular imports if TYPE_CHECKING: from ..agents.base_agent import BaseAgent @@ -155,6 +157,7 @@ def trace_agent_invocation( span.set_attribute(GEN_AI_AGENT_DESCRIPTION, agent.description) span.set_attribute(GEN_AI_AGENT_NAME, agent.name) + span.set_attribute(GEN_AI_AGENT_VERSION, agent.version) span.set_attribute(GEN_AI_CONVERSATION_ID, ctx.session.id) @@ -455,6 +458,7 @@ def use_generate_content_span( USER_ID: invocation_context.session.user_id, 'gcp.vertex.agent.event_id': model_response_event.id, 'gcp.vertex.agent.invocation_id': invocation_context.invocation_id, + GEN_AI_AGENT_VERSION: invocation_context.agent.version, } if ( _is_gemini_agent(invocation_context.agent) @@ -489,6 +493,7 @@ async def use_inference_span( USER_ID: invocation_context.session.user_id, 'gcp.vertex.agent.event_id': model_response_event.id, 'gcp.vertex.agent.invocation_id': invocation_context.invocation_id, + GEN_AI_AGENT_VERSION: invocation_context.agent.version, } if ( _is_gemini_agent(invocation_context.agent) diff --git a/tests/unittests/telemetry/test_spans.py b/tests/unittests/telemetry/test_spans.py index 3c061e42..81897ba3 100644 --- a/tests/unittests/telemetry/test_spans.py +++ b/tests/unittests/telemetry/test_spans.py @@ -24,6 +24,7 @@ from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.telemetry.tracing import ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS +from google.adk.telemetry.tracing import GEN_AI_AGENT_VERSION from google.adk.telemetry.tracing import trace_agent_invocation from google.adk.telemetry.tracing import trace_call_llm from google.adk.telemetry.tracing import trace_inference_result @@ -119,6 +120,33 @@ async def test_trace_agent_invocation(mock_span_fixture): mock.call('gen_ai.operation.name', 'invoke_agent'), mock.call('gen_ai.agent.description', agent.description), mock.call('gen_ai.agent.name', agent.name), + mock.call(GEN_AI_AGENT_VERSION, ''), + mock.call( + 'gen_ai.conversation.id', + invocation_context.session.id, + ), + ] + mock_span_fixture.set_attribute.assert_has_calls( + expected_calls, any_order=True + ) + assert mock_span_fixture.set_attribute.call_count == len(expected_calls) + + +@pytest.mark.asyncio +async def test_trace_agent_invocation_with_version(mock_span_fixture): + """Test trace_agent_invocation sets span attributes correctly when version is provided.""" + agent = LlmAgent(name='test_llm_agent', model='gemini-pro') + agent.description = 'Test agent description' + agent.version = '1.0.0' + invocation_context = await _create_invocation_context(agent) + + trace_agent_invocation(mock_span_fixture, agent, invocation_context) + + expected_calls = [ + mock.call('gen_ai.operation.name', 'invoke_agent'), + mock.call('gen_ai.agent.description', agent.description), + mock.call('gen_ai.agent.name', agent.name), + mock.call(GEN_AI_AGENT_VERSION, agent.version), mock.call( 'gen_ai.conversation.id', invocation_context.session.id, @@ -767,6 +795,7 @@ async def test_generate_content_span( mock_span.set_attributes.assert_called_once_with({ GEN_AI_AGENT_NAME: invocation_context.agent.name, + GEN_AI_AGENT_VERSION: '', GEN_AI_CONVERSATION_ID: invocation_context.session.id, USER_ID: invocation_context.session.user_id, 'gcp.vertex.agent.event_id': 'event-123', @@ -1087,6 +1116,7 @@ async def test_generate_content_span_with_experimental_semconv( mock_span.set_attributes.assert_called_once_with({ GEN_AI_AGENT_NAME: invocation_context.agent.name, + GEN_AI_AGENT_VERSION: '', GEN_AI_CONVERSATION_ID: invocation_context.session.id, USER_ID: invocation_context.session.user_id, 'gcp.vertex.agent.event_id': 'event-123', From feefadfcc9e4ccc8379a1da35a8e36451ab08d46 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 5 Mar 2026 06:20:06 -0800 Subject: [PATCH 100/102] ADK changes PiperOrigin-RevId: 879030011 --- src/google/adk/agents/base_agent.py | 7 ----- src/google/adk/agents/base_agent_config.py | 4 --- src/google/adk/telemetry/tracing.py | 5 ---- tests/unittests/telemetry/test_spans.py | 30 ---------------------- 4 files changed, 46 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index ac54bb68..dec85690 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -115,12 +115,6 @@ class BaseAgent(BaseModel): Agent name cannot be "user", since it's reserved for end-user's input. """ - version: str = '' - """The agent's version. - - Version of the agent being invoked. Used to identify the Agent involved in telemetry. - """ - description: str = '' """Description about the agent's capability. @@ -686,7 +680,6 @@ class BaseAgent(BaseModel): kwargs: Dict[str, Any] = { 'name': config.name, - 'version': config.version, 'description': config.description, } if config.sub_agents: diff --git a/src/google/adk/agents/base_agent_config.py b/src/google/adk/agents/base_agent_config.py index 9dca68c5..3859cb35 100644 --- a/src/google/adk/agents/base_agent_config.py +++ b/src/google/adk/agents/base_agent_config.py @@ -55,10 +55,6 @@ class BaseAgentConfig(BaseModel): name: str = Field(description='Required. The name of the agent.') - version: str = Field( - default='', description='Optional. The version of the agent.' - ) - description: str = Field( default='', description='Optional. The description of the agent.' ) diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index 8b777ea8..707bc313 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -82,8 +82,6 @@ OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = ( USER_CONTENT_ELIDED = '' -GEN_AI_AGENT_VERSION = 'gen_ai.agent.version' - # Needed to avoid circular imports if TYPE_CHECKING: from ..agents.base_agent import BaseAgent @@ -157,7 +155,6 @@ def trace_agent_invocation( span.set_attribute(GEN_AI_AGENT_DESCRIPTION, agent.description) span.set_attribute(GEN_AI_AGENT_NAME, agent.name) - span.set_attribute(GEN_AI_AGENT_VERSION, agent.version) span.set_attribute(GEN_AI_CONVERSATION_ID, ctx.session.id) @@ -458,7 +455,6 @@ def use_generate_content_span( USER_ID: invocation_context.session.user_id, 'gcp.vertex.agent.event_id': model_response_event.id, 'gcp.vertex.agent.invocation_id': invocation_context.invocation_id, - GEN_AI_AGENT_VERSION: invocation_context.agent.version, } if ( _is_gemini_agent(invocation_context.agent) @@ -493,7 +489,6 @@ async def use_inference_span( USER_ID: invocation_context.session.user_id, 'gcp.vertex.agent.event_id': model_response_event.id, 'gcp.vertex.agent.invocation_id': invocation_context.invocation_id, - GEN_AI_AGENT_VERSION: invocation_context.agent.version, } if ( _is_gemini_agent(invocation_context.agent) diff --git a/tests/unittests/telemetry/test_spans.py b/tests/unittests/telemetry/test_spans.py index 81897ba3..3c061e42 100644 --- a/tests/unittests/telemetry/test_spans.py +++ b/tests/unittests/telemetry/test_spans.py @@ -24,7 +24,6 @@ from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.telemetry.tracing import ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS -from google.adk.telemetry.tracing import GEN_AI_AGENT_VERSION from google.adk.telemetry.tracing import trace_agent_invocation from google.adk.telemetry.tracing import trace_call_llm from google.adk.telemetry.tracing import trace_inference_result @@ -120,33 +119,6 @@ async def test_trace_agent_invocation(mock_span_fixture): mock.call('gen_ai.operation.name', 'invoke_agent'), mock.call('gen_ai.agent.description', agent.description), mock.call('gen_ai.agent.name', agent.name), - mock.call(GEN_AI_AGENT_VERSION, ''), - mock.call( - 'gen_ai.conversation.id', - invocation_context.session.id, - ), - ] - mock_span_fixture.set_attribute.assert_has_calls( - expected_calls, any_order=True - ) - assert mock_span_fixture.set_attribute.call_count == len(expected_calls) - - -@pytest.mark.asyncio -async def test_trace_agent_invocation_with_version(mock_span_fixture): - """Test trace_agent_invocation sets span attributes correctly when version is provided.""" - agent = LlmAgent(name='test_llm_agent', model='gemini-pro') - agent.description = 'Test agent description' - agent.version = '1.0.0' - invocation_context = await _create_invocation_context(agent) - - trace_agent_invocation(mock_span_fixture, agent, invocation_context) - - expected_calls = [ - mock.call('gen_ai.operation.name', 'invoke_agent'), - mock.call('gen_ai.agent.description', agent.description), - mock.call('gen_ai.agent.name', agent.name), - mock.call(GEN_AI_AGENT_VERSION, agent.version), mock.call( 'gen_ai.conversation.id', invocation_context.session.id, @@ -795,7 +767,6 @@ async def test_generate_content_span( mock_span.set_attributes.assert_called_once_with({ GEN_AI_AGENT_NAME: invocation_context.agent.name, - GEN_AI_AGENT_VERSION: '', GEN_AI_CONVERSATION_ID: invocation_context.session.id, USER_ID: invocation_context.session.user_id, 'gcp.vertex.agent.event_id': 'event-123', @@ -1116,7 +1087,6 @@ async def test_generate_content_span_with_experimental_semconv( mock_span.set_attributes.assert_called_once_with({ GEN_AI_AGENT_NAME: invocation_context.agent.name, - GEN_AI_AGENT_VERSION: '', GEN_AI_CONVERSATION_ID: invocation_context.session.id, USER_ID: invocation_context.session.user_id, 'gcp.vertex.agent.event_id': 'event-123', From bcf38fa2bac2f0d1ab74e07e01eb5160bad1d6dc Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Thu, 5 Mar 2026 10:52:52 -0800 Subject: [PATCH 101/102] feat: Enhance BigQuery plugin schema upgrades and error reporting This change introduces several improvements to the BigQuery Agent Analytics Plugin: * **Fix 1 (High):** Error callbacks (`on_model_error_callback`, `on_tool_error_callback`) now emit `status="ERROR"` instead of defaulting to `"OK"`. * **Fix 2 (Medium):** Schema upgrade now detects missing sub-fields in nested RECORD columns via a new recursive helper. The version label is now stamped only after the `update_table` call succeeds, ensuring failures can be retried. * **Fix 3 (Medium):** Multi-loop `shutdown()` now drains batch processors on non-current event loops using `run_coroutine_threadsafe` before closing transports. * **Fix 4 (Medium):** Session state is truncated before logging to prevent oversized payloads. * **Fix 5 (Low):** String system prompts are now truncated during content parsing. * **Fix 6 (Low):** Removed the unused `_HITL_TOOL_NAMES` frozenset. Co-authored-by: Haiyuan Cao PiperOrigin-RevId: 879147684 --- .../bigquery_agent_analytics_plugin.py | 198 ++++++++- .../test_bigquery_agent_analytics_plugin.py | 401 +++++++++++++++++- 2 files changed, 576 insertions(+), 23 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 0f43de6b..ce028cf4 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -28,6 +28,13 @@ import json import logging import mimetypes import os + +# Enable gRPC fork support so child processes created via os.fork() +# can safely create new gRPC channels. Must be set before grpc's +# C-core is loaded (which happens through the google.api_core +# imports below). setdefault respects any explicit user override. +os.environ.setdefault("GRPC_ENABLE_FORK_SUPPORT", "1") + import random import time from types import MappingProxyType @@ -76,19 +83,29 @@ tracer = trace.get_tracer( _SCHEMA_VERSION = "1" _SCHEMA_VERSION_LABEL_KEY = "adk_schema_version" -# Human-in-the-loop (HITL) tool names that receive additional -# dedicated event types alongside the normal TOOL_* events. -_HITL_TOOL_NAMES = frozenset({ - "adk_request_credential", - "adk_request_confirmation", - "adk_request_input", -}) _HITL_EVENT_MAP = MappingProxyType({ "adk_request_credential": "HITL_CREDENTIAL_REQUEST", "adk_request_confirmation": "HITL_CONFIRMATION_REQUEST", "adk_request_input": "HITL_INPUT_REQUEST", }) +# Track all living plugin instances so the fork handler can reset +# them proactively in the child, before _ensure_started runs. +_LIVE_PLUGINS: weakref.WeakSet = weakref.WeakSet() + + +def _after_fork_in_child() -> None: + """Reset every living plugin instance after os.fork().""" + for plugin in list(_LIVE_PLUGINS): + try: + plugin._reset_runtime_state() + except Exception: + pass + + +if hasattr(os, "register_at_fork"): + os.register_at_fork(after_in_child=_after_fork_in_child) + def _safe_callback(func): """Decorator that catches and logs exceptions in plugin callbacks. @@ -1407,7 +1424,10 @@ class HybridContentParser: if content.config and getattr(content.config, "system_instruction", None): si = content.config.system_instruction if isinstance(si, str): - json_payload["system_prompt"] = si + truncated_si, trunc = process_text(si) + if trunc: + is_truncated = True + json_payload["system_prompt"] = truncated_si else: summary, parts, trunc = await self._parse_content_object(si) if trunc: @@ -1855,6 +1875,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): self._schema = None self.arrow_schema = None self._init_pid = os.getpid() + _LIVE_PLUGINS.add(self) def _cleanup_stale_loop_states(self) -> None: """Removes entries for event loops that have been closed.""" @@ -2142,9 +2163,73 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): exc_info=True, ) + @staticmethod + def _schema_fields_match( + existing: list[bq_schema.SchemaField], + desired: list[bq_schema.SchemaField], + ) -> tuple[ + list[bq_schema.SchemaField], + list[bq_schema.SchemaField], + ]: + """Compares existing vs desired schema fields recursively. + + Returns: + A tuple of (new_top_level_fields, updated_record_fields). + ``new_top_level_fields`` are fields in *desired* that are + entirely absent from *existing*. + ``updated_record_fields`` are RECORD fields that exist in + both but have new sub-fields in *desired*; each entry is a + copy of the existing field with the missing sub-fields + appended. + """ + existing_by_name = {f.name: f for f in existing} + new_fields: list[bq_schema.SchemaField] = [] + updated_records: list[bq_schema.SchemaField] = [] + + for desired_field in desired: + existing_field = existing_by_name.get(desired_field.name) + if existing_field is None: + new_fields.append(desired_field) + elif ( + desired_field.field_type == "RECORD" + and existing_field.field_type == "RECORD" + and desired_field.fields + ): + # Recurse into nested RECORD fields. + sub_new, sub_updated = ( + BigQueryAgentAnalyticsPlugin._schema_fields_match( + list(existing_field.fields), + list(desired_field.fields), + ) + ) + if sub_new or sub_updated: + # Build a merged sub-field list. + merged_sub = list(existing_field.fields) + # Replace updated nested records in-place. + updated_names = {f.name for f in sub_updated} + merged_sub = [ + next(u for u in sub_updated if u.name == f.name) + if f.name in updated_names + else f + for f in merged_sub + ] + # Append entirely new sub-fields. + merged_sub.extend(sub_new) + # Rebuild via API representation to preserve all + # existing field attributes (policy_tags, etc.). + api_repr = existing_field.to_api_repr() + api_repr["fields"] = [sf.to_api_repr() for sf in merged_sub] + updated_records.append(bq_schema.SchemaField.from_api_repr(api_repr)) + + return new_fields, updated_records + def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None: """Adds missing columns to an existing table (additive only). + Handles nested RECORD fields by recursing into sub-fields. + The version label is only stamped after a successful update + so that a failed attempt is retried on the next run. + Args: existing_table: The current BigQuery table object. """ @@ -2154,24 +2239,43 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): if stored_version == _SCHEMA_VERSION: return - existing_names = {f.name for f in existing_table.schema} - new_fields = [f for f in self._schema if f.name not in existing_names] + new_fields, updated_records = self._schema_fields_match( + list(existing_table.schema), list(self._schema) + ) - if new_fields: - merged = list(existing_table.schema) + new_fields + if new_fields or updated_records: + # Build merged top-level schema. + updated_names = {f.name for f in updated_records} + merged = [ + next(u for u in updated_records if u.name == f.name) + if f.name in updated_names + else f + for f in existing_table.schema + ] + merged.extend(new_fields) existing_table.schema = merged + + change_desc = [] + if new_fields: + change_desc.append(f"new columns {[f.name for f in new_fields]}") + if updated_records: + change_desc.append( + f"updated RECORD fields {[f.name for f in updated_records]}" + ) logger.info( - "Auto-upgrading table %s: adding columns %s", + "Auto-upgrading table %s: %s", self.full_table_id, - [f.name for f in new_fields], + ", ".join(change_desc), ) - # Always stamp the version label so we skip on next run. - labels = dict(existing_table.labels or {}) - labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION - existing_table.labels = labels - try: + # Stamp the version label inside the try block so that + # on failure the label is NOT persisted and the next run + # retries the upgrade. + labels = dict(existing_table.labels or {}) + labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION + existing_table.labels = labels + update_fields = ["schema", "labels"] self.client.update_table(existing_table, update_fields) except Exception as e: @@ -2243,6 +2347,22 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): if loop in self._loop_state_by_loop: await self._loop_state_by_loop[loop].batch_processor.shutdown(timeout=t) + # 1b. Drain batch processors on other (non-current) loops. + for other_loop, state in self._loop_state_by_loop.items(): + if other_loop is loop or other_loop.is_closed(): + continue + try: + future = asyncio.run_coroutine_threadsafe( + state.batch_processor.shutdown(timeout=t), + other_loop, + ) + future.result(timeout=t) + except Exception: + logger.warning( + "Could not drain batch processor on loop %s", + other_loop, + ) + # 2. Close clients for all states for state in self._loop_state_by_loop.values(): if state.write_client and getattr( @@ -2298,6 +2418,38 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): process. Pure-data fields like ``_schema`` and ``arrow_schema`` are kept because they are safe across fork. """ + logger.warning( + "Fork detected (parent PID %s, child PID %s). Resetting" + " gRPC state for BigQuery analytics plugin. Note: gRPC" + " bidirectional streaming (used by the BigQuery Storage" + " Write API) is not fork-safe. If writes hang or time" + " out, configure the 'spawn' start method at your program" + " entry-point before creating child processes:" + " multiprocessing.set_start_method('spawn')", + self._init_pid, + os.getpid(), + ) + # Best-effort: close inherited gRPC channels so broken + # finalizers don't interfere with newly created channels. + # For grpc.aio channels, close() is a coroutine. We cannot + # await here (called from sync context / fork handler), so + # we skip async channels and only close sync ones. + for loop_state in self._loop_state_by_loop.values(): + wc = getattr(loop_state, "write_client", None) + transport = getattr(wc, "transport", None) + if transport is not None: + try: + channel = getattr(transport, "_grpc_channel", None) + if channel is not None and hasattr(channel, "close"): + result = channel.close() + # If close() returned a coroutine (grpc.aio channel), + # discard it to avoid unawaited-coroutine warnings. + if asyncio.iscoroutine(result): + result.close() + except Exception: + pass + + # Clear all runtime state. self._setup_lock = None self.client = None self._loop_state_by_loop = {} @@ -2442,7 +2594,11 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): # Include session state if non-empty (contains user-set metadata # like gchat thread-id, customer_id, etc.) if session.state: - session_meta["state"] = dict(session.state) + truncated_state, _ = _recursive_smart_truncate( + dict(session.state), + self.config.max_content_length, + ) + session_meta["state"] = truncated_state attrs["session_metadata"] = session_meta except Exception: pass @@ -2988,6 +3144,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): "LLM_ERROR", callback_context, event_data=EventData( + status="ERROR", error_message=str(error), latency_ms=duration, span_id_override=None if has_ambient else span_id, @@ -3110,6 +3267,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): raw_content=content_dict, is_truncated=is_truncated, event_data=EventData( + status="ERROR", error_message=str(error), latency_ms=duration, span_id_override=None if has_ambient else span_id, diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 5d87a17c..a39eb932 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -17,6 +17,7 @@ import asyncio import contextlib import dataclasses import json +import os from unittest import mock from google.adk.agents import base_agent @@ -1734,6 +1735,7 @@ class TestBigQueryAgentAnalyticsPlugin: _assert_common_fields(log_entry, "LLM_ERROR") assert log_entry["content"] is None assert log_entry["error_message"] == "LLM failed" + assert log_entry["status"] == "ERROR" @pytest.mark.asyncio async def test_on_tool_error_callback_logs_correctly( @@ -1761,6 +1763,7 @@ class TestBigQueryAgentAnalyticsPlugin: assert content_dict["tool"] == "MyTool" assert content_dict["args"] == {"param": "value"} assert log_entry["error_message"] == "Tool timed out" + assert log_entry["status"] == "ERROR" @pytest.mark.asyncio async def test_table_creation_options( @@ -4829,7 +4832,6 @@ class TestForkSafety: # _ensure_started should detect PID mismatch and reset await plugin._ensure_started() # After reset + re-init, _init_pid should match current - import os assert plugin._init_pid == os.getpid() assert plugin._started is True @@ -4884,8 +4886,6 @@ class TestForkSafety: assert plugin._schema == ["kept"] assert plugin.arrow_schema == "kept_arrow" - import os - assert plugin._init_pid == os.getpid() def test_getstate_resets_pid(self): @@ -4920,6 +4920,134 @@ class TestForkSafety: await new_plugin.shutdown() +class TestForkGrpcSafety: + """Tests for gRPC fork safety enhancements.""" + + def _make_plugin(self): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + + def test_grpc_fork_env_var_set(self): + """GRPC_ENABLE_FORK_SUPPORT should be '1' after import.""" + + assert os.environ.get("GRPC_ENABLE_FORK_SUPPORT") == "1" + + def test_register_at_fork_resets_all_instances(self): + """_after_fork_in_child resets all living plugin instances.""" + p1 = self._make_plugin() + p2 = self._make_plugin() + p1._started = True + p2._started = True + p1._init_pid = -1 + p2._init_pid = -1 + + bigquery_agent_analytics_plugin._after_fork_in_child() + + assert p1._started is False + assert p2._started is False + assert p1._init_pid == os.getpid() + assert p2._init_pid == os.getpid() + + def test_dead_plugin_removed_from_live_set(self): + """WeakSet should not hold dead plugin references.""" + p = self._make_plugin() + assert p in bigquery_agent_analytics_plugin._LIVE_PLUGINS + pid = id(p) + del p + # After deletion, the WeakSet should no longer contain it. + for alive in bigquery_agent_analytics_plugin._LIVE_PLUGINS: + assert id(alive) != pid + + def test_reset_closes_inherited_sync_transports(self): + """_reset_runtime_state closes inherited sync gRPC channels.""" + plugin = self._make_plugin() + mock_channel = mock.MagicMock() + mock_channel.close.return_value = None # sync close + mock_transport = mock.MagicMock() + mock_transport._grpc_channel = mock_channel + mock_wc = mock.MagicMock() + mock_wc.transport = mock_transport + + mock_loop_state = mock.MagicMock() + mock_loop_state.write_client = mock_wc + + plugin._loop_state_by_loop = {mock.MagicMock(): mock_loop_state} + plugin._init_pid = -1 + + plugin._reset_runtime_state() + + mock_channel.close.assert_called_once() + + def test_reset_discards_async_channel_close_coroutine(self): + """Async channel close() returns a coroutine; must not warn.""" + import warnings + + plugin = self._make_plugin() + + async def _async_close(): + pass + + mock_channel = mock.MagicMock() + mock_channel.close.return_value = _async_close() + mock_transport = mock.MagicMock() + mock_transport._grpc_channel = mock_channel + mock_wc = mock.MagicMock() + mock_wc.transport = mock_transport + + mock_loop_state = mock.MagicMock() + mock_loop_state.write_client = mock_wc + + plugin._loop_state_by_loop = {mock.MagicMock(): mock_loop_state} + plugin._init_pid = -1 + + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + # Must not raise RuntimeWarning for unawaited coroutine + plugin._reset_runtime_state() + + mock_channel.close.assert_called_once() + + def test_transport_close_exception_swallowed(self): + """close() raising should not prevent reset from completing.""" + plugin = self._make_plugin() + mock_channel = mock.MagicMock() + mock_channel.close.side_effect = RuntimeError("broken channel") + mock_transport = mock.MagicMock() + mock_transport._grpc_channel = mock_channel + mock_wc = mock.MagicMock() + mock_wc.transport = mock_transport + + mock_loop_state = mock.MagicMock() + mock_loop_state.write_client = mock_wc + + plugin._loop_state_by_loop = {mock.MagicMock(): mock_loop_state} + plugin._init_pid = -1 + + # Should not raise + plugin._reset_runtime_state() + + assert plugin._started is False + assert plugin._loop_state_by_loop == {} + + def test_reset_logs_fork_warning(self): + """_reset_runtime_state logs a warning with 'Fork detected'.""" + plugin = self._make_plugin() + plugin._init_pid = -1 + + with mock.patch.object( + bigquery_agent_analytics_plugin.logger, "warning" + ) as mock_warn: + plugin._reset_runtime_state() + + mock_warn.assert_called_once() + assert "Fork detected" in mock_warn.call_args[0][0] + + # ============================================================================== # Analytics Views Tests # ============================================================================== @@ -6057,3 +6185,270 @@ class TestAfterRunCleanupExceptionSafety: assert bigquery_agent_analytics_plugin._root_agent_name_ctx.get() is None provider.shutdown() + + +class TestStringSystemPromptTruncation: + """Tests that a string system prompt is truncated in parse().""" + + @pytest.mark.asyncio + async def test_long_string_system_prompt_is_truncated(self): + """A string system_instruction exceeding max_content_length is truncated.""" + parser = bigquery_agent_analytics_plugin.HybridContentParser( + offloader=None, + trace_id="test-trace", + span_id="test-span", + max_length=50, + ) + long_prompt = "A" * 200 + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + contents=[types.Content(parts=[types.Part(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction=long_prompt, + ), + ) + payload, _, is_truncated = await parser.parse(llm_request) + assert is_truncated + assert len(payload["system_prompt"]) < 200 + assert "TRUNCATED" in payload["system_prompt"] + + +class TestSessionStateTruncation: + """Tests that session state is truncated in _enrich_attributes.""" + + @pytest.mark.asyncio + async def test_oversized_session_state_is_truncated( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + mock_session, + invocation_context, + ): + """Session state with large values is truncated.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + max_content_length=30, + ) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + await plugin._ensure_started() + + # Set a large session state value. + large_value = "X" * 200 + type(mock_session).state = mock.PropertyMock( + return_value={"big_key": large_value} + ) + + callback_ctx = CallbackContext(invocation_context=invocation_context) + event_data = bigquery_agent_analytics_plugin.EventData() + attrs = plugin._enrich_attributes(event_data, callback_ctx) + state = attrs["session_metadata"]["state"] + assert len(state["big_key"]) < 200 + assert "TRUNCATED" in state["big_key"] + await plugin.shutdown() + + +class TestSchemaUpgradeNestedFields: + """Tests for nested RECORD field detection in schema upgrade.""" + + def _make_plugin(self): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + auto_schema_upgrade=True, + ) + with mock.patch("google.cloud.bigquery.Client"): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + plugin.client = mock.MagicMock() + plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}" + return plugin + + def test_nested_field_detected(self): + """A new sub-field in a RECORD triggers an upgrade.""" + plugin = self._make_plugin() + + existing_record = bigquery.SchemaField( + "metadata", + "RECORD", + fields=[ + bigquery.SchemaField("key", "STRING"), + ], + ) + desired_record = bigquery.SchemaField( + "metadata", + "RECORD", + fields=[ + bigquery.SchemaField("key", "STRING"), + bigquery.SchemaField("value", "STRING"), + ], + ) + plugin._schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + desired_record, + ] + + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + existing_record, + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + # Find the metadata field and check it has both sub-fields. + metadata_field = next( + f for f in updated_table.schema if f.name == "metadata" + ) + sub_names = {sf.name for sf in metadata_field.fields} + assert "key" in sub_names + assert "value" in sub_names + + def test_version_label_not_stamped_on_failure(self): + """A failed update_table does not persist the version label.""" + plugin = self._make_plugin() + plugin._schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("new_col", "STRING"), + ] + + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin.client.update_table.side_effect = Exception("network error") + + # Should not raise. + plugin._ensure_schema_exists() + + # The label is set on the table object before update_table is + # called, but since update_table failed the label was never + # persisted remotely. On the next run the stored_version will + # still be None (from the real BQ table) so the upgrade retries. + # We verify that update_table was actually attempted. + plugin.client.update_table.assert_called_once() + + def test_nested_upgrade_preserves_policy_tags(self): + """RECORD field metadata (e.g. policy_tags) is preserved on upgrade.""" + from google.cloud.bigquery import schema as bq_schema + + plugin = self._make_plugin() + + existing_record = bigquery.SchemaField( + "metadata", + "RECORD", + policy_tags=bq_schema.PolicyTagList( + names=["projects/p/locations/us/taxonomies/t/policyTags/pt"] + ), + fields=[ + bigquery.SchemaField("key", "STRING"), + ], + ) + desired_record = bigquery.SchemaField( + "metadata", + "RECORD", + fields=[ + bigquery.SchemaField("key", "STRING"), + bigquery.SchemaField("value", "STRING"), + ], + ) + plugin._schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + desired_record, + ] + + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + existing_record, + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + metadata_field = next( + f for f in updated_table.schema if f.name == "metadata" + ) + # Sub-fields were merged. + sub_names = {sf.name for sf in metadata_field.fields} + assert "key" in sub_names + assert "value" in sub_names + # policy_tags preserved from the existing field. + assert metadata_field.policy_tags is not None + assert ( + "projects/p/locations/us/taxonomies/t/policyTags/pt" + in metadata_field.policy_tags.names + ) + + +class TestMultiLoopShutdownDrainsOtherLoops: + """Tests that shutdown() drains batch processors on other loops.""" + + @pytest.mark.asyncio + async def test_other_loop_batch_processor_drained( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + ): + """Shutdown drains batch_processor.shutdown on non-current loops.""" + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + await plugin._ensure_started() + + # Create a mock "other" loop with a mock batch processor. + other_loop = mock.MagicMock(spec=asyncio.AbstractEventLoop) + other_loop.is_closed.return_value = False + + mock_other_bp = mock.AsyncMock() + mock_other_write_client = mock.MagicMock() + mock_other_write_client.transport = mock.AsyncMock() + + other_state = bigquery_agent_analytics_plugin._LoopState( + write_client=mock_other_write_client, + batch_processor=mock_other_bp, + ) + plugin._loop_state_by_loop[other_loop] = other_state + + # Patch run_coroutine_threadsafe to verify it's called for + # the other loop's batch_processor. Close the coroutine arg + # to avoid "coroutine was never awaited" RuntimeWarning. + mock_future = mock.MagicMock() + mock_future.result.return_value = None + + def _fake_run_coroutine_threadsafe(coro, loop): + coro.close() + return mock_future + + with mock.patch.object( + asyncio, + "run_coroutine_threadsafe", + side_effect=_fake_run_coroutine_threadsafe, + ) as mock_rcts: + await plugin.shutdown() + + # Verify run_coroutine_threadsafe was called with + # the other loop. + mock_rcts.assert_called() + call_args = mock_rcts.call_args + assert call_args[0][1] is other_loop From 44a5e6bdb8e8f02891e72b65ef883f108c506f6a Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Thu, 5 Mar 2026 13:52:23 -0800 Subject: [PATCH 102/102] feat: Add support for ADK tools in SkillToolset To use ADK tools, users can specify the tool name in a skill object's `additional_tools` and pass the tool in when initializing a SkillToolset. Co-authored-by: Kathy Wu PiperOrigin-RevId: 879230409 --- contributing/samples/skills_agent/agent.py | 39 ++++++++++ .../skills/weather_skill/SKILL.md | 10 +++ src/google/adk/skills/models.py | 22 ++++-- src/google/adk/tools/skill_toolset.py | 71 ++++++++++++++++- tests/unittests/skills/test_models.py | 31 ++++++++ tests/unittests/tools/test_skill_toolset.py | 78 ++++++++++++++++++- 6 files changed, 239 insertions(+), 12 deletions(-) create mode 100644 contributing/samples/skills_agent/skills/weather_skill/SKILL.md diff --git a/contributing/samples/skills_agent/agent.py b/contributing/samples/skills_agent/agent.py index 9232545a..6d5355db 100644 --- a/contributing/samples/skills_agent/agent.py +++ b/contributing/samples/skills_agent/agent.py @@ -20,7 +20,44 @@ from google.adk import Agent from google.adk.code_executors.unsafe_local_code_executor import UnsafeLocalCodeExecutor from google.adk.skills import load_skill_from_dir from google.adk.skills import models +from google.adk.tools.base_tool import BaseTool from google.adk.tools.skill_toolset import SkillToolset +from google.genai import types + + +class GetTimezoneTool(BaseTool): + """A tool to get the timezone for a given location.""" + + def __init__(self): + super().__init__( + name="get_timezone", + description="Returns the timezone for a given location.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration | None: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters_json_schema={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the timezone for.", + }, + }, + "required": ["location"], + }, + ) + + async def run_async(self, *, args: dict, tool_context) -> str: + return f"The timezone for {args['location']} is UTC+00:00." + + +def get_current_humidity(location: str) -> str: + """Returns the current humidity for a given location.""" + return f"The humidity in {location} is 45%." + greeting_skill = models.Skill( frontmatter=models.Frontmatter( @@ -28,6 +65,7 @@ greeting_skill = models.Skill( description=( "A friendly greeting skill that can say hello to a specific person." ), + metadata={"adk_additional_tools": ["get_timezone"]}, ), instructions=( "Step 1: Read the 'references/hello_world.txt' file to understand how" @@ -49,6 +87,7 @@ weather_skill = load_skill_from_dir( # be used in production environments. my_skill_toolset = SkillToolset( skills=[greeting_skill, weather_skill], + additional_tools=[GetTimezoneTool(), get_current_humidity], code_executor=UnsafeLocalCodeExecutor(), ) diff --git a/contributing/samples/skills_agent/skills/weather_skill/SKILL.md b/contributing/samples/skills_agent/skills/weather_skill/SKILL.md new file mode 100644 index 00000000..ea79220a --- /dev/null +++ b/contributing/samples/skills_agent/skills/weather_skill/SKILL.md @@ -0,0 +1,10 @@ +--- +name: weather-skill +description: A skill that provides weather information based on reference data. +metadata: + adk_additional_tools: + - get_current_humidity +--- + +Step 1: Check 'references/weather_info.md' for the current weather. +Step 2: Provide the weather update to the user. diff --git a/src/google/adk/skills/models.py b/src/google/adk/skills/models.py index f98b0f10..f7674cd9 100644 --- a/src/google/adk/skills/models.py +++ b/src/google/adk/skills/models.py @@ -17,6 +17,7 @@ from __future__ import annotations import re +from typing import Any from typing import Optional import unicodedata @@ -37,11 +38,13 @@ class Frontmatter(BaseModel): (required). license: License for the skill (optional). compatibility: Compatibility information for the skill (optional). - allowed_tools: Tool patterns the skill requires (optional, experimental). - Accepts both ``allowed_tools`` and the YAML-friendly ``allowed-tools`` - key. + allowed_tools: A space-delimited list of tools that are pre-approved to + run (optional, experimental). Accepts both ``allowed_tools`` and the + YAML-friendly ``allowed-tools`` key. For more details, see + https://agentskills.io/specification#allowed-tools-field. metadata: Key-value pairs for client-specific properties (defaults to - empty dict). + empty dict). For example, to include additional tools, use the + ``adk_additional_tools`` key with a list of tools. """ model_config = ConfigDict( @@ -58,7 +61,16 @@ class Frontmatter(BaseModel): alias="allowed-tools", serialization_alias="allowed-tools", ) - metadata: dict[str, str] = {} + metadata: dict[str, Any] = {} + + @field_validator("metadata") + @classmethod + def _validate_metadata(cls, v: dict[str, Any]) -> dict[str, Any]: + if "adk_additional_tools" in v: + tools = v["adk_additional_tools"] + if not isinstance(tools, list): + raise ValueError("adk_additional_tools must be a list of strings") + return v @field_validator("name") @classmethod diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index 12411b41..81ce0c45 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -37,9 +37,11 @@ from ..skills import models from ..skills import prompt from .base_tool import BaseTool from .base_toolset import BaseToolset +from .function_tool import FunctionTool from .tool_context import ToolContext if TYPE_CHECKING: + from ..agents.llm_agent import ToolUnion from ..models.llm_request import LlmRequest logger = logging.getLogger("google_adk." + __name__) @@ -138,6 +140,15 @@ class LoadSkillTool(BaseTool): "error_code": "SKILL_NOT_FOUND", } + # Record skill activation in agent state for tool resolution. + agent_name = tool_context.agent_name + state_key = f"_adk_activated_skill_{agent_name}" + + activated_skills = list(tool_context.state.get(state_key, [])) + if skill_name not in activated_skills: + activated_skills.append(skill_name) + tool_context.state[state_key] = activated_skills + return { "skill_name": skill_name, "instructions": skill.instructions, @@ -586,6 +597,7 @@ class SkillToolset(BaseToolset): *, code_executor: Optional[BaseCodeExecutor] = None, script_timeout: int = _DEFAULT_SCRIPT_TIMEOUT, + additional_tools: list[ToolUnion] | None = None, ): """Initializes the SkillToolset. @@ -609,20 +621,73 @@ class SkillToolset(BaseToolset): self._code_executor = code_executor self._script_timeout = script_timeout + self._provided_tools_by_name = {} + for tool_union in additional_tools or []: + if isinstance(tool_union, BaseTool): + self._provided_tools_by_name[tool_union.name] = tool_union + elif callable(tool_union): + ft = FunctionTool(tool_union) + self._provided_tools_by_name[ft.name] = ft + # Initialize core skill tools self._tools = [ ListSkillsTool(self), LoadSkillTool(self), LoadSkillResourceTool(self), + RunSkillScriptTool(self), ] - # Always add RunSkillScriptTool, relies on invocation_context fallback if _code_executor is None - self._tools.append(RunSkillScriptTool(self)) async def get_tools( self, readonly_context: ReadonlyContext | None = None ) -> list[BaseTool]: """Returns the list of tools in this toolset.""" - return self._tools + dynamic_tools = await self._resolve_additional_tools_from_state( + readonly_context + ) + return self._tools + dynamic_tools + + async def _resolve_additional_tools_from_state( + self, readonly_context: ReadonlyContext | None + ) -> list[BaseTool]: + """Resolves tools listed in the "adk_additional_tools" metadata of skills.""" + + if not readonly_context: + return [] + + agent_name = readonly_context.agent_name + state_key = f"_adk_activated_skill_{agent_name}" + activated_skills = readonly_context.state.get(state_key, []) + + if not activated_skills: + return [] + + additional_tool_names = set() + for skill_name in activated_skills: + skill = self._skills.get(skill_name) + if skill: + additional_tools = skill.frontmatter.metadata.get( + "adk_additional_tools" + ) + if additional_tools: + additional_tool_names.update(additional_tools) + + if not additional_tool_names: + return [] + + resolved_tools = [] + existing_tool_names = {t.name for t in self._tools} + for name in additional_tool_names: + if name in self._provided_tools_by_name: + tool = self._provided_tools_by_name[name] + if tool.name in existing_tool_names: + logger.error( + "Tool name collision: tool '%s' already exists.", tool.name + ) + continue + resolved_tools.append(tool) + existing_tool_names.add(tool.name) + + return resolved_tools def _get_skill(self, name: str) -> models.Skill | None: """Retrieves a skill by name.""" diff --git a/tests/unittests/skills/test_models.py b/tests/unittests/skills/test_models.py index 3bc7fd30..5685e9d8 100644 --- a/tests/unittests/skills/test_models.py +++ b/tests/unittests/skills/test_models.py @@ -173,3 +173,34 @@ def test_allowed_tools_serialization_alias(): dumped = fm.model_dump(by_alias=True) assert "allowed-tools" in dumped assert dumped["allowed-tools"] == "tool-pattern" + + +def test_metadata_adk_additional_tools_list(): + fm = models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "metadata": {"adk_additional_tools": ["tool1", "tool2"]}, + }) + assert fm.metadata["adk_additional_tools"] == ["tool1", "tool2"] + + +def test_metadata_adk_additional_tools_rejected_as_string(): + with pytest.raises( + ValidationError, match="adk_additional_tools must be a list of strings" + ): + models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "metadata": {"adk_additional_tools": "tool1 tool2"}, + }) + + +def test_metadata_adk_additional_tools_invalid_type(): + with pytest.raises( + ValidationError, match="adk_additional_tools must be a list of strings" + ): + models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "metadata": {"adk_additional_tools": 123}, + }) diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index cbccecdb..7ebf4f40 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=redefined-outer-name,g-import-not-at-top,protected-access - - +import logging from unittest import mock from google.adk.code_executors.base_code_executor import BaseCodeExecutor @@ -145,7 +143,13 @@ def mock_skill2(mock_skill2_frontmatter): @pytest.fixture def tool_context_instance(): """Fixture for tool context.""" - return mock.create_autospec(tool_context.ToolContext, instance=True) + ctx = mock.create_autospec(tool_context.ToolContext, instance=True) + ctx._invocation_context = mock.MagicMock() + ctx._invocation_context.agent = mock.MagicMock() + ctx._invocation_context.agent.name = "test_agent" + ctx._invocation_context.agent_states = {} + ctx.agent_name = "test_agent" + return ctx # SkillToolset tests @@ -361,6 +365,10 @@ def _make_tool_context_with_agent(agent=None): ctx = mock.MagicMock(spec=tool_context.ToolContext) ctx._invocation_context = mock.MagicMock() ctx._invocation_context.agent = agent or mock.MagicMock() + ctx._invocation_context.agent.name = "test_agent" + ctx._invocation_context.agent_states = {} + ctx.agent_name = "test_agent" + ctx.state = {} return ctx @@ -1202,3 +1210,65 @@ async def test_execute_script_binary_content_packaged(): assert "b'\\x00\\x01\\x02'" in code_input.code # Wrapper code handles binary with 'wb' mode assert "'wb' if isinstance(content, bytes)" in code_input.code + + +@pytest.mark.asyncio +async def test_skill_toolset_dynamic_tool_resolution(mock_skill1): + # Set up a skill with additional_tools in metadata + mock_skill1.frontmatter.metadata = { + "adk_additional_tools": ["my_custom_tool", "my_func"] + } + mock_skill1.name = "skill1" + + # Prepare additional tools + custom_tool = mock.create_autospec(skill_toolset.BaseTool, instance=True) + custom_tool.name = "my_custom_tool" + + def my_func(): + """My function description.""" + pass + + toolset = skill_toolset.SkillToolset( + [mock_skill1], + additional_tools=[custom_tool, my_func], + ) + + ctx = _make_tool_context_with_agent() + # Initial tools (only core) + tools = await toolset.get_tools(readonly_context=ctx) + assert len(tools) == 4 + + # Activate skill + load_tool = skill_toolset.LoadSkillTool(toolset) + await load_tool.run_async(args={"name": "skill1"}, tool_context=ctx) + + # Dynamic tools should now be resolved + tools = await toolset.get_tools(readonly_context=ctx) + tool_names = {t.name for t in tools} + assert "my_custom_tool" in tool_names + assert "my_func" in tool_names + + # Check specific tool resolution details + my_func_tool = next(t for t in tools if t.name == "my_func") + assert isinstance(my_func_tool, skill_toolset.FunctionTool) + assert my_func_tool.description == "My function description." + + +@pytest.mark.asyncio +async def test_skill_toolset_resolution_error_handling(mock_skill1, caplog): + mock_skill1.frontmatter.metadata = { + "adk_additional_tools": ["nonexistent_tool"] + } + mock_skill1.name = "skill1" + toolset = skill_toolset.SkillToolset([mock_skill1]) + ctx = _make_tool_context_with_agent() + + # Activate skill + load_tool = skill_toolset.LoadSkillTool(toolset) + await load_tool.run_async(args={"name": "skill1"}, tool_context=ctx) + + with caplog.at_level(logging.WARNING): + tools = await toolset.get_tools(readonly_context=ctx) + + # Should still return basic skill tools + assert len(tools) == 4