fix: Improve handling of partial and complete transcriptions in live calls

In `gemini_llm_connection.py`, accumulate partial transcription texts and emit `LlmResponse` with `partial=True` for each chunk. When the transcription is marked as `finished`, emit a final `LlmResponse` with the full accumulated text and `partial=False`.

In `runners.py`, modify `_should_append_to_history` to only add transcription events to the history when they are fully finished, preventing partial transcriptions from being added.

Co-authored-by: Hangfei Lin <hangfei@google.com>
PiperOrigin-RevId: 829029715
This commit is contained in:
Hangfei Lin
2025-11-06 11:08:47 -08:00
committed by Copybara-Service
parent 44d45fe9cd
commit 1819ecb4b8
5 changed files with 278 additions and 22 deletions
@@ -587,21 +587,19 @@ class BaseLlmFlow(ABC):
# Handle transcription events ONCE per llm_response, outside the event loop
if llm_response.input_transcription:
input_transcription_event = (
await self.transcription_manager.handle_input_transcription(
invocation_context, llm_response.input_transcription
)
model_response_event.input_transcription = (
llm_response.input_transcription
)
yield input_transcription_event
model_response_event.partial = llm_response.partial
yield model_response_event
return
if llm_response.output_transcription:
output_transcription_event = (
await self.transcription_manager.handle_output_transcription(
invocation_context, llm_response.output_transcription
)
model_response_event.output_transcription = (
llm_response.output_transcription
)
yield output_transcription_event
model_response_event.partial = llm_response.partial
yield model_response_event
return
# Flush audio caches based on control events using configurable settings
+50 -12
View File
@@ -35,6 +35,8 @@ class GeminiLlmConnection(BaseLlmConnection):
def __init__(self, gemini_session: live.AsyncSession):
self._gemini_session = gemini_session
self._input_transcription_text: str = ''
self._output_transcription_text: str = ''
async def send_history(self, history: list[types.Content]):
"""Sends the conversation history to the gemini model.
@@ -166,15 +168,49 @@ class GeminiLlmConnection(BaseLlmConnection):
text = ''
yield llm_response
if message.server_content.input_transcription:
llm_response = LlmResponse(
input_transcription=message.server_content.input_transcription,
)
yield llm_response
if message.server_content.input_transcription.text:
self._input_transcription_text += (
message.server_content.input_transcription.text
)
yield LlmResponse(
input_transcription=types.Transcription(
text=message.server_content.input_transcription.text,
finished=False,
),
partial=True,
)
# finished=True and partial transcription may happen in the same
# message.
if message.server_content.input_transcription.finished:
yield LlmResponse(
input_transcription=types.Transcription(
text=self._input_transcription_text,
finished=True,
),
partial=False,
)
self._input_transcription_text = ''
if message.server_content.output_transcription:
llm_response = LlmResponse(
output_transcription=message.server_content.output_transcription
)
yield llm_response
if message.server_content.output_transcription.text:
self._output_transcription_text += (
message.server_content.output_transcription.text
)
yield LlmResponse(
output_transcription=types.Transcription(
text=message.server_content.output_transcription.text,
finished=False,
),
partial=True,
)
if message.server_content.output_transcription.finished:
yield LlmResponse(
output_transcription=types.Transcription(
text=self._output_transcription_text,
finished=True,
),
partial=False,
)
self._output_transcription_text = ''
if message.server_content.turn_complete:
if text:
yield self.__build_full_text_response(text)
@@ -188,10 +224,12 @@ class GeminiLlmConnection(BaseLlmConnection):
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safety threshold is triggered
if message.server_content.interrupted and text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(interrupted=message.server_content.interrupted)
if message.server_content.interrupted:
if text:
yield self.__build_full_text_response(text)
text = ''
else:
yield LlmResponse(interrupted=message.server_content.interrupted)
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
+4
View File
@@ -588,6 +588,10 @@ class Runner:
# Don't append audio response from model in live mode to session.
# The data is appended to artifacts with a reference in file_data in the
# event.
# We should append non-partial events only.For example, non-finished(partial)
# transcription events should not be appended.
# Function call and function response events should be appended.
# Other control events should be appended.
if is_live_call and contents._is_live_model_audio_event(event):
return False
return True
@@ -219,3 +219,135 @@ async def test_receive_usage_metadata_and_server_content(
)
assert usage_response.usage_metadata == expected_usage
assert content_response.content == mock_content
@pytest.mark.asyncio
async def test_receive_handles_input_transcription_fragments(
gemini_connection, mock_gemini_session
):
"""Test receive handles input transcription fragments correctly."""
message1 = mock.Mock()
message1.usage_metadata = None
message1.server_content = mock.Mock()
message1.server_content.model_turn = None
message1.server_content.interrupted = False
message1.server_content.input_transcription = types.Transcription(
text='Hello', finished=False
)
message1.server_content.output_transcription = None
message1.server_content.turn_complete = False
message1.tool_call = None
message1.session_resumption_update = None
message2 = mock.Mock()
message2.usage_metadata = None
message2.server_content = mock.Mock()
message2.server_content.model_turn = None
message2.server_content.interrupted = False
message2.server_content.input_transcription = types.Transcription(
text=' world', finished=False
)
message2.server_content.output_transcription = None
message2.server_content.turn_complete = False
message2.tool_call = None
message2.session_resumption_update = None
message3 = mock.Mock()
message3.usage_metadata = None
message3.server_content = mock.Mock()
message3.server_content.model_turn = None
message3.server_content.interrupted = False
message3.server_content.input_transcription = types.Transcription(
text=None, finished=True
)
message3.server_content.output_transcription = None
message3.server_content.turn_complete = False
message3.tool_call = None
message3.session_resumption_update = None
async def mock_receive_generator():
yield message1
yield message2
yield message3
receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock
responses = [resp async for resp in gemini_connection.receive()]
assert len(responses) == 3
assert responses[0].input_transcription.text == 'Hello'
assert responses[0].input_transcription.finished is False
assert responses[0].partial is True
assert responses[1].input_transcription.text == ' world'
assert responses[1].input_transcription.finished is False
assert responses[1].partial is True
assert responses[2].input_transcription.text == 'Hello world'
assert responses[2].input_transcription.finished is True
assert responses[2].partial is False
@pytest.mark.asyncio
async def test_receive_handles_output_transcription_fragments(
gemini_connection, mock_gemini_session
):
"""Test receive handles output transcription fragments correctly."""
message1 = mock.Mock()
message1.usage_metadata = None
message1.server_content = mock.Mock()
message1.server_content.model_turn = None
message1.server_content.interrupted = False
message1.server_content.input_transcription = None
message1.server_content.output_transcription = types.Transcription(
text='How can', finished=False
)
message1.server_content.turn_complete = False
message1.tool_call = None
message1.session_resumption_update = None
message2 = mock.Mock()
message2.usage_metadata = None
message2.server_content = mock.Mock()
message2.server_content.model_turn = None
message2.server_content.interrupted = False
message2.server_content.input_transcription = None
message2.server_content.output_transcription = types.Transcription(
text=' I help?', finished=False
)
message2.server_content.turn_complete = False
message2.tool_call = None
message2.session_resumption_update = None
message3 = mock.Mock()
message3.usage_metadata = None
message3.server_content = mock.Mock()
message3.server_content.model_turn = None
message3.server_content.interrupted = False
message3.server_content.input_transcription = None
message3.server_content.output_transcription = types.Transcription(
text=None, finished=True
)
message3.server_content.turn_complete = False
message3.tool_call = None
message3.session_resumption_update = None
async def mock_receive_generator():
yield message1
yield message2
yield message3
receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock
responses = [resp async for resp in gemini_connection.receive()]
assert len(responses) == 3
assert responses[0].output_transcription.text == 'How can'
assert responses[0].output_transcription.finished is False
assert responses[0].partial is True
assert responses[1].output_transcription.text == ' I help?'
assert responses[1].output_transcription.finished is False
assert responses[1].partial is True
assert responses[2].output_transcription.text == 'How can I help?'
assert responses[2].output_transcription.finished is True
assert responses[2].partial is False
+84
View File
@@ -775,5 +775,89 @@ class TestRunnerCacheConfig:
assert str(runner.context_cache_config) == expected_str
class TestRunnerShouldAppendEvent:
"""Tests for Runner._should_append_event method."""
def setup_method(self):
"""Set up test fixtures."""
self.session_service = InMemorySessionService()
self.artifact_service = InMemoryArtifactService()
self.root_agent = MockLlmAgent("root_agent")
self.runner = Runner(
app_name="test_app",
agent=self.root_agent,
session_service=self.session_service,
artifact_service=self.artifact_service,
)
def test_should_append_event_finished_input_transcription(self):
event = Event(
invocation_id="inv1",
author="user",
input_transcription=types.Transcription(text="hello", finished=True),
)
assert self.runner._should_append_event(event, is_live_call=True) is True
def test_should_append_event_unfinished_input_transcription(self):
event = Event(
invocation_id="inv1",
author="user",
input_transcription=types.Transcription(text="hello", finished=False),
)
assert self.runner._should_append_event(event, is_live_call=True) is True
def test_should_append_event_finished_output_transcription(self):
event = Event(
invocation_id="inv1",
author="model",
output_transcription=types.Transcription(text="world", finished=True),
)
assert self.runner._should_append_event(event, is_live_call=True) is True
def test_should_append_event_unfinished_output_transcription(self):
event = Event(
invocation_id="inv1",
author="model",
output_transcription=types.Transcription(text="world", finished=False),
)
assert self.runner._should_append_event(event, is_live_call=True) is True
def test_should_not_append_event_live_model_audio(self):
event = Event(
invocation_id="inv1",
author="model",
content=types.Content(
parts=[
types.Part(
inline_data=types.Blob(data=b"123", mime_type="audio/pcm")
)
]
),
)
assert self.runner._should_append_event(event, is_live_call=True) is False
def test_should_append_event_non_live_model_audio(self):
event = Event(
invocation_id="inv1",
author="model",
content=types.Content(
parts=[
types.Part(
inline_data=types.Blob(data=b"123", mime_type="audio/pcm")
)
]
),
)
assert self.runner._should_append_event(event, is_live_call=False) is True
def test_should_append_event_other_event(self):
event = Event(
invocation_id="inv1",
author="model",
content=types.Content(parts=[types.Part(text="text")]),
)
assert self.runner._should_append_event(event, is_live_call=True) is True
if __name__ == "__main__":
pytest.main([__file__])