fix: propagate grounding and citation metadata in streaming responses

Co-authored-by: Sasha Sobran <asobran@google.com>
PiperOrigin-RevId: 868324488
This commit is contained in:
Sasha Sobran
2026-02-10 14:35:54 -08:00
committed by Copybara-Service
parent 6ee5126d1c
commit e6da417292
3 changed files with 448 additions and 0 deletions
+10
View File
@@ -36,6 +36,8 @@ class StreamingResponseAggregator:
self._text = ''
self._thought_text = ''
self._usage_metadata = None
self._grounding_metadata: Optional[types.GroundingMetadata] = None
self._citation_metadata: Optional[types.CitationMetadata] = None
self._response = None
# For progressive SSE streaming mode: accumulate parts in order
@@ -251,6 +253,10 @@ class StreamingResponseAggregator:
self._response = response
llm_response = LlmResponse.create(response)
self._usage_metadata = llm_response.usage_metadata
if llm_response.grounding_metadata:
self._grounding_metadata = llm_response.grounding_metadata
if llm_response.citation_metadata:
self._citation_metadata = llm_response.citation_metadata
# ========== Progressive SSE Streaming (new feature) ==========
# Save finish_reason for final aggregation
@@ -347,6 +353,8 @@ class StreamingResponseAggregator:
return LlmResponse(
content=types.ModelContent(parts=final_parts),
grounding_metadata=self._grounding_metadata,
citation_metadata=self._citation_metadata,
error_code=None
if finish_reason == types.FinishReason.STOP
else finish_reason,
@@ -374,6 +382,8 @@ class StreamingResponseAggregator:
candidate = self._response.candidates[0]
return LlmResponse(
content=types.ModelContent(parts=parts),
grounding_metadata=self._grounding_metadata,
citation_metadata=self._citation_metadata,
error_code=None
if candidate.finish_reason == types.FinishReason.STOP
else candidate.finish_reason,
@@ -0,0 +1,334 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for grounding metadata preservation in SSE streaming.
Verifies that grounding_metadata from VertexAiSearchTool reaches the final
non-partial event in both progressive and non-progressive SSE streaming modes.
Prerequisites:
- GOOGLE_CLOUD_PROJECT env var set to a GCP project with Vertex AI enabled
- Discovery Engine API enabled (discoveryengine.googleapis.com)
- Authenticated via `gcloud auth application-default login`
Usage:
GOOGLE_CLOUD_PROJECT=my-project pytest
tests/integration/test_vertex_ai_search_grounding_streaming.py -v -s
"""
from __future__ import annotations
import json
import os
import time
import uuid
from google.adk.features._feature_registry import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.genai import types
import pytest
_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT", "")
_LOCATION = os.environ.get("GOOGLE_CLOUD_LOCATION", "global")
_COLLECTION = "default_collection"
_DATA_STORE_ID = f"adk-grounding-test-{uuid.uuid4().hex[:8]}"
_DATA_STORE_DISPLAY_NAME = "ADK Grounding Integration Test"
_MODEL = "gemini-2.0-flash"
_TEST_DOCUMENTS = (
{
"id": "doc-adk-overview",
"title": "ADK Overview",
"content": (
"The Agent Development Kit (ADK) is an open-source framework by"
" Google for building AI agents. ADK supports multi-agent"
" architectures, tool use, and integrates with Gemini models."
" ADK was first released in April 2025."
),
},
{
"id": "doc-adk-tools",
"title": "ADK Built-in Tools",
"content": (
"ADK provides built-in tools including VertexAiSearchTool for"
" grounded search, GoogleSearchTool for web search, and"
" CodeExecutionTool for running code. The VertexAiSearchTool"
" returns grounding metadata with citations pointing to source"
" documents."
),
},
)
def _parent_path() -> str:
return f"projects/{_PROJECT}/locations/{_LOCATION}/collections/{_COLLECTION}"
def _data_store_path() -> str:
return f"{_parent_path()}/dataStores/{_DATA_STORE_ID}"
@pytest.fixture(scope="module")
def project_id():
if not _PROJECT:
pytest.skip("GOOGLE_CLOUD_PROJECT env var not set")
return _PROJECT
@pytest.fixture(scope="module")
def data_store_resource(project_id) -> str:
"""Create a Vertex AI Search data store with test documents."""
from google.api_core.exceptions import AlreadyExists
from google.cloud import discoveryengine_v1beta as discoveryengine
ds_client = discoveryengine.DataStoreServiceClient()
doc_client = discoveryengine.DocumentServiceClient()
# Create data store
try:
request = discoveryengine.CreateDataStoreRequest(
parent=_parent_path(),
data_store=discoveryengine.DataStore(
display_name=_DATA_STORE_DISPLAY_NAME,
industry_vertical=discoveryengine.IndustryVertical.GENERIC,
solution_types=[discoveryengine.SolutionType.SOLUTION_TYPE_SEARCH],
content_config=discoveryengine.DataStore.ContentConfig.NO_CONTENT,
),
data_store_id=_DATA_STORE_ID,
)
operation = ds_client.create_data_store(request=request)
print(f"\nCreating data store '{_DATA_STORE_ID}'...")
operation.result(timeout=120)
print("Data store created.")
except AlreadyExists:
print(f"\nData store '{_DATA_STORE_ID}' already exists, reusing.")
# Ingest test documents
branch = f"{_data_store_path()}/branches/default_branch"
for doc_data in _TEST_DOCUMENTS:
json_data = json.dumps({
"title": doc_data["title"],
"description": doc_data["content"],
})
doc = discoveryengine.Document(
id=doc_data["id"],
json_data=json_data,
)
try:
doc_client.create_document(
parent=branch,
document=doc,
document_id=doc_data["id"],
)
print(f" Created document: {doc_data['id']}")
except AlreadyExists:
doc_client.update_document(
document=discoveryengine.Document(
name=f"{branch}/documents/{doc_data['id']}",
json_data=json_data,
),
)
print(f" Updated document: {doc_data['id']}")
print("Waiting 5s for indexing...")
time.sleep(5)
yield _data_store_path()
# Cleanup — best-effort, ignore errors from Discovery Engine LRO
try:
operation = ds_client.delete_data_store(name=_data_store_path())
operation.result(timeout=120)
print(f"\nDeleted data store '{_DATA_STORE_ID}'.")
except Exception as e:
print(f"\nFailed to delete data store '{_DATA_STORE_ID}': {e}")
class TestIntegrationVertexAiSearchGrounding:
"""Integration tests hitting real Vertex AI with VertexAiSearchTool."""
@pytest.mark.parametrize("llm_backend", ["VERTEX"], indirect=True)
@pytest.mark.parametrize(
"progressive_sse, label",
[
(True, "Progressive SSE"),
(False, "Non-Progressive SSE"),
],
)
@pytest.mark.asyncio
async def test_grounding_metadata_with_sse_streaming(
self, project_id, data_store_resource, progressive_sse, label
):
"""Verifies grounding_metadata in SSE streaming modes."""
from google.adk.agents.llm_agent import LlmAgent
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
agent = LlmAgent(
name="test_agent",
model=_MODEL,
tools=[VertexAiSearchTool(data_store_id=data_store_resource)],
instruction="Answer questions using the search tool.",
)
with temporary_feature_override(
FeatureName.PROGRESSIVE_SSE_STREAMING, progressive_sse
):
all_events, saved_events = await self._run_agent_streaming(
agent, project_id
)
self._report_events(label, all_events, saved_events)
saved_with_grounding = [e for e in saved_events if e["has_grounding"]]
assert (
saved_with_grounding
), f"No saved (non-partial) events have grounding_metadata with {label}."
@pytest.mark.parametrize("llm_backend", ["VERTEX"], indirect=True)
@pytest.mark.asyncio
async def test_grounding_metadata_without_streaming(
self, project_id, data_store_resource
):
"""Without streaming, grounding_metadata should always be present."""
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.agents.run_config import StreamingMode
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
from google.adk.utils.context_utils import Aclosing
agent = LlmAgent(
name="test_agent",
model=_MODEL,
tools=[VertexAiSearchTool(data_store_id=data_store_resource)],
instruction="Answer questions using the search tool.",
)
session_service = InMemorySessionService()
runner = Runner(
app_name="test_app",
agent=agent,
session_service=session_service,
)
session = await session_service.create_session(
app_name="test_app", user_id="test_user"
)
run_config = RunConfig(streaming_mode=StreamingMode.NONE)
events = []
async with Aclosing(
runner.run_async(
user_id="test_user",
session_id=session.id,
new_message=types.Content(
role="user",
parts=[
types.Part.from_text(
text="What built-in tools does ADK provide?"
)
],
),
run_config=run_config,
)
) as agen:
async for event in agen:
events.append({
"author": event.author,
"partial": event.partial,
"has_grounding": event.grounding_metadata is not None,
"has_content": bool(event.content and event.content.parts),
})
print("\n=== No Streaming ===")
for i, e in enumerate(events):
print(
f" Event {i}: author={e['author']}, partial={e['partial']},"
f" grounding={e['has_grounding']}, content={e['has_content']}"
)
model_events = [e for e in events if e["author"] == "test_agent"]
with_grounding = [e for e in model_events if e["has_grounding"]]
assert (
with_grounding
), "No events have grounding_metadata even without streaming."
async def _run_agent_streaming(self, agent, project_id):
from google.adk.agents.run_config import RunConfig
from google.adk.agents.run_config import StreamingMode
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.utils.context_utils import Aclosing
session_service = InMemorySessionService()
runner = Runner(
app_name="test_app",
agent=agent,
session_service=session_service,
)
session = await session_service.create_session(
app_name="test_app", user_id="test_user"
)
run_config = RunConfig(streaming_mode=StreamingMode.SSE)
all_events = []
async with Aclosing(
runner.run_async(
user_id="test_user",
session_id=session.id,
new_message=types.Content(
role="user",
parts=[
types.Part.from_text(
text="What is ADK and when was it first released?"
)
],
),
run_config=run_config,
)
) as agen:
async for event in agen:
all_events.append({
"author": event.author,
"partial": event.partial,
"has_grounding": event.grounding_metadata is not None,
"has_content": bool(event.content and event.content.parts),
})
saved_events = [e for e in all_events if e["partial"] is not True]
return all_events, saved_events
def _report_events(self, label, all_events, saved_events):
print(f"\n=== {label} — All Events ===")
for i, e in enumerate(all_events):
print(
f" Event {i}: author={e['author']}, partial={e['partial']},"
f" grounding={e['has_grounding']},"
f" content={e['has_content']}"
)
print(f"\n=== {label} — Saved (non-partial) Events ===")
for i, e in enumerate(saved_events):
print(
f" Event {i}: author={e['author']}, partial={e['partial']},"
f" grounding={e['has_grounding']},"
f" content={e['has_content']}"
)
partial_with_grounding = [
e for e in all_events if e["partial"] is True and e["has_grounding"]
]
if partial_with_grounding:
print(
f"\n NOTE: {len(partial_with_grounding)} partial event(s)"
" had grounding_metadata but were NOT saved to session."
)
@@ -14,6 +14,8 @@
from __future__ import annotations
from google.adk.features._feature_registry import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.utils import streaming_utils
from google.genai import types
import pytest
@@ -200,3 +202,105 @@ class TestStreamingResponseAggregator:
closed_response = aggregator.close()
assert closed_response is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_id, use_progressive_sse, metadata_type",
[
("grounding_default", False, "grounding"),
("grounding_progressive", True, "grounding"),
("citation_default", False, "citation"),
("citation_progressive", True, "citation"),
],
)
async def test_close_preserves_metadata(
self, test_id, use_progressive_sse, metadata_type
):
"""close() should carry metadata into the aggregated response."""
aggregator = streaming_utils.StreamingResponseAggregator()
metadata = None
response1 = None
response2 = None
if metadata_type == "grounding":
metadata = types.GroundingMetadata(
grounding_chunks=[
types.GroundingChunk(
retrieved_context=types.GroundingChunkRetrievedContext(
uri="https://example.com/doc1",
title="Source",
)
)
],
)
response1 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text="Hello ")]),
grounding_metadata=metadata,
)
]
)
response2 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text="World!")]),
finish_reason=types.FinishReason.STOP,
grounding_metadata=metadata,
)
]
)
elif metadata_type == "citation":
metadata = types.CitationMetadata(
citations=[
types.Citation(
start_index=0,
end_index=10,
uri="https://example.com/source",
title="Source",
)
]
)
response1 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text="Cited text")]),
)
]
)
response2 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[]),
finish_reason=types.FinishReason.STOP,
citation_metadata=metadata,
)
]
)
async def run_test():
async for _ in aggregator.process_response(response1):
pass
async for _ in aggregator.process_response(response2):
pass
closed_response = aggregator.close()
assert closed_response is not None
if use_progressive_sse:
assert closed_response.partial is False
if metadata_type == "grounding":
assert closed_response.grounding_metadata is not None
assert len(closed_response.grounding_metadata.grounding_chunks) == 1
elif metadata_type == "citation":
assert closed_response.citation_metadata is not None
assert len(closed_response.citation_metadata.citations) == 1
if use_progressive_sse:
with temporary_feature_override(
FeatureName.PROGRESSIVE_SSE_STREAMING, True
):
await run_test()
else:
await run_test()