From e6da4172924ecc36ffc2535199c450a2a51c7bcc Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Tue, 10 Feb 2026 14:35:54 -0800 Subject: [PATCH] fix: propagate grounding and citation metadata in streaming responses Co-authored-by: Sasha Sobran PiperOrigin-RevId: 868324488 --- src/google/adk/utils/streaming_utils.py | 10 + ...st_vertex_ai_search_grounding_streaming.py | 334 ++++++++++++++++++ tests/unittests/utils/test_streaming_utils.py | 104 ++++++ 3 files changed, 448 insertions(+) create mode 100644 tests/integration/test_vertex_ai_search_grounding_streaming.py diff --git a/src/google/adk/utils/streaming_utils.py b/src/google/adk/utils/streaming_utils.py index 09689a4d..808541f2 100644 --- a/src/google/adk/utils/streaming_utils.py +++ b/src/google/adk/utils/streaming_utils.py @@ -36,6 +36,8 @@ class StreamingResponseAggregator: self._text = '' self._thought_text = '' self._usage_metadata = None + self._grounding_metadata: Optional[types.GroundingMetadata] = None + self._citation_metadata: Optional[types.CitationMetadata] = None self._response = None # For progressive SSE streaming mode: accumulate parts in order @@ -251,6 +253,10 @@ class StreamingResponseAggregator: self._response = response llm_response = LlmResponse.create(response) self._usage_metadata = llm_response.usage_metadata + if llm_response.grounding_metadata: + self._grounding_metadata = llm_response.grounding_metadata + if llm_response.citation_metadata: + self._citation_metadata = llm_response.citation_metadata # ========== Progressive SSE Streaming (new feature) ========== # Save finish_reason for final aggregation @@ -347,6 +353,8 @@ class StreamingResponseAggregator: return LlmResponse( content=types.ModelContent(parts=final_parts), + grounding_metadata=self._grounding_metadata, + citation_metadata=self._citation_metadata, error_code=None if finish_reason == types.FinishReason.STOP else finish_reason, @@ -374,6 +382,8 @@ class StreamingResponseAggregator: candidate = self._response.candidates[0] return LlmResponse( content=types.ModelContent(parts=parts), + grounding_metadata=self._grounding_metadata, + citation_metadata=self._citation_metadata, error_code=None if candidate.finish_reason == types.FinishReason.STOP else candidate.finish_reason, diff --git a/tests/integration/test_vertex_ai_search_grounding_streaming.py b/tests/integration/test_vertex_ai_search_grounding_streaming.py new file mode 100644 index 00000000..226fd845 --- /dev/null +++ b/tests/integration/test_vertex_ai_search_grounding_streaming.py @@ -0,0 +1,334 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for grounding metadata preservation in SSE streaming. + +Verifies that grounding_metadata from VertexAiSearchTool reaches the final +non-partial event in both progressive and non-progressive SSE streaming modes. + +Prerequisites: + - GOOGLE_CLOUD_PROJECT env var set to a GCP project with Vertex AI enabled + - Discovery Engine API enabled (discoveryengine.googleapis.com) + - Authenticated via `gcloud auth application-default login` + +Usage: + GOOGLE_CLOUD_PROJECT=my-project pytest + tests/integration/test_vertex_ai_search_grounding_streaming.py -v -s +""" + +from __future__ import annotations + +import json +import os +import time +import uuid + +from google.adk.features._feature_registry import FeatureName +from google.adk.features._feature_registry import temporary_feature_override +from google.genai import types +import pytest + +_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT", "") +_LOCATION = os.environ.get("GOOGLE_CLOUD_LOCATION", "global") +_COLLECTION = "default_collection" +_DATA_STORE_ID = f"adk-grounding-test-{uuid.uuid4().hex[:8]}" +_DATA_STORE_DISPLAY_NAME = "ADK Grounding Integration Test" +_MODEL = "gemini-2.0-flash" + +_TEST_DOCUMENTS = ( + { + "id": "doc-adk-overview", + "title": "ADK Overview", + "content": ( + "The Agent Development Kit (ADK) is an open-source framework by" + " Google for building AI agents. ADK supports multi-agent" + " architectures, tool use, and integrates with Gemini models." + " ADK was first released in April 2025." + ), + }, + { + "id": "doc-adk-tools", + "title": "ADK Built-in Tools", + "content": ( + "ADK provides built-in tools including VertexAiSearchTool for" + " grounded search, GoogleSearchTool for web search, and" + " CodeExecutionTool for running code. The VertexAiSearchTool" + " returns grounding metadata with citations pointing to source" + " documents." + ), + }, +) + + +def _parent_path() -> str: + return f"projects/{_PROJECT}/locations/{_LOCATION}/collections/{_COLLECTION}" + + +def _data_store_path() -> str: + return f"{_parent_path()}/dataStores/{_DATA_STORE_ID}" + + +@pytest.fixture(scope="module") +def project_id(): + if not _PROJECT: + pytest.skip("GOOGLE_CLOUD_PROJECT env var not set") + return _PROJECT + + +@pytest.fixture(scope="module") +def data_store_resource(project_id) -> str: + """Create a Vertex AI Search data store with test documents.""" + from google.api_core.exceptions import AlreadyExists + from google.cloud import discoveryengine_v1beta as discoveryengine + + ds_client = discoveryengine.DataStoreServiceClient() + doc_client = discoveryengine.DocumentServiceClient() + + # Create data store + try: + request = discoveryengine.CreateDataStoreRequest( + parent=_parent_path(), + data_store=discoveryengine.DataStore( + display_name=_DATA_STORE_DISPLAY_NAME, + industry_vertical=discoveryengine.IndustryVertical.GENERIC, + solution_types=[discoveryengine.SolutionType.SOLUTION_TYPE_SEARCH], + content_config=discoveryengine.DataStore.ContentConfig.NO_CONTENT, + ), + data_store_id=_DATA_STORE_ID, + ) + operation = ds_client.create_data_store(request=request) + print(f"\nCreating data store '{_DATA_STORE_ID}'...") + operation.result(timeout=120) + print("Data store created.") + except AlreadyExists: + print(f"\nData store '{_DATA_STORE_ID}' already exists, reusing.") + + # Ingest test documents + branch = f"{_data_store_path()}/branches/default_branch" + for doc_data in _TEST_DOCUMENTS: + json_data = json.dumps({ + "title": doc_data["title"], + "description": doc_data["content"], + }) + doc = discoveryengine.Document( + id=doc_data["id"], + json_data=json_data, + ) + try: + doc_client.create_document( + parent=branch, + document=doc, + document_id=doc_data["id"], + ) + print(f" Created document: {doc_data['id']}") + except AlreadyExists: + doc_client.update_document( + document=discoveryengine.Document( + name=f"{branch}/documents/{doc_data['id']}", + json_data=json_data, + ), + ) + print(f" Updated document: {doc_data['id']}") + + print("Waiting 5s for indexing...") + time.sleep(5) + + yield _data_store_path() + + # Cleanup — best-effort, ignore errors from Discovery Engine LRO + try: + operation = ds_client.delete_data_store(name=_data_store_path()) + operation.result(timeout=120) + print(f"\nDeleted data store '{_DATA_STORE_ID}'.") + except Exception as e: + print(f"\nFailed to delete data store '{_DATA_STORE_ID}': {e}") + + +class TestIntegrationVertexAiSearchGrounding: + """Integration tests hitting real Vertex AI with VertexAiSearchTool.""" + + @pytest.mark.parametrize("llm_backend", ["VERTEX"], indirect=True) + @pytest.mark.parametrize( + "progressive_sse, label", + [ + (True, "Progressive SSE"), + (False, "Non-Progressive SSE"), + ], + ) + @pytest.mark.asyncio + async def test_grounding_metadata_with_sse_streaming( + self, project_id, data_store_resource, progressive_sse, label + ): + """Verifies grounding_metadata in SSE streaming modes.""" + from google.adk.agents.llm_agent import LlmAgent + from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool + + agent = LlmAgent( + name="test_agent", + model=_MODEL, + tools=[VertexAiSearchTool(data_store_id=data_store_resource)], + instruction="Answer questions using the search tool.", + ) + + with temporary_feature_override( + FeatureName.PROGRESSIVE_SSE_STREAMING, progressive_sse + ): + all_events, saved_events = await self._run_agent_streaming( + agent, project_id + ) + + self._report_events(label, all_events, saved_events) + + saved_with_grounding = [e for e in saved_events if e["has_grounding"]] + assert ( + saved_with_grounding + ), f"No saved (non-partial) events have grounding_metadata with {label}." + + @pytest.mark.parametrize("llm_backend", ["VERTEX"], indirect=True) + @pytest.mark.asyncio + async def test_grounding_metadata_without_streaming( + self, project_id, data_store_resource + ): + """Without streaming, grounding_metadata should always be present.""" + from google.adk.agents.llm_agent import LlmAgent + from google.adk.agents.run_config import RunConfig + from google.adk.agents.run_config import StreamingMode + from google.adk.runners import Runner + from google.adk.sessions.in_memory_session_service import InMemorySessionService + from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool + from google.adk.utils.context_utils import Aclosing + + agent = LlmAgent( + name="test_agent", + model=_MODEL, + tools=[VertexAiSearchTool(data_store_id=data_store_resource)], + instruction="Answer questions using the search tool.", + ) + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", + agent=agent, + session_service=session_service, + ) + session = await session_service.create_session( + app_name="test_app", user_id="test_user" + ) + + run_config = RunConfig(streaming_mode=StreamingMode.NONE) + events = [] + async with Aclosing( + runner.run_async( + user_id="test_user", + session_id=session.id, + new_message=types.Content( + role="user", + parts=[ + types.Part.from_text( + text="What built-in tools does ADK provide?" + ) + ], + ), + run_config=run_config, + ) + ) as agen: + async for event in agen: + events.append({ + "author": event.author, + "partial": event.partial, + "has_grounding": event.grounding_metadata is not None, + "has_content": bool(event.content and event.content.parts), + }) + + print("\n=== No Streaming ===") + for i, e in enumerate(events): + print( + f" Event {i}: author={e['author']}, partial={e['partial']}," + f" grounding={e['has_grounding']}, content={e['has_content']}" + ) + + model_events = [e for e in events if e["author"] == "test_agent"] + with_grounding = [e for e in model_events if e["has_grounding"]] + assert ( + with_grounding + ), "No events have grounding_metadata even without streaming." + + async def _run_agent_streaming(self, agent, project_id): + from google.adk.agents.run_config import RunConfig + from google.adk.agents.run_config import StreamingMode + from google.adk.runners import Runner + from google.adk.sessions.in_memory_session_service import InMemorySessionService + from google.adk.utils.context_utils import Aclosing + + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", + agent=agent, + session_service=session_service, + ) + session = await session_service.create_session( + app_name="test_app", user_id="test_user" + ) + + run_config = RunConfig(streaming_mode=StreamingMode.SSE) + all_events = [] + async with Aclosing( + runner.run_async( + user_id="test_user", + session_id=session.id, + new_message=types.Content( + role="user", + parts=[ + types.Part.from_text( + text="What is ADK and when was it first released?" + ) + ], + ), + run_config=run_config, + ) + ) as agen: + async for event in agen: + all_events.append({ + "author": event.author, + "partial": event.partial, + "has_grounding": event.grounding_metadata is not None, + "has_content": bool(event.content and event.content.parts), + }) + + saved_events = [e for e in all_events if e["partial"] is not True] + return all_events, saved_events + + def _report_events(self, label, all_events, saved_events): + print(f"\n=== {label} — All Events ===") + for i, e in enumerate(all_events): + print( + f" Event {i}: author={e['author']}, partial={e['partial']}," + f" grounding={e['has_grounding']}," + f" content={e['has_content']}" + ) + print(f"\n=== {label} — Saved (non-partial) Events ===") + for i, e in enumerate(saved_events): + print( + f" Event {i}: author={e['author']}, partial={e['partial']}," + f" grounding={e['has_grounding']}," + f" content={e['has_content']}" + ) + partial_with_grounding = [ + e for e in all_events if e["partial"] is True and e["has_grounding"] + ] + if partial_with_grounding: + print( + f"\n NOTE: {len(partial_with_grounding)} partial event(s)" + " had grounding_metadata but were NOT saved to session." + ) diff --git a/tests/unittests/utils/test_streaming_utils.py b/tests/unittests/utils/test_streaming_utils.py index 3fd30469..c814ecd5 100644 --- a/tests/unittests/utils/test_streaming_utils.py +++ b/tests/unittests/utils/test_streaming_utils.py @@ -14,6 +14,8 @@ from __future__ import annotations +from google.adk.features._feature_registry import FeatureName +from google.adk.features._feature_registry import temporary_feature_override from google.adk.utils import streaming_utils from google.genai import types import pytest @@ -200,3 +202,105 @@ class TestStreamingResponseAggregator: closed_response = aggregator.close() assert closed_response is None + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_id, use_progressive_sse, metadata_type", + [ + ("grounding_default", False, "grounding"), + ("grounding_progressive", True, "grounding"), + ("citation_default", False, "citation"), + ("citation_progressive", True, "citation"), + ], + ) + async def test_close_preserves_metadata( + self, test_id, use_progressive_sse, metadata_type + ): + """close() should carry metadata into the aggregated response.""" + aggregator = streaming_utils.StreamingResponseAggregator() + + metadata = None + response1 = None + response2 = None + + if metadata_type == "grounding": + metadata = types.GroundingMetadata( + grounding_chunks=[ + types.GroundingChunk( + retrieved_context=types.GroundingChunkRetrievedContext( + uri="https://example.com/doc1", + title="Source", + ) + ) + ], + ) + response1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text="Hello ")]), + grounding_metadata=metadata, + ) + ] + ) + response2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text="World!")]), + finish_reason=types.FinishReason.STOP, + grounding_metadata=metadata, + ) + ] + ) + elif metadata_type == "citation": + metadata = types.CitationMetadata( + citations=[ + types.Citation( + start_index=0, + end_index=10, + uri="https://example.com/source", + title="Source", + ) + ] + ) + response1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text="Cited text")]), + ) + ] + ) + response2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[]), + finish_reason=types.FinishReason.STOP, + citation_metadata=metadata, + ) + ] + ) + + async def run_test(): + async for _ in aggregator.process_response(response1): + pass + async for _ in aggregator.process_response(response2): + pass + + closed_response = aggregator.close() + assert closed_response is not None + if use_progressive_sse: + assert closed_response.partial is False + + if metadata_type == "grounding": + assert closed_response.grounding_metadata is not None + assert len(closed_response.grounding_metadata.grounding_chunks) == 1 + elif metadata_type == "citation": + assert closed_response.citation_metadata is not None + assert len(closed_response.citation_metadata.citations) == 1 + + if use_progressive_sse: + with temporary_feature_override( + FeatureName.PROGRESSIVE_SSE_STREAMING, True + ): + await run_test() + else: + await run_test()