refactor: Extract a utility for aggregating partial streaming responses and emitting LlmResponses for them as needed

PiperOrigin-RevId: 800521404
This commit is contained in:
Google Team Member
2025-08-28 10:28:29 -07:00
committed by Copybara-Service
parent 3bc2d77b4d
commit 7975e8e196
4 changed files with 305 additions and 54 deletions
@@ -144,6 +144,8 @@ class GeminiLlmConnection(BaseLlmConnection):
text = ''
async with Aclosing(self._gemini_session.receive()) as agen:
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
# partial content and emit responses as needed.
async for message in agen:
logger.debug('Got LLM Live message: %s', message)
if message.server_content:
+10 -54
View File
@@ -33,6 +33,7 @@ from typing_extensions import override
from .. import version
from ..utils.context_utils import Aclosing
from ..utils.streaming_utils import StreamingResponseAggregator
from ..utils.variant_utils import GoogleLLMVariant
from .base_llm import BaseLlm
from .base_llm_connection import BaseLlmConnection
@@ -133,68 +134,23 @@ class Gemini(BaseLlm):
contents=llm_request.contents,
config=llm_request.config,
)
response = None
thought_text = ''
text = ''
usage_metadata = None
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
# we need to mark those text content as partial and after all partial
# contents are sent, we send an accumulated event which contains all the
# previous partial content. The only difference is bidi rely on
# complete_turn flag to detect end while sse depends on finish_reason.
aggregator = StreamingResponseAggregator()
async with Aclosing(responses) as agen:
async for response in agen:
logger.debug(_build_response_log(response))
llm_response = LlmResponse.create(response)
usage_metadata = llm_response.usage_metadata
if (
llm_response.content
and llm_response.content.parts
and llm_response.content.parts[0].text
):
part0 = llm_response.content.parts[0]
if part0.thought:
thought_text += part0.text
else:
text += part0.text
llm_response.partial = True
elif (thought_text or text) and (
not llm_response.content
or not llm_response.content.parts
# don't yield the merged text event when receiving audio data
or not llm_response.content.parts[0].inline_data
):
parts = []
if thought_text:
parts.append(types.Part(text=thought_text, thought=True))
if text:
parts.append(types.Part.from_text(text=text))
yield LlmResponse(
content=types.ModelContent(parts=parts),
usage_metadata=llm_response.usage_metadata,
)
thought_text = ''
text = ''
yield llm_response
# generate an aggregated content at the end regardless the
# response.candidates[0].finish_reason
if (text or thought_text) and response and response.candidates:
parts = []
if thought_text:
parts.append(types.Part(text=thought_text, thought=True))
if text:
parts.append(types.Part.from_text(text=text))
yield LlmResponse(
content=types.ModelContent(parts=parts),
error_code=None
if response.candidates[0].finish_reason == FinishReason.STOP
else response.candidates[0].finish_reason,
error_message=None
if response.candidates[0].finish_reason == FinishReason.STOP
else response.candidates[0].finish_message,
usage_metadata=usage_metadata,
)
async with Aclosing(
aggregator.process_response(response)
) as aggregator_gen:
async for llm_response in aggregator_gen:
yield llm_response
if (close_result := aggregator.close()) is not None:
yield close_result
else:
response = await self.api_client.aio.models.generate_content(
+112
View File
@@ -0,0 +1,112 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import AsyncGenerator
from typing import Optional
from google.genai import types
from ..models.llm_response import LlmResponse
class StreamingResponseAggregator:
"""Aggregates partial streaming responses.
It aggregates content from partial responses, and generates LlmResponses for
individual (partial) model responses, as well as for aggregated content.
"""
def __init__(self):
self._text = ''
self._thought_text = ''
self._usage_metadata = None
self._response = None
async def process_response(
self, response: types.GenerateContentResponse
) -> AsyncGenerator[LlmResponse, None]:
"""Processes a single model response.
Args:
response: The response to process.
Yields:
The generated LlmResponse(s), for the partial response, and the aggregated
response if needed.
"""
# results = []
self._response = response
llm_response = LlmResponse.create(response)
self._usage_metadata = llm_response.usage_metadata
if (
llm_response.content
and llm_response.content.parts
and llm_response.content.parts[0].text
):
part0 = llm_response.content.parts[0]
if part0.thought:
self._thought_text += part0.text
else:
self._text += part0.text
llm_response.partial = True
elif (self._thought_text or self._text) and (
not llm_response.content
or not llm_response.content.parts
# don't yield the merged text event when receiving audio data
or not llm_response.content.parts[0].inline_data
):
parts = []
if self._thought_text:
parts.append(types.Part(text=self._thought_text, thought=True))
if self._text:
parts.append(types.Part.from_text(text=self._text))
yield LlmResponse(
content=types.ModelContent(parts=parts),
usage_metadata=llm_response.usage_metadata,
)
self._thought_text = ''
self._text = ''
yield llm_response
def close(self) -> Optional[LlmResponse]:
"""Generate an aggregated response at the end, if needed.
This should be called after all the model responses are processed.
Returns:
The aggregated LlmResponse.
"""
if (
(self._text or self._thought_text)
and self._response
and self._response.candidates
):
parts = []
if self._thought_text:
parts.append(types.Part(text=self._thought_text, thought=True))
if self._text:
parts.append(types.Part.from_text(text=self._text))
candidate = self._response.candidates[0]
return LlmResponse(
content=types.ModelContent(parts=parts),
error_code=None
if candidate.finish_reason == types.FinishReason.STOP
else candidate.finish_reason,
error_message=None
if candidate.finish_reason == types.FinishReason.STOP
else candidate.finish_message,
usage_metadata=self._usage_metadata,
)
@@ -0,0 +1,181 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.utils import streaming_utils
from google.genai import types
import pytest
class TestStreamingResponseAggregator:
@pytest.mark.asyncio
async def test_process_response_with_text(self):
aggregator = streaming_utils.StreamingResponseAggregator()
response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text="Hello")])
)
]
)
results = []
async for r in aggregator.process_response(response):
results.append(r)
assert len(results) == 1
assert results[0].content.parts[0].text == "Hello"
assert results[0].partial
@pytest.mark.asyncio
async def test_process_response_with_thought(self):
aggregator = streaming_utils.StreamingResponseAggregator()
response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
parts=[types.Part(text="Thinking...", thought=True)]
)
)
]
)
results = []
async for r in aggregator.process_response(response):
results.append(r)
assert len(results) == 1
assert results[0].content.parts[0].text == "Thinking..."
assert results[0].content.parts[0].thought
assert results[0].partial
@pytest.mark.asyncio
async def test_process_response_multiple(self):
aggregator = streaming_utils.StreamingResponseAggregator()
response1 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text="Hello ")])
)
]
)
response2 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text="World!")])
)
]
)
async for _ in aggregator.process_response(response1):
pass
results = []
async for r in aggregator.process_response(response2):
results.append(r)
assert len(results) == 1
assert results[0].content.parts[0].text == "World!"
closed_response = aggregator.close()
assert closed_response is not None
assert closed_response.content.parts[0].text == "Hello World!"
@pytest.mark.asyncio
async def test_process_response_interleaved_thought_and_text(self):
aggregator = streaming_utils.StreamingResponseAggregator()
response1 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
parts=[types.Part(text="I am thinking...", thought=True)]
)
)
]
)
response2 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
parts=[types.Part(text="Okay, I have a result.")]
)
)
]
)
response3 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
parts=[types.Part(text=" The result is 42.")]
)
)
]
)
async for _ in aggregator.process_response(response1):
pass
async for _ in aggregator.process_response(response2):
pass
async for _ in aggregator.process_response(response3):
pass
closed_response = aggregator.close()
assert closed_response is not None
assert len(closed_response.content.parts) == 2
assert closed_response.content.parts[0].text == "I am thinking..."
assert closed_response.content.parts[0].thought
assert (
closed_response.content.parts[1].text
== "Okay, I have a result. The result is 42."
)
assert not closed_response.content.parts[1].thought
def test_close_with_no_responses(self):
aggregator = streaming_utils.StreamingResponseAggregator()
closed_response = aggregator.close()
assert closed_response is None
@pytest.mark.asyncio
async def test_close_with_finish_reason(self):
aggregator = streaming_utils.StreamingResponseAggregator()
response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text="Hello")]),
finish_reason=types.FinishReason.STOP,
)
]
)
async for _ in aggregator.process_response(response):
pass
closed_response = aggregator.close()
assert closed_response is not None
assert closed_response.content.parts[0].text == "Hello"
assert closed_response.error_code is None
assert closed_response.error_message is None
@pytest.mark.asyncio
async def test_close_with_error(self):
aggregator = streaming_utils.StreamingResponseAggregator()
response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text="Error")]),
finish_reason=types.FinishReason.RECITATION,
finish_message="Recitation error",
)
]
)
async for _ in aggregator.process_response(response):
pass
closed_response = aggregator.close()
assert closed_response is not None
assert closed_response.content.parts[0].text == "Error"
assert closed_response.error_code == types.FinishReason.RECITATION
assert closed_response.error_message == "Recitation error"