You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: propagate grounding and citation metadata in streaming responses
Co-authored-by: Sasha Sobran <asobran@google.com> PiperOrigin-RevId: 868324488
This commit is contained in:
committed by
Copybara-Service
parent
6ee5126d1c
commit
e6da417292
@@ -36,6 +36,8 @@ class StreamingResponseAggregator:
|
||||
self._text = ''
|
||||
self._thought_text = ''
|
||||
self._usage_metadata = None
|
||||
self._grounding_metadata: Optional[types.GroundingMetadata] = None
|
||||
self._citation_metadata: Optional[types.CitationMetadata] = None
|
||||
self._response = None
|
||||
|
||||
# For progressive SSE streaming mode: accumulate parts in order
|
||||
@@ -251,6 +253,10 @@ class StreamingResponseAggregator:
|
||||
self._response = response
|
||||
llm_response = LlmResponse.create(response)
|
||||
self._usage_metadata = llm_response.usage_metadata
|
||||
if llm_response.grounding_metadata:
|
||||
self._grounding_metadata = llm_response.grounding_metadata
|
||||
if llm_response.citation_metadata:
|
||||
self._citation_metadata = llm_response.citation_metadata
|
||||
|
||||
# ========== Progressive SSE Streaming (new feature) ==========
|
||||
# Save finish_reason for final aggregation
|
||||
@@ -347,6 +353,8 @@ class StreamingResponseAggregator:
|
||||
|
||||
return LlmResponse(
|
||||
content=types.ModelContent(parts=final_parts),
|
||||
grounding_metadata=self._grounding_metadata,
|
||||
citation_metadata=self._citation_metadata,
|
||||
error_code=None
|
||||
if finish_reason == types.FinishReason.STOP
|
||||
else finish_reason,
|
||||
@@ -374,6 +382,8 @@ class StreamingResponseAggregator:
|
||||
candidate = self._response.candidates[0]
|
||||
return LlmResponse(
|
||||
content=types.ModelContent(parts=parts),
|
||||
grounding_metadata=self._grounding_metadata,
|
||||
citation_metadata=self._citation_metadata,
|
||||
error_code=None
|
||||
if candidate.finish_reason == types.FinishReason.STOP
|
||||
else candidate.finish_reason,
|
||||
|
||||
@@ -0,0 +1,334 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Integration tests for grounding metadata preservation in SSE streaming.
|
||||
|
||||
Verifies that grounding_metadata from VertexAiSearchTool reaches the final
|
||||
non-partial event in both progressive and non-progressive SSE streaming modes.
|
||||
|
||||
Prerequisites:
|
||||
- GOOGLE_CLOUD_PROJECT env var set to a GCP project with Vertex AI enabled
|
||||
- Discovery Engine API enabled (discoveryengine.googleapis.com)
|
||||
- Authenticated via `gcloud auth application-default login`
|
||||
|
||||
Usage:
|
||||
GOOGLE_CLOUD_PROJECT=my-project pytest
|
||||
tests/integration/test_vertex_ai_search_grounding_streaming.py -v -s
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from google.adk.features._feature_registry import FeatureName
|
||||
from google.adk.features._feature_registry import temporary_feature_override
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT", "")
|
||||
_LOCATION = os.environ.get("GOOGLE_CLOUD_LOCATION", "global")
|
||||
_COLLECTION = "default_collection"
|
||||
_DATA_STORE_ID = f"adk-grounding-test-{uuid.uuid4().hex[:8]}"
|
||||
_DATA_STORE_DISPLAY_NAME = "ADK Grounding Integration Test"
|
||||
_MODEL = "gemini-2.0-flash"
|
||||
|
||||
_TEST_DOCUMENTS = (
|
||||
{
|
||||
"id": "doc-adk-overview",
|
||||
"title": "ADK Overview",
|
||||
"content": (
|
||||
"The Agent Development Kit (ADK) is an open-source framework by"
|
||||
" Google for building AI agents. ADK supports multi-agent"
|
||||
" architectures, tool use, and integrates with Gemini models."
|
||||
" ADK was first released in April 2025."
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "doc-adk-tools",
|
||||
"title": "ADK Built-in Tools",
|
||||
"content": (
|
||||
"ADK provides built-in tools including VertexAiSearchTool for"
|
||||
" grounded search, GoogleSearchTool for web search, and"
|
||||
" CodeExecutionTool for running code. The VertexAiSearchTool"
|
||||
" returns grounding metadata with citations pointing to source"
|
||||
" documents."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _parent_path() -> str:
|
||||
return f"projects/{_PROJECT}/locations/{_LOCATION}/collections/{_COLLECTION}"
|
||||
|
||||
|
||||
def _data_store_path() -> str:
|
||||
return f"{_parent_path()}/dataStores/{_DATA_STORE_ID}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def project_id():
|
||||
if not _PROJECT:
|
||||
pytest.skip("GOOGLE_CLOUD_PROJECT env var not set")
|
||||
return _PROJECT
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data_store_resource(project_id) -> str:
|
||||
"""Create a Vertex AI Search data store with test documents."""
|
||||
from google.api_core.exceptions import AlreadyExists
|
||||
from google.cloud import discoveryengine_v1beta as discoveryengine
|
||||
|
||||
ds_client = discoveryengine.DataStoreServiceClient()
|
||||
doc_client = discoveryengine.DocumentServiceClient()
|
||||
|
||||
# Create data store
|
||||
try:
|
||||
request = discoveryengine.CreateDataStoreRequest(
|
||||
parent=_parent_path(),
|
||||
data_store=discoveryengine.DataStore(
|
||||
display_name=_DATA_STORE_DISPLAY_NAME,
|
||||
industry_vertical=discoveryengine.IndustryVertical.GENERIC,
|
||||
solution_types=[discoveryengine.SolutionType.SOLUTION_TYPE_SEARCH],
|
||||
content_config=discoveryengine.DataStore.ContentConfig.NO_CONTENT,
|
||||
),
|
||||
data_store_id=_DATA_STORE_ID,
|
||||
)
|
||||
operation = ds_client.create_data_store(request=request)
|
||||
print(f"\nCreating data store '{_DATA_STORE_ID}'...")
|
||||
operation.result(timeout=120)
|
||||
print("Data store created.")
|
||||
except AlreadyExists:
|
||||
print(f"\nData store '{_DATA_STORE_ID}' already exists, reusing.")
|
||||
|
||||
# Ingest test documents
|
||||
branch = f"{_data_store_path()}/branches/default_branch"
|
||||
for doc_data in _TEST_DOCUMENTS:
|
||||
json_data = json.dumps({
|
||||
"title": doc_data["title"],
|
||||
"description": doc_data["content"],
|
||||
})
|
||||
doc = discoveryengine.Document(
|
||||
id=doc_data["id"],
|
||||
json_data=json_data,
|
||||
)
|
||||
try:
|
||||
doc_client.create_document(
|
||||
parent=branch,
|
||||
document=doc,
|
||||
document_id=doc_data["id"],
|
||||
)
|
||||
print(f" Created document: {doc_data['id']}")
|
||||
except AlreadyExists:
|
||||
doc_client.update_document(
|
||||
document=discoveryengine.Document(
|
||||
name=f"{branch}/documents/{doc_data['id']}",
|
||||
json_data=json_data,
|
||||
),
|
||||
)
|
||||
print(f" Updated document: {doc_data['id']}")
|
||||
|
||||
print("Waiting 5s for indexing...")
|
||||
time.sleep(5)
|
||||
|
||||
yield _data_store_path()
|
||||
|
||||
# Cleanup — best-effort, ignore errors from Discovery Engine LRO
|
||||
try:
|
||||
operation = ds_client.delete_data_store(name=_data_store_path())
|
||||
operation.result(timeout=120)
|
||||
print(f"\nDeleted data store '{_DATA_STORE_ID}'.")
|
||||
except Exception as e:
|
||||
print(f"\nFailed to delete data store '{_DATA_STORE_ID}': {e}")
|
||||
|
||||
|
||||
class TestIntegrationVertexAiSearchGrounding:
|
||||
"""Integration tests hitting real Vertex AI with VertexAiSearchTool."""
|
||||
|
||||
@pytest.mark.parametrize("llm_backend", ["VERTEX"], indirect=True)
|
||||
@pytest.mark.parametrize(
|
||||
"progressive_sse, label",
|
||||
[
|
||||
(True, "Progressive SSE"),
|
||||
(False, "Non-Progressive SSE"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_grounding_metadata_with_sse_streaming(
|
||||
self, project_id, data_store_resource, progressive_sse, label
|
||||
):
|
||||
"""Verifies grounding_metadata in SSE streaming modes."""
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
|
||||
|
||||
agent = LlmAgent(
|
||||
name="test_agent",
|
||||
model=_MODEL,
|
||||
tools=[VertexAiSearchTool(data_store_id=data_store_resource)],
|
||||
instruction="Answer questions using the search tool.",
|
||||
)
|
||||
|
||||
with temporary_feature_override(
|
||||
FeatureName.PROGRESSIVE_SSE_STREAMING, progressive_sse
|
||||
):
|
||||
all_events, saved_events = await self._run_agent_streaming(
|
||||
agent, project_id
|
||||
)
|
||||
|
||||
self._report_events(label, all_events, saved_events)
|
||||
|
||||
saved_with_grounding = [e for e in saved_events if e["has_grounding"]]
|
||||
assert (
|
||||
saved_with_grounding
|
||||
), f"No saved (non-partial) events have grounding_metadata with {label}."
|
||||
|
||||
@pytest.mark.parametrize("llm_backend", ["VERTEX"], indirect=True)
|
||||
@pytest.mark.asyncio
|
||||
async def test_grounding_metadata_without_streaming(
|
||||
self, project_id, data_store_resource
|
||||
):
|
||||
"""Without streaming, grounding_metadata should always be present."""
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.agents.run_config import StreamingMode
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
|
||||
from google.adk.utils.context_utils import Aclosing
|
||||
|
||||
agent = LlmAgent(
|
||||
name="test_agent",
|
||||
model=_MODEL,
|
||||
tools=[VertexAiSearchTool(data_store_id=data_store_resource)],
|
||||
instruction="Answer questions using the search tool.",
|
||||
)
|
||||
|
||||
session_service = InMemorySessionService()
|
||||
runner = Runner(
|
||||
app_name="test_app",
|
||||
agent=agent,
|
||||
session_service=session_service,
|
||||
)
|
||||
session = await session_service.create_session(
|
||||
app_name="test_app", user_id="test_user"
|
||||
)
|
||||
|
||||
run_config = RunConfig(streaming_mode=StreamingMode.NONE)
|
||||
events = []
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id="test_user",
|
||||
session_id=session.id,
|
||||
new_message=types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text="What built-in tools does ADK provide?"
|
||||
)
|
||||
],
|
||||
),
|
||||
run_config=run_config,
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
events.append({
|
||||
"author": event.author,
|
||||
"partial": event.partial,
|
||||
"has_grounding": event.grounding_metadata is not None,
|
||||
"has_content": bool(event.content and event.content.parts),
|
||||
})
|
||||
|
||||
print("\n=== No Streaming ===")
|
||||
for i, e in enumerate(events):
|
||||
print(
|
||||
f" Event {i}: author={e['author']}, partial={e['partial']},"
|
||||
f" grounding={e['has_grounding']}, content={e['has_content']}"
|
||||
)
|
||||
|
||||
model_events = [e for e in events if e["author"] == "test_agent"]
|
||||
with_grounding = [e for e in model_events if e["has_grounding"]]
|
||||
assert (
|
||||
with_grounding
|
||||
), "No events have grounding_metadata even without streaming."
|
||||
|
||||
async def _run_agent_streaming(self, agent, project_id):
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.agents.run_config import StreamingMode
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.utils.context_utils import Aclosing
|
||||
|
||||
session_service = InMemorySessionService()
|
||||
runner = Runner(
|
||||
app_name="test_app",
|
||||
agent=agent,
|
||||
session_service=session_service,
|
||||
)
|
||||
session = await session_service.create_session(
|
||||
app_name="test_app", user_id="test_user"
|
||||
)
|
||||
|
||||
run_config = RunConfig(streaming_mode=StreamingMode.SSE)
|
||||
all_events = []
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id="test_user",
|
||||
session_id=session.id,
|
||||
new_message=types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text="What is ADK and when was it first released?"
|
||||
)
|
||||
],
|
||||
),
|
||||
run_config=run_config,
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
all_events.append({
|
||||
"author": event.author,
|
||||
"partial": event.partial,
|
||||
"has_grounding": event.grounding_metadata is not None,
|
||||
"has_content": bool(event.content and event.content.parts),
|
||||
})
|
||||
|
||||
saved_events = [e for e in all_events if e["partial"] is not True]
|
||||
return all_events, saved_events
|
||||
|
||||
def _report_events(self, label, all_events, saved_events):
|
||||
print(f"\n=== {label} — All Events ===")
|
||||
for i, e in enumerate(all_events):
|
||||
print(
|
||||
f" Event {i}: author={e['author']}, partial={e['partial']},"
|
||||
f" grounding={e['has_grounding']},"
|
||||
f" content={e['has_content']}"
|
||||
)
|
||||
print(f"\n=== {label} — Saved (non-partial) Events ===")
|
||||
for i, e in enumerate(saved_events):
|
||||
print(
|
||||
f" Event {i}: author={e['author']}, partial={e['partial']},"
|
||||
f" grounding={e['has_grounding']},"
|
||||
f" content={e['has_content']}"
|
||||
)
|
||||
partial_with_grounding = [
|
||||
e for e in all_events if e["partial"] is True and e["has_grounding"]
|
||||
]
|
||||
if partial_with_grounding:
|
||||
print(
|
||||
f"\n NOTE: {len(partial_with_grounding)} partial event(s)"
|
||||
" had grounding_metadata but were NOT saved to session."
|
||||
)
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.features._feature_registry import FeatureName
|
||||
from google.adk.features._feature_registry import temporary_feature_override
|
||||
from google.adk.utils import streaming_utils
|
||||
from google.genai import types
|
||||
import pytest
|
||||
@@ -200,3 +202,105 @@ class TestStreamingResponseAggregator:
|
||||
|
||||
closed_response = aggregator.close()
|
||||
assert closed_response is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_id, use_progressive_sse, metadata_type",
|
||||
[
|
||||
("grounding_default", False, "grounding"),
|
||||
("grounding_progressive", True, "grounding"),
|
||||
("citation_default", False, "citation"),
|
||||
("citation_progressive", True, "citation"),
|
||||
],
|
||||
)
|
||||
async def test_close_preserves_metadata(
|
||||
self, test_id, use_progressive_sse, metadata_type
|
||||
):
|
||||
"""close() should carry metadata into the aggregated response."""
|
||||
aggregator = streaming_utils.StreamingResponseAggregator()
|
||||
|
||||
metadata = None
|
||||
response1 = None
|
||||
response2 = None
|
||||
|
||||
if metadata_type == "grounding":
|
||||
metadata = types.GroundingMetadata(
|
||||
grounding_chunks=[
|
||||
types.GroundingChunk(
|
||||
retrieved_context=types.GroundingChunkRetrievedContext(
|
||||
uri="https://example.com/doc1",
|
||||
title="Source",
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
response1 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text="Hello ")]),
|
||||
grounding_metadata=metadata,
|
||||
)
|
||||
]
|
||||
)
|
||||
response2 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text="World!")]),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
grounding_metadata=metadata,
|
||||
)
|
||||
]
|
||||
)
|
||||
elif metadata_type == "citation":
|
||||
metadata = types.CitationMetadata(
|
||||
citations=[
|
||||
types.Citation(
|
||||
start_index=0,
|
||||
end_index=10,
|
||||
uri="https://example.com/source",
|
||||
title="Source",
|
||||
)
|
||||
]
|
||||
)
|
||||
response1 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text="Cited text")]),
|
||||
)
|
||||
]
|
||||
)
|
||||
response2 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[]),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
citation_metadata=metadata,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def run_test():
|
||||
async for _ in aggregator.process_response(response1):
|
||||
pass
|
||||
async for _ in aggregator.process_response(response2):
|
||||
pass
|
||||
|
||||
closed_response = aggregator.close()
|
||||
assert closed_response is not None
|
||||
if use_progressive_sse:
|
||||
assert closed_response.partial is False
|
||||
|
||||
if metadata_type == "grounding":
|
||||
assert closed_response.grounding_metadata is not None
|
||||
assert len(closed_response.grounding_metadata.grounding_chunks) == 1
|
||||
elif metadata_type == "citation":
|
||||
assert closed_response.citation_metadata is not None
|
||||
assert len(closed_response.citation_metadata.citations) == 1
|
||||
|
||||
if use_progressive_sse:
|
||||
with temporary_feature_override(
|
||||
FeatureName.PROGRESSIVE_SSE_STREAMING, True
|
||||
):
|
||||
await run_test()
|
||||
else:
|
||||
await run_test()
|
||||
|
||||
Reference in New Issue
Block a user