You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
refactor: Extract a utility for aggregating partial streaming responses and emitting LlmResponses for them as needed
PiperOrigin-RevId: 800521404
This commit is contained in:
committed by
Copybara-Service
parent
3bc2d77b4d
commit
7975e8e196
@@ -144,6 +144,8 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
|
||||
text = ''
|
||||
async with Aclosing(self._gemini_session.receive()) as agen:
|
||||
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
|
||||
# partial content and emit responses as needed.
|
||||
async for message in agen:
|
||||
logger.debug('Got LLM Live message: %s', message)
|
||||
if message.server_content:
|
||||
|
||||
@@ -33,6 +33,7 @@ from typing_extensions import override
|
||||
|
||||
from .. import version
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ..utils.streaming_utils import StreamingResponseAggregator
|
||||
from ..utils.variant_utils import GoogleLLMVariant
|
||||
from .base_llm import BaseLlm
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
@@ -133,68 +134,23 @@ class Gemini(BaseLlm):
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
response = None
|
||||
thought_text = ''
|
||||
text = ''
|
||||
usage_metadata = None
|
||||
|
||||
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
|
||||
# we need to mark those text content as partial and after all partial
|
||||
# contents are sent, we send an accumulated event which contains all the
|
||||
# previous partial content. The only difference is bidi rely on
|
||||
# complete_turn flag to detect end while sse depends on finish_reason.
|
||||
aggregator = StreamingResponseAggregator()
|
||||
async with Aclosing(responses) as agen:
|
||||
async for response in agen:
|
||||
logger.debug(_build_response_log(response))
|
||||
llm_response = LlmResponse.create(response)
|
||||
usage_metadata = llm_response.usage_metadata
|
||||
if (
|
||||
llm_response.content
|
||||
and llm_response.content.parts
|
||||
and llm_response.content.parts[0].text
|
||||
):
|
||||
part0 = llm_response.content.parts[0]
|
||||
if part0.thought:
|
||||
thought_text += part0.text
|
||||
else:
|
||||
text += part0.text
|
||||
llm_response.partial = True
|
||||
elif (thought_text or text) and (
|
||||
not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
# don't yield the merged text event when receiving audio data
|
||||
or not llm_response.content.parts[0].inline_data
|
||||
):
|
||||
parts = []
|
||||
if thought_text:
|
||||
parts.append(types.Part(text=thought_text, thought=True))
|
||||
if text:
|
||||
parts.append(types.Part.from_text(text=text))
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(parts=parts),
|
||||
usage_metadata=llm_response.usage_metadata,
|
||||
)
|
||||
thought_text = ''
|
||||
text = ''
|
||||
yield llm_response
|
||||
|
||||
# generate an aggregated content at the end regardless the
|
||||
# response.candidates[0].finish_reason
|
||||
if (text or thought_text) and response and response.candidates:
|
||||
parts = []
|
||||
if thought_text:
|
||||
parts.append(types.Part(text=thought_text, thought=True))
|
||||
if text:
|
||||
parts.append(types.Part.from_text(text=text))
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(parts=parts),
|
||||
error_code=None
|
||||
if response.candidates[0].finish_reason == FinishReason.STOP
|
||||
else response.candidates[0].finish_reason,
|
||||
error_message=None
|
||||
if response.candidates[0].finish_reason == FinishReason.STOP
|
||||
else response.candidates[0].finish_message,
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
async with Aclosing(
|
||||
aggregator.process_response(response)
|
||||
) as aggregator_gen:
|
||||
async for llm_response in aggregator_gen:
|
||||
yield llm_response
|
||||
if (close_result := aggregator.close()) is not None:
|
||||
yield close_result
|
||||
|
||||
else:
|
||||
response = await self.api_client.aio.models.generate_content(
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..models.llm_response import LlmResponse
|
||||
|
||||
|
||||
class StreamingResponseAggregator:
|
||||
"""Aggregates partial streaming responses.
|
||||
|
||||
It aggregates content from partial responses, and generates LlmResponses for
|
||||
individual (partial) model responses, as well as for aggregated content.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._text = ''
|
||||
self._thought_text = ''
|
||||
self._usage_metadata = None
|
||||
self._response = None
|
||||
|
||||
async def process_response(
|
||||
self, response: types.GenerateContentResponse
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Processes a single model response.
|
||||
|
||||
Args:
|
||||
response: The response to process.
|
||||
|
||||
Yields:
|
||||
The generated LlmResponse(s), for the partial response, and the aggregated
|
||||
response if needed.
|
||||
"""
|
||||
# results = []
|
||||
self._response = response
|
||||
llm_response = LlmResponse.create(response)
|
||||
self._usage_metadata = llm_response.usage_metadata
|
||||
if (
|
||||
llm_response.content
|
||||
and llm_response.content.parts
|
||||
and llm_response.content.parts[0].text
|
||||
):
|
||||
part0 = llm_response.content.parts[0]
|
||||
if part0.thought:
|
||||
self._thought_text += part0.text
|
||||
else:
|
||||
self._text += part0.text
|
||||
llm_response.partial = True
|
||||
elif (self._thought_text or self._text) and (
|
||||
not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
# don't yield the merged text event when receiving audio data
|
||||
or not llm_response.content.parts[0].inline_data
|
||||
):
|
||||
parts = []
|
||||
if self._thought_text:
|
||||
parts.append(types.Part(text=self._thought_text, thought=True))
|
||||
if self._text:
|
||||
parts.append(types.Part.from_text(text=self._text))
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(parts=parts),
|
||||
usage_metadata=llm_response.usage_metadata,
|
||||
)
|
||||
self._thought_text = ''
|
||||
self._text = ''
|
||||
yield llm_response
|
||||
|
||||
def close(self) -> Optional[LlmResponse]:
|
||||
"""Generate an aggregated response at the end, if needed.
|
||||
|
||||
This should be called after all the model responses are processed.
|
||||
|
||||
Returns:
|
||||
The aggregated LlmResponse.
|
||||
"""
|
||||
if (
|
||||
(self._text or self._thought_text)
|
||||
and self._response
|
||||
and self._response.candidates
|
||||
):
|
||||
parts = []
|
||||
if self._thought_text:
|
||||
parts.append(types.Part(text=self._thought_text, thought=True))
|
||||
if self._text:
|
||||
parts.append(types.Part.from_text(text=self._text))
|
||||
candidate = self._response.candidates[0]
|
||||
return LlmResponse(
|
||||
content=types.ModelContent(parts=parts),
|
||||
error_code=None
|
||||
if candidate.finish_reason == types.FinishReason.STOP
|
||||
else candidate.finish_reason,
|
||||
error_message=None
|
||||
if candidate.finish_reason == types.FinishReason.STOP
|
||||
else candidate.finish_message,
|
||||
usage_metadata=self._usage_metadata,
|
||||
)
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.utils import streaming_utils
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStreamingResponseAggregator:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_response_with_text(self):
|
||||
aggregator = streaming_utils.StreamingResponseAggregator()
|
||||
response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text="Hello")])
|
||||
)
|
||||
]
|
||||
)
|
||||
results = []
|
||||
async for r in aggregator.process_response(response):
|
||||
results.append(r)
|
||||
assert len(results) == 1
|
||||
assert results[0].content.parts[0].text == "Hello"
|
||||
assert results[0].partial
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_response_with_thought(self):
|
||||
aggregator = streaming_utils.StreamingResponseAggregator()
|
||||
response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=[types.Part(text="Thinking...", thought=True)]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
results = []
|
||||
async for r in aggregator.process_response(response):
|
||||
results.append(r)
|
||||
assert len(results) == 1
|
||||
assert results[0].content.parts[0].text == "Thinking..."
|
||||
assert results[0].content.parts[0].thought
|
||||
assert results[0].partial
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_response_multiple(self):
|
||||
aggregator = streaming_utils.StreamingResponseAggregator()
|
||||
response1 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text="Hello ")])
|
||||
)
|
||||
]
|
||||
)
|
||||
response2 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text="World!")])
|
||||
)
|
||||
]
|
||||
)
|
||||
async for _ in aggregator.process_response(response1):
|
||||
pass
|
||||
results = []
|
||||
async for r in aggregator.process_response(response2):
|
||||
results.append(r)
|
||||
assert len(results) == 1
|
||||
assert results[0].content.parts[0].text == "World!"
|
||||
|
||||
closed_response = aggregator.close()
|
||||
assert closed_response is not None
|
||||
assert closed_response.content.parts[0].text == "Hello World!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_response_interleaved_thought_and_text(self):
|
||||
aggregator = streaming_utils.StreamingResponseAggregator()
|
||||
response1 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=[types.Part(text="I am thinking...", thought=True)]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
response2 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=[types.Part(text="Okay, I have a result.")]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
response3 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=[types.Part(text=" The result is 42.")]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async for _ in aggregator.process_response(response1):
|
||||
pass
|
||||
async for _ in aggregator.process_response(response2):
|
||||
pass
|
||||
async for _ in aggregator.process_response(response3):
|
||||
pass
|
||||
|
||||
closed_response = aggregator.close()
|
||||
assert closed_response is not None
|
||||
assert len(closed_response.content.parts) == 2
|
||||
assert closed_response.content.parts[0].text == "I am thinking..."
|
||||
assert closed_response.content.parts[0].thought
|
||||
assert (
|
||||
closed_response.content.parts[1].text
|
||||
== "Okay, I have a result. The result is 42."
|
||||
)
|
||||
assert not closed_response.content.parts[1].thought
|
||||
|
||||
def test_close_with_no_responses(self):
|
||||
aggregator = streaming_utils.StreamingResponseAggregator()
|
||||
closed_response = aggregator.close()
|
||||
assert closed_response is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_with_finish_reason(self):
|
||||
aggregator = streaming_utils.StreamingResponseAggregator()
|
||||
response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text="Hello")]),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
)
|
||||
async for _ in aggregator.process_response(response):
|
||||
pass
|
||||
closed_response = aggregator.close()
|
||||
assert closed_response is not None
|
||||
assert closed_response.content.parts[0].text == "Hello"
|
||||
assert closed_response.error_code is None
|
||||
assert closed_response.error_message is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_with_error(self):
|
||||
aggregator = streaming_utils.StreamingResponseAggregator()
|
||||
response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text="Error")]),
|
||||
finish_reason=types.FinishReason.RECITATION,
|
||||
finish_message="Recitation error",
|
||||
)
|
||||
]
|
||||
)
|
||||
async for _ in aggregator.process_response(response):
|
||||
pass
|
||||
closed_response = aggregator.close()
|
||||
assert closed_response is not None
|
||||
assert closed_response.content.parts[0].text == "Error"
|
||||
assert closed_response.error_code == types.FinishReason.RECITATION
|
||||
assert closed_response.error_message == "Recitation error"
|
||||
Reference in New Issue
Block a user