From 7975e8e1961c8e375e2af3506ea546580ff7e45d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 28 Aug 2025 10:28:29 -0700 Subject: [PATCH] refactor: Extract a utility for aggregating partial streaming responses and emitting LlmResponses for them as needed PiperOrigin-RevId: 800521404 --- .../adk/models/gemini_llm_connection.py | 2 + src/google/adk/models/google_llm.py | 64 +------ src/google/adk/utils/streaming_utils.py | 112 +++++++++++ tests/unittests/utils/test_streaming_utils.py | 181 ++++++++++++++++++ 4 files changed, 305 insertions(+), 54 deletions(-) create mode 100644 src/google/adk/utils/streaming_utils.py create mode 100644 tests/unittests/utils/test_streaming_utils.py diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index fd6f4a78..0a4ecbb1 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -144,6 +144,8 @@ class GeminiLlmConnection(BaseLlmConnection): text = '' async with Aclosing(self._gemini_session.receive()) as agen: + # TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate + # partial content and emit responses as needed. async for message in agen: logger.debug('Got LLM Live message: %s', message) if message.server_content: diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 86515db1..be2238b4 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -33,6 +33,7 @@ from typing_extensions import override from .. import version from ..utils.context_utils import Aclosing +from ..utils.streaming_utils import StreamingResponseAggregator from ..utils.variant_utils import GoogleLLMVariant from .base_llm import BaseLlm from .base_llm_connection import BaseLlmConnection @@ -133,68 +134,23 @@ class Gemini(BaseLlm): contents=llm_request.contents, config=llm_request.config, ) - response = None - thought_text = '' - text = '' - usage_metadata = None + # for sse, similar as bidi (see receive method in gemini_llm_connecton.py), # we need to mark those text content as partial and after all partial # contents are sent, we send an accumulated event which contains all the # previous partial content. The only difference is bidi rely on # complete_turn flag to detect end while sse depends on finish_reason. + aggregator = StreamingResponseAggregator() async with Aclosing(responses) as agen: async for response in agen: logger.debug(_build_response_log(response)) - llm_response = LlmResponse.create(response) - usage_metadata = llm_response.usage_metadata - if ( - llm_response.content - and llm_response.content.parts - and llm_response.content.parts[0].text - ): - part0 = llm_response.content.parts[0] - if part0.thought: - thought_text += part0.text - else: - text += part0.text - llm_response.partial = True - elif (thought_text or text) and ( - not llm_response.content - or not llm_response.content.parts - # don't yield the merged text event when receiving audio data - or not llm_response.content.parts[0].inline_data - ): - parts = [] - if thought_text: - parts.append(types.Part(text=thought_text, thought=True)) - if text: - parts.append(types.Part.from_text(text=text)) - yield LlmResponse( - content=types.ModelContent(parts=parts), - usage_metadata=llm_response.usage_metadata, - ) - thought_text = '' - text = '' - yield llm_response - - # generate an aggregated content at the end regardless the - # response.candidates[0].finish_reason - if (text or thought_text) and response and response.candidates: - parts = [] - if thought_text: - parts.append(types.Part(text=thought_text, thought=True)) - if text: - parts.append(types.Part.from_text(text=text)) - yield LlmResponse( - content=types.ModelContent(parts=parts), - error_code=None - if response.candidates[0].finish_reason == FinishReason.STOP - else response.candidates[0].finish_reason, - error_message=None - if response.candidates[0].finish_reason == FinishReason.STOP - else response.candidates[0].finish_message, - usage_metadata=usage_metadata, - ) + async with Aclosing( + aggregator.process_response(response) + ) as aggregator_gen: + async for llm_response in aggregator_gen: + yield llm_response + if (close_result := aggregator.close()) is not None: + yield close_result else: response = await self.api_client.aio.models.generate_content( diff --git a/src/google/adk/utils/streaming_utils.py b/src/google/adk/utils/streaming_utils.py new file mode 100644 index 00000000..21bcd57a --- /dev/null +++ b/src/google/adk/utils/streaming_utils.py @@ -0,0 +1,112 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import AsyncGenerator +from typing import Optional + +from google.genai import types + +from ..models.llm_response import LlmResponse + + +class StreamingResponseAggregator: + """Aggregates partial streaming responses. + + It aggregates content from partial responses, and generates LlmResponses for + individual (partial) model responses, as well as for aggregated content. + """ + + def __init__(self): + self._text = '' + self._thought_text = '' + self._usage_metadata = None + self._response = None + + async def process_response( + self, response: types.GenerateContentResponse + ) -> AsyncGenerator[LlmResponse, None]: + """Processes a single model response. + + Args: + response: The response to process. + + Yields: + The generated LlmResponse(s), for the partial response, and the aggregated + response if needed. + """ + # results = [] + self._response = response + llm_response = LlmResponse.create(response) + self._usage_metadata = llm_response.usage_metadata + if ( + llm_response.content + and llm_response.content.parts + and llm_response.content.parts[0].text + ): + part0 = llm_response.content.parts[0] + if part0.thought: + self._thought_text += part0.text + else: + self._text += part0.text + llm_response.partial = True + elif (self._thought_text or self._text) and ( + not llm_response.content + or not llm_response.content.parts + # don't yield the merged text event when receiving audio data + or not llm_response.content.parts[0].inline_data + ): + parts = [] + if self._thought_text: + parts.append(types.Part(text=self._thought_text, thought=True)) + if self._text: + parts.append(types.Part.from_text(text=self._text)) + yield LlmResponse( + content=types.ModelContent(parts=parts), + usage_metadata=llm_response.usage_metadata, + ) + self._thought_text = '' + self._text = '' + yield llm_response + + def close(self) -> Optional[LlmResponse]: + """Generate an aggregated response at the end, if needed. + + This should be called after all the model responses are processed. + + Returns: + The aggregated LlmResponse. + """ + if ( + (self._text or self._thought_text) + and self._response + and self._response.candidates + ): + parts = [] + if self._thought_text: + parts.append(types.Part(text=self._thought_text, thought=True)) + if self._text: + parts.append(types.Part.from_text(text=self._text)) + candidate = self._response.candidates[0] + return LlmResponse( + content=types.ModelContent(parts=parts), + error_code=None + if candidate.finish_reason == types.FinishReason.STOP + else candidate.finish_reason, + error_message=None + if candidate.finish_reason == types.FinishReason.STOP + else candidate.finish_message, + usage_metadata=self._usage_metadata, + ) diff --git a/tests/unittests/utils/test_streaming_utils.py b/tests/unittests/utils/test_streaming_utils.py new file mode 100644 index 00000000..057c05e9 --- /dev/null +++ b/tests/unittests/utils/test_streaming_utils.py @@ -0,0 +1,181 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.utils import streaming_utils +from google.genai import types +import pytest + + +class TestStreamingResponseAggregator: + + @pytest.mark.asyncio + async def test_process_response_with_text(self): + aggregator = streaming_utils.StreamingResponseAggregator() + response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text="Hello")]) + ) + ] + ) + results = [] + async for r in aggregator.process_response(response): + results.append(r) + assert len(results) == 1 + assert results[0].content.parts[0].text == "Hello" + assert results[0].partial + + @pytest.mark.asyncio + async def test_process_response_with_thought(self): + aggregator = streaming_utils.StreamingResponseAggregator() + response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + parts=[types.Part(text="Thinking...", thought=True)] + ) + ) + ] + ) + results = [] + async for r in aggregator.process_response(response): + results.append(r) + assert len(results) == 1 + assert results[0].content.parts[0].text == "Thinking..." + assert results[0].content.parts[0].thought + assert results[0].partial + + @pytest.mark.asyncio + async def test_process_response_multiple(self): + aggregator = streaming_utils.StreamingResponseAggregator() + response1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text="Hello ")]) + ) + ] + ) + response2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text="World!")]) + ) + ] + ) + async for _ in aggregator.process_response(response1): + pass + results = [] + async for r in aggregator.process_response(response2): + results.append(r) + assert len(results) == 1 + assert results[0].content.parts[0].text == "World!" + + closed_response = aggregator.close() + assert closed_response is not None + assert closed_response.content.parts[0].text == "Hello World!" + + @pytest.mark.asyncio + async def test_process_response_interleaved_thought_and_text(self): + aggregator = streaming_utils.StreamingResponseAggregator() + response1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + parts=[types.Part(text="I am thinking...", thought=True)] + ) + ) + ] + ) + response2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + parts=[types.Part(text="Okay, I have a result.")] + ) + ) + ] + ) + response3 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + parts=[types.Part(text=" The result is 42.")] + ) + ) + ] + ) + + async for _ in aggregator.process_response(response1): + pass + async for _ in aggregator.process_response(response2): + pass + async for _ in aggregator.process_response(response3): + pass + + closed_response = aggregator.close() + assert closed_response is not None + assert len(closed_response.content.parts) == 2 + assert closed_response.content.parts[0].text == "I am thinking..." + assert closed_response.content.parts[0].thought + assert ( + closed_response.content.parts[1].text + == "Okay, I have a result. The result is 42." + ) + assert not closed_response.content.parts[1].thought + + def test_close_with_no_responses(self): + aggregator = streaming_utils.StreamingResponseAggregator() + closed_response = aggregator.close() + assert closed_response is None + + @pytest.mark.asyncio + async def test_close_with_finish_reason(self): + aggregator = streaming_utils.StreamingResponseAggregator() + response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text="Hello")]), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + async for _ in aggregator.process_response(response): + pass + closed_response = aggregator.close() + assert closed_response is not None + assert closed_response.content.parts[0].text == "Hello" + assert closed_response.error_code is None + assert closed_response.error_message is None + + @pytest.mark.asyncio + async def test_close_with_error(self): + aggregator = streaming_utils.StreamingResponseAggregator() + response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text="Error")]), + finish_reason=types.FinishReason.RECITATION, + finish_message="Recitation error", + ) + ] + ) + async for _ in aggregator.process_response(response): + pass + closed_response = aggregator.close() + assert closed_response is not None + assert closed_response.content.parts[0].text == "Error" + assert closed_response.error_code == types.FinishReason.RECITATION + assert closed_response.error_message == "Recitation error"