You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Improve handling of partial and complete transcriptions in live calls
In `gemini_llm_connection.py`, accumulate partial transcription texts and emit `LlmResponse` with `partial=True` for each chunk. When the transcription is marked as `finished`, emit a final `LlmResponse` with the full accumulated text and `partial=False`. In `runners.py`, modify `_should_append_to_history` to only add transcription events to the history when they are fully finished, preventing partial transcriptions from being added. Co-authored-by: Hangfei Lin <hangfei@google.com> PiperOrigin-RevId: 829029715
This commit is contained in:
committed by
Copybara-Service
parent
44d45fe9cd
commit
1819ecb4b8
@@ -587,21 +587,19 @@ class BaseLlmFlow(ABC):
|
||||
|
||||
# Handle transcription events ONCE per llm_response, outside the event loop
|
||||
if llm_response.input_transcription:
|
||||
input_transcription_event = (
|
||||
await self.transcription_manager.handle_input_transcription(
|
||||
invocation_context, llm_response.input_transcription
|
||||
)
|
||||
model_response_event.input_transcription = (
|
||||
llm_response.input_transcription
|
||||
)
|
||||
yield input_transcription_event
|
||||
model_response_event.partial = llm_response.partial
|
||||
yield model_response_event
|
||||
return
|
||||
|
||||
if llm_response.output_transcription:
|
||||
output_transcription_event = (
|
||||
await self.transcription_manager.handle_output_transcription(
|
||||
invocation_context, llm_response.output_transcription
|
||||
)
|
||||
model_response_event.output_transcription = (
|
||||
llm_response.output_transcription
|
||||
)
|
||||
yield output_transcription_event
|
||||
model_response_event.partial = llm_response.partial
|
||||
yield model_response_event
|
||||
return
|
||||
|
||||
# Flush audio caches based on control events using configurable settings
|
||||
|
||||
@@ -35,6 +35,8 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
|
||||
def __init__(self, gemini_session: live.AsyncSession):
|
||||
self._gemini_session = gemini_session
|
||||
self._input_transcription_text: str = ''
|
||||
self._output_transcription_text: str = ''
|
||||
|
||||
async def send_history(self, history: list[types.Content]):
|
||||
"""Sends the conversation history to the gemini model.
|
||||
@@ -166,15 +168,49 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
text = ''
|
||||
yield llm_response
|
||||
if message.server_content.input_transcription:
|
||||
llm_response = LlmResponse(
|
||||
input_transcription=message.server_content.input_transcription,
|
||||
)
|
||||
yield llm_response
|
||||
if message.server_content.input_transcription.text:
|
||||
self._input_transcription_text += (
|
||||
message.server_content.input_transcription.text
|
||||
)
|
||||
yield LlmResponse(
|
||||
input_transcription=types.Transcription(
|
||||
text=message.server_content.input_transcription.text,
|
||||
finished=False,
|
||||
),
|
||||
partial=True,
|
||||
)
|
||||
# finished=True and partial transcription may happen in the same
|
||||
# message.
|
||||
if message.server_content.input_transcription.finished:
|
||||
yield LlmResponse(
|
||||
input_transcription=types.Transcription(
|
||||
text=self._input_transcription_text,
|
||||
finished=True,
|
||||
),
|
||||
partial=False,
|
||||
)
|
||||
self._input_transcription_text = ''
|
||||
if message.server_content.output_transcription:
|
||||
llm_response = LlmResponse(
|
||||
output_transcription=message.server_content.output_transcription
|
||||
)
|
||||
yield llm_response
|
||||
if message.server_content.output_transcription.text:
|
||||
self._output_transcription_text += (
|
||||
message.server_content.output_transcription.text
|
||||
)
|
||||
yield LlmResponse(
|
||||
output_transcription=types.Transcription(
|
||||
text=message.server_content.output_transcription.text,
|
||||
finished=False,
|
||||
),
|
||||
partial=True,
|
||||
)
|
||||
if message.server_content.output_transcription.finished:
|
||||
yield LlmResponse(
|
||||
output_transcription=types.Transcription(
|
||||
text=self._output_transcription_text,
|
||||
finished=True,
|
||||
),
|
||||
partial=False,
|
||||
)
|
||||
self._output_transcription_text = ''
|
||||
if message.server_content.turn_complete:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
@@ -188,10 +224,12 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
# in case it's an interrupted message, we merge the previous partial
|
||||
# text. Other we don't merge. because content can be none when model
|
||||
# safety threshold is triggered
|
||||
if message.server_content.interrupted and text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(interrupted=message.server_content.interrupted)
|
||||
if message.server_content.interrupted:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
else:
|
||||
yield LlmResponse(interrupted=message.server_content.interrupted)
|
||||
if message.tool_call:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
|
||||
@@ -588,6 +588,10 @@ class Runner:
|
||||
# Don't append audio response from model in live mode to session.
|
||||
# The data is appended to artifacts with a reference in file_data in the
|
||||
# event.
|
||||
# We should append non-partial events only.For example, non-finished(partial)
|
||||
# transcription events should not be appended.
|
||||
# Function call and function response events should be appended.
|
||||
# Other control events should be appended.
|
||||
if is_live_call and contents._is_live_model_audio_event(event):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -219,3 +219,135 @@ async def test_receive_usage_metadata_and_server_content(
|
||||
)
|
||||
assert usage_response.usage_metadata == expected_usage
|
||||
assert content_response.content == mock_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_input_transcription_fragments(
|
||||
gemini_connection, mock_gemini_session
|
||||
):
|
||||
"""Test receive handles input transcription fragments correctly."""
|
||||
message1 = mock.Mock()
|
||||
message1.usage_metadata = None
|
||||
message1.server_content = mock.Mock()
|
||||
message1.server_content.model_turn = None
|
||||
message1.server_content.interrupted = False
|
||||
message1.server_content.input_transcription = types.Transcription(
|
||||
text='Hello', finished=False
|
||||
)
|
||||
message1.server_content.output_transcription = None
|
||||
message1.server_content.turn_complete = False
|
||||
message1.tool_call = None
|
||||
message1.session_resumption_update = None
|
||||
|
||||
message2 = mock.Mock()
|
||||
message2.usage_metadata = None
|
||||
message2.server_content = mock.Mock()
|
||||
message2.server_content.model_turn = None
|
||||
message2.server_content.interrupted = False
|
||||
message2.server_content.input_transcription = types.Transcription(
|
||||
text=' world', finished=False
|
||||
)
|
||||
message2.server_content.output_transcription = None
|
||||
message2.server_content.turn_complete = False
|
||||
message2.tool_call = None
|
||||
message2.session_resumption_update = None
|
||||
|
||||
message3 = mock.Mock()
|
||||
message3.usage_metadata = None
|
||||
message3.server_content = mock.Mock()
|
||||
message3.server_content.model_turn = None
|
||||
message3.server_content.interrupted = False
|
||||
message3.server_content.input_transcription = types.Transcription(
|
||||
text=None, finished=True
|
||||
)
|
||||
message3.server_content.output_transcription = None
|
||||
message3.server_content.turn_complete = False
|
||||
message3.tool_call = None
|
||||
message3.session_resumption_update = None
|
||||
|
||||
async def mock_receive_generator():
|
||||
yield message1
|
||||
yield message2
|
||||
yield message3
|
||||
|
||||
receive_mock = mock.Mock(return_value=mock_receive_generator())
|
||||
mock_gemini_session.receive = receive_mock
|
||||
|
||||
responses = [resp async for resp in gemini_connection.receive()]
|
||||
|
||||
assert len(responses) == 3
|
||||
assert responses[0].input_transcription.text == 'Hello'
|
||||
assert responses[0].input_transcription.finished is False
|
||||
assert responses[0].partial is True
|
||||
assert responses[1].input_transcription.text == ' world'
|
||||
assert responses[1].input_transcription.finished is False
|
||||
assert responses[1].partial is True
|
||||
assert responses[2].input_transcription.text == 'Hello world'
|
||||
assert responses[2].input_transcription.finished is True
|
||||
assert responses[2].partial is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_output_transcription_fragments(
|
||||
gemini_connection, mock_gemini_session
|
||||
):
|
||||
"""Test receive handles output transcription fragments correctly."""
|
||||
message1 = mock.Mock()
|
||||
message1.usage_metadata = None
|
||||
message1.server_content = mock.Mock()
|
||||
message1.server_content.model_turn = None
|
||||
message1.server_content.interrupted = False
|
||||
message1.server_content.input_transcription = None
|
||||
message1.server_content.output_transcription = types.Transcription(
|
||||
text='How can', finished=False
|
||||
)
|
||||
message1.server_content.turn_complete = False
|
||||
message1.tool_call = None
|
||||
message1.session_resumption_update = None
|
||||
|
||||
message2 = mock.Mock()
|
||||
message2.usage_metadata = None
|
||||
message2.server_content = mock.Mock()
|
||||
message2.server_content.model_turn = None
|
||||
message2.server_content.interrupted = False
|
||||
message2.server_content.input_transcription = None
|
||||
message2.server_content.output_transcription = types.Transcription(
|
||||
text=' I help?', finished=False
|
||||
)
|
||||
message2.server_content.turn_complete = False
|
||||
message2.tool_call = None
|
||||
message2.session_resumption_update = None
|
||||
|
||||
message3 = mock.Mock()
|
||||
message3.usage_metadata = None
|
||||
message3.server_content = mock.Mock()
|
||||
message3.server_content.model_turn = None
|
||||
message3.server_content.interrupted = False
|
||||
message3.server_content.input_transcription = None
|
||||
message3.server_content.output_transcription = types.Transcription(
|
||||
text=None, finished=True
|
||||
)
|
||||
message3.server_content.turn_complete = False
|
||||
message3.tool_call = None
|
||||
message3.session_resumption_update = None
|
||||
|
||||
async def mock_receive_generator():
|
||||
yield message1
|
||||
yield message2
|
||||
yield message3
|
||||
|
||||
receive_mock = mock.Mock(return_value=mock_receive_generator())
|
||||
mock_gemini_session.receive = receive_mock
|
||||
|
||||
responses = [resp async for resp in gemini_connection.receive()]
|
||||
|
||||
assert len(responses) == 3
|
||||
assert responses[0].output_transcription.text == 'How can'
|
||||
assert responses[0].output_transcription.finished is False
|
||||
assert responses[0].partial is True
|
||||
assert responses[1].output_transcription.text == ' I help?'
|
||||
assert responses[1].output_transcription.finished is False
|
||||
assert responses[1].partial is True
|
||||
assert responses[2].output_transcription.text == 'How can I help?'
|
||||
assert responses[2].output_transcription.finished is True
|
||||
assert responses[2].partial is False
|
||||
|
||||
@@ -775,5 +775,89 @@ class TestRunnerCacheConfig:
|
||||
assert str(runner.context_cache_config) == expected_str
|
||||
|
||||
|
||||
class TestRunnerShouldAppendEvent:
|
||||
"""Tests for Runner._should_append_event method."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.session_service = InMemorySessionService()
|
||||
self.artifact_service = InMemoryArtifactService()
|
||||
self.root_agent = MockLlmAgent("root_agent")
|
||||
self.runner = Runner(
|
||||
app_name="test_app",
|
||||
agent=self.root_agent,
|
||||
session_service=self.session_service,
|
||||
artifact_service=self.artifact_service,
|
||||
)
|
||||
|
||||
def test_should_append_event_finished_input_transcription(self):
|
||||
event = Event(
|
||||
invocation_id="inv1",
|
||||
author="user",
|
||||
input_transcription=types.Transcription(text="hello", finished=True),
|
||||
)
|
||||
assert self.runner._should_append_event(event, is_live_call=True) is True
|
||||
|
||||
def test_should_append_event_unfinished_input_transcription(self):
|
||||
event = Event(
|
||||
invocation_id="inv1",
|
||||
author="user",
|
||||
input_transcription=types.Transcription(text="hello", finished=False),
|
||||
)
|
||||
assert self.runner._should_append_event(event, is_live_call=True) is True
|
||||
|
||||
def test_should_append_event_finished_output_transcription(self):
|
||||
event = Event(
|
||||
invocation_id="inv1",
|
||||
author="model",
|
||||
output_transcription=types.Transcription(text="world", finished=True),
|
||||
)
|
||||
assert self.runner._should_append_event(event, is_live_call=True) is True
|
||||
|
||||
def test_should_append_event_unfinished_output_transcription(self):
|
||||
event = Event(
|
||||
invocation_id="inv1",
|
||||
author="model",
|
||||
output_transcription=types.Transcription(text="world", finished=False),
|
||||
)
|
||||
assert self.runner._should_append_event(event, is_live_call=True) is True
|
||||
|
||||
def test_should_not_append_event_live_model_audio(self):
|
||||
event = Event(
|
||||
invocation_id="inv1",
|
||||
author="model",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
inline_data=types.Blob(data=b"123", mime_type="audio/pcm")
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
assert self.runner._should_append_event(event, is_live_call=True) is False
|
||||
|
||||
def test_should_append_event_non_live_model_audio(self):
|
||||
event = Event(
|
||||
invocation_id="inv1",
|
||||
author="model",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
inline_data=types.Blob(data=b"123", mime_type="audio/pcm")
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
assert self.runner._should_append_event(event, is_live_call=False) is True
|
||||
|
||||
def test_should_append_event_other_event(self):
|
||||
event = Event(
|
||||
invocation_id="inv1",
|
||||
author="model",
|
||||
content=types.Content(parts=[types.Part(text="text")]),
|
||||
)
|
||||
assert self.runner._should_append_event(event, is_live_call=True) is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user