You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add progressive SSE streaming feature
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 833483804
This commit is contained in:
committed by
Copybara-Service
parent
0ec01956e8
commit
a5ac1d5e14
@@ -24,8 +24,9 @@ from ..utils.env_utils import is_env_enabled
|
||||
class FeatureName(str, Enum):
|
||||
"""Feature names."""
|
||||
|
||||
JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL"
|
||||
COMPUTER_USE = "COMPUTER_USE"
|
||||
JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL"
|
||||
PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING"
|
||||
|
||||
|
||||
class FeatureStage(Enum):
|
||||
@@ -58,11 +59,14 @@ class FeatureConfig:
|
||||
|
||||
# Central registry: FeatureName -> FeatureConfig
|
||||
_FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = {
|
||||
FeatureName.COMPUTER_USE: FeatureConfig(
|
||||
FeatureStage.EXPERIMENTAL, default_on=True
|
||||
),
|
||||
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL: FeatureConfig(
|
||||
FeatureStage.WIP, default_on=False
|
||||
),
|
||||
FeatureName.COMPUTER_USE: FeatureConfig(
|
||||
FeatureStage.EXPERIMENTAL, default_on=True
|
||||
FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig(
|
||||
FeatureStage.WIP, default_on=False
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,8 @@ from ...agents.readonly_context import ReadonlyContext
|
||||
from ...agents.run_config import StreamingMode
|
||||
from ...agents.transcription_entry import TranscriptionEntry
|
||||
from ...events.event import Event
|
||||
from ...features import FeatureName
|
||||
from ...features import is_feature_enabled
|
||||
from ...models.base_llm_connection import BaseLlmConnection
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...models.llm_response import LlmResponse
|
||||
@@ -525,6 +527,16 @@ class BaseLlmFlow(ABC):
|
||||
|
||||
# Handles function calls.
|
||||
if model_response_event.get_function_calls():
|
||||
|
||||
if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING):
|
||||
# In progressive SSE streaming mode stage 1, we skip partial FC events
|
||||
# Only execute FCs in the final aggregated event (partial=False)
|
||||
if (
|
||||
invocation_context.run_config.streaming_mode == StreamingMode.SSE
|
||||
and model_response_event.partial
|
||||
):
|
||||
return
|
||||
|
||||
async with Aclosing(
|
||||
self._postprocess_handle_function_calls_async(
|
||||
invocation_context, model_response_event, llm_request
|
||||
|
||||
@@ -19,6 +19,8 @@ from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..features import FeatureName
|
||||
from ..features import is_feature_enabled
|
||||
from ..models.llm_response import LlmResponse
|
||||
|
||||
|
||||
@@ -35,6 +37,30 @@ class StreamingResponseAggregator:
|
||||
self._usage_metadata = None
|
||||
self._response = None
|
||||
|
||||
# For progressive SSE streaming mode: accumulate parts in order
|
||||
self._parts_sequence: list[types.Part] = []
|
||||
self._current_text_buffer: str = ''
|
||||
self._current_text_is_thought: Optional[bool] = None
|
||||
self._finish_reason: Optional[types.FinishReason] = None
|
||||
|
||||
def _flush_text_buffer_to_sequence(self):
|
||||
"""Flush current text buffer to parts sequence.
|
||||
|
||||
This helper is used in progressive SSE mode to maintain part ordering.
|
||||
It only merges consecutive text parts of the same type (thought or regular).
|
||||
"""
|
||||
if self._current_text_buffer:
|
||||
if self._current_text_is_thought:
|
||||
self._parts_sequence.append(
|
||||
types.Part(text=self._current_text_buffer, thought=True)
|
||||
)
|
||||
else:
|
||||
self._parts_sequence.append(
|
||||
types.Part.from_text(text=self._current_text_buffer)
|
||||
)
|
||||
self._current_text_buffer = ''
|
||||
self._current_text_is_thought = None
|
||||
|
||||
async def process_response(
|
||||
self, response: types.GenerateContentResponse
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
@@ -51,6 +77,42 @@ class StreamingResponseAggregator:
|
||||
self._response = response
|
||||
llm_response = LlmResponse.create(response)
|
||||
self._usage_metadata = llm_response.usage_metadata
|
||||
|
||||
# ========== Progressive SSE Streaming (new feature) ==========
|
||||
# Save finish_reason for final aggregation
|
||||
if llm_response.finish_reason:
|
||||
self._finish_reason = llm_response.finish_reason
|
||||
|
||||
if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING):
|
||||
# Accumulate parts while preserving their order
|
||||
# Only merge consecutive text parts of the same type (thought or regular)
|
||||
if llm_response.content and llm_response.content.parts:
|
||||
for part in llm_response.content.parts:
|
||||
if part.text:
|
||||
# Check if we need to flush the current buffer first
|
||||
# (when text type changes from thought to regular or vice versa)
|
||||
if (
|
||||
self._current_text_buffer
|
||||
and part.thought != self._current_text_is_thought
|
||||
):
|
||||
self._flush_text_buffer_to_sequence()
|
||||
|
||||
# Accumulate text to buffer
|
||||
if not self._current_text_buffer:
|
||||
self._current_text_is_thought = part.thought
|
||||
self._current_text_buffer += part.text
|
||||
else:
|
||||
# Non-text part (function_call, bytes, etc.)
|
||||
# Flush any buffered text first, then add the non-text part
|
||||
self._flush_text_buffer_to_sequence()
|
||||
self._parts_sequence.append(part)
|
||||
|
||||
# Mark ALL intermediate chunks as partial
|
||||
llm_response.partial = True
|
||||
yield llm_response
|
||||
return
|
||||
|
||||
# ========== Non-Progressive SSE Streaming (old behavior) ==========
|
||||
if (
|
||||
llm_response.content
|
||||
and llm_response.content.parts
|
||||
@@ -89,6 +151,36 @@ class StreamingResponseAggregator:
|
||||
Returns:
|
||||
The aggregated LlmResponse.
|
||||
"""
|
||||
# ========== Progressive SSE Streaming (new feature) ==========
|
||||
if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING):
|
||||
# Always generate final aggregated response in progressive mode
|
||||
if self._response and self._response.candidates:
|
||||
# Flush any remaining text buffer to complete the sequence
|
||||
self._flush_text_buffer_to_sequence()
|
||||
|
||||
# Use the parts sequence which preserves original ordering
|
||||
final_parts = self._parts_sequence
|
||||
|
||||
if final_parts:
|
||||
candidate = self._response.candidates[0]
|
||||
finish_reason = self._finish_reason or candidate.finish_reason
|
||||
|
||||
return LlmResponse(
|
||||
content=types.ModelContent(parts=final_parts),
|
||||
error_code=None
|
||||
if finish_reason == types.FinishReason.STOP
|
||||
else finish_reason,
|
||||
error_message=None
|
||||
if finish_reason == types.FinishReason.STOP
|
||||
else candidate.finish_message,
|
||||
usage_metadata=self._usage_metadata,
|
||||
finish_reason=finish_reason,
|
||||
partial=False,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# ========== Non-Progressive SSE Streaming (old behavior) ==========
|
||||
if (
|
||||
(self._text or self._thought_text)
|
||||
and self._response
|
||||
|
||||
@@ -0,0 +1,399 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Progressive SSE Streaming Stage 1 implementation."""
|
||||
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.agents.run_config import StreamingMode
|
||||
from google.adk.models.base_llm import BaseLlm
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.runners import InMemoryRunner
|
||||
from google.adk.utils.streaming_utils import StreamingResponseAggregator
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_env(monkeypatch):
|
||||
monkeypatch.setenv("ADK_ENABLE_PROGRESSIVE_SSE_STREAMING", "1")
|
||||
yield
|
||||
monkeypatch.delenv("ADK_ENABLE_PROGRESSIVE_SSE_STREAMING")
|
||||
|
||||
|
||||
def get_weather(location: str) -> dict[str, Any]:
|
||||
"""Mock weather function for testing.
|
||||
|
||||
Args:
|
||||
location: The location to get the weather for.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the weather information.
|
||||
"""
|
||||
return {
|
||||
"temperature": 22,
|
||||
"condition": "sunny",
|
||||
"location": location,
|
||||
}
|
||||
|
||||
|
||||
class StreamingMockModel(BaseLlm):
|
||||
"""A mock model that properly streams multiple chunks in a single call."""
|
||||
|
||||
model: str = "streaming-mock"
|
||||
stream_chunks: list[LlmResponse] = []
|
||||
call_count: int = 0
|
||||
|
||||
@classmethod
|
||||
def supported_models(cls) -> list[str]:
|
||||
return ["streaming-mock"]
|
||||
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Yield all chunks in a single streaming call."""
|
||||
self.call_count += 1
|
||||
|
||||
# Only stream on the first call
|
||||
if self.call_count > 1:
|
||||
# On subsequent calls, return a simple final response
|
||||
yield LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[types.Part.from_text(text="Task completed.")],
|
||||
),
|
||||
partial=False,
|
||||
)
|
||||
return
|
||||
|
||||
aggregator = StreamingResponseAggregator()
|
||||
|
||||
# Process each chunk through the aggregator
|
||||
for chunk in self.stream_chunks:
|
||||
# Convert LlmResponse to types.GenerateContentResponse
|
||||
# Since we don't have the full response object, we'll simulate it
|
||||
async for processed_chunk in aggregator.process_response(
|
||||
self._llm_response_to_generate_content_response(chunk)
|
||||
):
|
||||
yield processed_chunk
|
||||
|
||||
# Call close() to get the final aggregated response
|
||||
if final_response := aggregator.close():
|
||||
yield final_response
|
||||
|
||||
def _llm_response_to_generate_content_response(
|
||||
self, llm_response: LlmResponse
|
||||
) -> types.GenerateContentResponse:
|
||||
"""Convert LlmResponse to GenerateContentResponse for aggregator."""
|
||||
# Create a minimal GenerateContentResponse that the aggregator can process
|
||||
candidates = []
|
||||
if llm_response.content:
|
||||
candidates.append(
|
||||
types.Candidate(
|
||||
content=llm_response.content,
|
||||
finish_reason=llm_response.finish_reason,
|
||||
finish_message=llm_response.error_message,
|
||||
)
|
||||
)
|
||||
|
||||
return types.GenerateContentResponse(
|
||||
candidates=candidates,
|
||||
usage_metadata=llm_response.usage_metadata,
|
||||
)
|
||||
|
||||
|
||||
def test_progressive_sse_streaming_function_calls():
|
||||
"""Test that function calls are buffered and executed in parallel."""
|
||||
|
||||
# Setup: Create mock responses simulating streaming chunks
|
||||
response1 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part.from_text(text="Checking weather...")]
|
||||
),
|
||||
)
|
||||
|
||||
response2 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part.from_function_call(
|
||||
name="get_weather", args={"location": "Tokyo"}
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
response3 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part.from_function_call(
|
||||
name="get_weather", args={"location": "New York"}
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
|
||||
# Create a streaming mock that yields all chunks in one call
|
||||
mock_model = StreamingMockModel(
|
||||
stream_chunks=[response1, response2, response3]
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
name="weather_agent",
|
||||
model=mock_model,
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
run_config = RunConfig(streaming_mode=StreamingMode.SSE)
|
||||
|
||||
# Use the real InMemoryRunner to get access to run_config parameter
|
||||
runner = InMemoryRunner(agent=agent)
|
||||
|
||||
# Create session manually
|
||||
session = runner.session_service.create_session_sync(
|
||||
app_name=runner.app_name, user_id="test_user"
|
||||
)
|
||||
|
||||
events = []
|
||||
for event in runner.run(
|
||||
user_id="test_user",
|
||||
session_id=session.id,
|
||||
new_message=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="What is the weather?")],
|
||||
),
|
||||
run_config=run_config,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Verify event structure (Stage 1 expectations)
|
||||
# Expected events:
|
||||
# 0-2: Partial events (text + 2 FCs) - not executed
|
||||
# 3: Final aggregated model event (text + 2 FCs) - partial=False
|
||||
# 4: Aggregated function response (both get_weather results executed in
|
||||
# parallel)
|
||||
# 5: Final model response after FCs
|
||||
assert len(events) == 6
|
||||
|
||||
assert events[0].partial
|
||||
assert events[0].content.parts[0].text == "Checking weather..."
|
||||
|
||||
assert events[1].partial
|
||||
assert events[1].content.parts[0].function_call.name == "get_weather"
|
||||
assert events[1].content.parts[0].function_call.args["location"] == "Tokyo"
|
||||
|
||||
assert events[2].partial
|
||||
assert events[2].content.parts[0].function_call.name == "get_weather"
|
||||
assert events[2].content.parts[0].function_call.args["location"] == "New York"
|
||||
|
||||
assert not events[3].partial
|
||||
assert events[3].content.parts[0].text == "Checking weather..."
|
||||
assert events[3].content.parts[1].function_call.name == "get_weather"
|
||||
assert events[3].content.parts[1].function_call.args["location"] == "Tokyo"
|
||||
assert events[3].content.parts[2].function_call.name == "get_weather"
|
||||
assert events[3].content.parts[2].function_call.args["location"] == "New York"
|
||||
|
||||
assert not events[4].partial
|
||||
assert events[4].content.parts[0].function_response.name == "get_weather"
|
||||
assert (
|
||||
events[4].content.parts[0].function_response.response["location"]
|
||||
== "Tokyo"
|
||||
)
|
||||
assert events[4].content.parts[1].function_response.name == "get_weather"
|
||||
assert (
|
||||
events[4].content.parts[1].function_response.response["location"]
|
||||
== "New York"
|
||||
)
|
||||
|
||||
assert not events[5].partial
|
||||
assert events[5].content.parts[0].text == "Task completed."
|
||||
|
||||
|
||||
def test_progressive_sse_preserves_part_ordering():
|
||||
"""Test that part ordering is preserved, especially for thought parts.
|
||||
|
||||
This test verifies that when the model outputs:
|
||||
- chunk1(thought1_1)
|
||||
- chunk2(thought1_2)
|
||||
- chunk3(text1_1)
|
||||
- chunk4(text1_2)
|
||||
- chunk5(FC1)
|
||||
- chunk6(thought2_1)
|
||||
- chunk7(thought2_2)
|
||||
- chunk8(FC2)
|
||||
|
||||
The final aggregated output should be:
|
||||
- Part(thought1) # thought1_1 + thought1_2 merged
|
||||
- Part(text1) # text1_1 + text1_2 merged
|
||||
- Part(FC1)
|
||||
- Part(thought2) # thought2_1 + thought2_2 merged
|
||||
- Part(FC2)
|
||||
"""
|
||||
|
||||
# Create streaming chunks that test the ordering requirement
|
||||
chunk1 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[types.Part(text="Initial thought part 1. ", thought=True)],
|
||||
)
|
||||
)
|
||||
|
||||
chunk2 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[types.Part(text="Initial thought part 2.", thought=True)],
|
||||
)
|
||||
)
|
||||
|
||||
chunk3 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[types.Part.from_text(text="Let me check Tokyo. ")],
|
||||
)
|
||||
)
|
||||
|
||||
chunk4 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part.from_text(text="And New York too.")]
|
||||
)
|
||||
)
|
||||
|
||||
chunk5 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part.from_function_call(
|
||||
name="get_weather", args={"location": "Tokyo"}
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
chunk6 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part(
|
||||
text="Now processing second thought part 1. ", thought=True
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
chunk7 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[types.Part(text="Second thought part 2.", thought=True)],
|
||||
)
|
||||
)
|
||||
|
||||
chunk8 = LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part.from_function_call(
|
||||
name="get_weather", args={"location": "New York"}
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
|
||||
mock_model = StreamingMockModel(
|
||||
stream_chunks=[
|
||||
chunk1,
|
||||
chunk2,
|
||||
chunk3,
|
||||
chunk4,
|
||||
chunk5,
|
||||
chunk6,
|
||||
chunk7,
|
||||
chunk8,
|
||||
]
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
name="ordering_test_agent",
|
||||
model=mock_model,
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
run_config = RunConfig(streaming_mode=StreamingMode.SSE)
|
||||
|
||||
# Use the real InMemoryRunner to get access to run_config parameter
|
||||
runner = InMemoryRunner(agent=agent)
|
||||
|
||||
# Create session manually
|
||||
session = runner.session_service.create_session_sync(
|
||||
app_name=runner.app_name, user_id="test_user"
|
||||
)
|
||||
|
||||
events = []
|
||||
for event in runner.run(
|
||||
user_id="test_user",
|
||||
session_id=session.id,
|
||||
new_message=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="What is the weather?")],
|
||||
),
|
||||
run_config=run_config,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Find the final aggregated model event (partial=False, from model)
|
||||
aggregated_event = None
|
||||
for event in events:
|
||||
if (
|
||||
not event.partial
|
||||
and event.author == "ordering_test_agent"
|
||||
and event.content
|
||||
and len(event.content.parts) > 2
|
||||
):
|
||||
aggregated_event = event
|
||||
break
|
||||
|
||||
assert aggregated_event is not None, "Should find an aggregated model event"
|
||||
|
||||
# Verify the part ordering
|
||||
parts = aggregated_event.content.parts
|
||||
assert len(parts) == 5, f"Expected 5 parts, got {len(parts)}"
|
||||
|
||||
# Part 0: First thought (merged from chunk1 + chunk2)
|
||||
assert parts[0].thought
|
||||
assert parts[0].text == "Initial thought part 1. Initial thought part 2."
|
||||
|
||||
# Part 1: Regular text (merged from chunk3 + chunk4)
|
||||
assert not parts[1].thought
|
||||
assert parts[1].text == "Let me check Tokyo. And New York too."
|
||||
|
||||
# Part 2: First function call (from chunk5)
|
||||
assert parts[2].function_call.name == "get_weather"
|
||||
assert parts[2].function_call.args["location"] == "Tokyo"
|
||||
|
||||
# Part 3: Second thought (merged from chunk6 + chunk7)
|
||||
assert parts[3].thought
|
||||
assert (
|
||||
parts[3].text
|
||||
== "Now processing second thought part 1. Second thought part 2."
|
||||
)
|
||||
|
||||
# Part 4: Second function call (from chunk8)
|
||||
assert parts[4].function_call.name == "get_weather"
|
||||
assert parts[4].function_call.args["location"] == "New York"
|
||||
Reference in New Issue
Block a user