diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 433d7f52..fbbf2c00 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -587,21 +587,19 @@ class BaseLlmFlow(ABC): # Handle transcription events ONCE per llm_response, outside the event loop if llm_response.input_transcription: - input_transcription_event = ( - await self.transcription_manager.handle_input_transcription( - invocation_context, llm_response.input_transcription - ) + model_response_event.input_transcription = ( + llm_response.input_transcription ) - yield input_transcription_event + model_response_event.partial = llm_response.partial + yield model_response_event return if llm_response.output_transcription: - output_transcription_event = ( - await self.transcription_manager.handle_output_transcription( - invocation_context, llm_response.output_transcription - ) + model_response_event.output_transcription = ( + llm_response.output_transcription ) - yield output_transcription_event + model_response_event.partial = llm_response.partial + yield model_response_event return # Flush audio caches based on control events using configurable settings diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 028cdf0d..3ffdc102 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -35,6 +35,8 @@ class GeminiLlmConnection(BaseLlmConnection): def __init__(self, gemini_session: live.AsyncSession): self._gemini_session = gemini_session + self._input_transcription_text: str = '' + self._output_transcription_text: str = '' async def send_history(self, history: list[types.Content]): """Sends the conversation history to the gemini model. @@ -166,15 +168,49 @@ class GeminiLlmConnection(BaseLlmConnection): text = '' yield llm_response if message.server_content.input_transcription: - llm_response = LlmResponse( - input_transcription=message.server_content.input_transcription, - ) - yield llm_response + if message.server_content.input_transcription.text: + self._input_transcription_text += ( + message.server_content.input_transcription.text + ) + yield LlmResponse( + input_transcription=types.Transcription( + text=message.server_content.input_transcription.text, + finished=False, + ), + partial=True, + ) + # finished=True and partial transcription may happen in the same + # message. + if message.server_content.input_transcription.finished: + yield LlmResponse( + input_transcription=types.Transcription( + text=self._input_transcription_text, + finished=True, + ), + partial=False, + ) + self._input_transcription_text = '' if message.server_content.output_transcription: - llm_response = LlmResponse( - output_transcription=message.server_content.output_transcription - ) - yield llm_response + if message.server_content.output_transcription.text: + self._output_transcription_text += ( + message.server_content.output_transcription.text + ) + yield LlmResponse( + output_transcription=types.Transcription( + text=message.server_content.output_transcription.text, + finished=False, + ), + partial=True, + ) + if message.server_content.output_transcription.finished: + yield LlmResponse( + output_transcription=types.Transcription( + text=self._output_transcription_text, + finished=True, + ), + partial=False, + ) + self._output_transcription_text = '' if message.server_content.turn_complete: if text: yield self.__build_full_text_response(text) @@ -188,10 +224,12 @@ class GeminiLlmConnection(BaseLlmConnection): # in case it's an interrupted message, we merge the previous partial # text. Other we don't merge. because content can be none when model # safety threshold is triggered - if message.server_content.interrupted and text: - yield self.__build_full_text_response(text) - text = '' - yield LlmResponse(interrupted=message.server_content.interrupted) + if message.server_content.interrupted: + if text: + yield self.__build_full_text_response(text) + text = '' + else: + yield LlmResponse(interrupted=message.server_content.interrupted) if message.tool_call: if text: yield self.__build_full_text_response(text) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 84e5e3fa..1b624158 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -588,6 +588,10 @@ class Runner: # Don't append audio response from model in live mode to session. # The data is appended to artifacts with a reference in file_data in the # event. + # We should append non-partial events only.For example, non-finished(partial) + # transcription events should not be appended. + # Function call and function response events should be appended. + # Other control events should be appended. if is_live_call and contents._is_live_model_audio_event(event): return False return True diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 23e8697f..e706a972 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -219,3 +219,135 @@ async def test_receive_usage_metadata_and_server_content( ) assert usage_response.usage_metadata == expected_usage assert content_response.content == mock_content + + +@pytest.mark.asyncio +async def test_receive_handles_input_transcription_fragments( + gemini_connection, mock_gemini_session +): + """Test receive handles input transcription fragments correctly.""" + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = None + message1.server_content.interrupted = False + message1.server_content.input_transcription = types.Transcription( + text='Hello', finished=False + ) + message1.server_content.output_transcription = None + message1.server_content.turn_complete = False + message1.tool_call = None + message1.session_resumption_update = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = None + message2.server_content.interrupted = False + message2.server_content.input_transcription = types.Transcription( + text=' world', finished=False + ) + message2.server_content.output_transcription = None + message2.server_content.turn_complete = False + message2.tool_call = None + message2.session_resumption_update = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None + message3.server_content.interrupted = False + message3.server_content.input_transcription = types.Transcription( + text=None, finished=True + ) + message3.server_content.output_transcription = None + message3.server_content.turn_complete = False + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 3 + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is False + assert responses[0].partial is True + assert responses[1].input_transcription.text == ' world' + assert responses[1].input_transcription.finished is False + assert responses[1].partial is True + assert responses[2].input_transcription.text == 'Hello world' + assert responses[2].input_transcription.finished is True + assert responses[2].partial is False + + +@pytest.mark.asyncio +async def test_receive_handles_output_transcription_fragments( + gemini_connection, mock_gemini_session +): + """Test receive handles output transcription fragments correctly.""" + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = None + message1.server_content.interrupted = False + message1.server_content.input_transcription = None + message1.server_content.output_transcription = types.Transcription( + text='How can', finished=False + ) + message1.server_content.turn_complete = False + message1.tool_call = None + message1.session_resumption_update = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = None + message2.server_content.interrupted = False + message2.server_content.input_transcription = None + message2.server_content.output_transcription = types.Transcription( + text=' I help?', finished=False + ) + message2.server_content.turn_complete = False + message2.tool_call = None + message2.session_resumption_update = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None + message3.server_content.interrupted = False + message3.server_content.input_transcription = None + message3.server_content.output_transcription = types.Transcription( + text=None, finished=True + ) + message3.server_content.turn_complete = False + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 3 + assert responses[0].output_transcription.text == 'How can' + assert responses[0].output_transcription.finished is False + assert responses[0].partial is True + assert responses[1].output_transcription.text == ' I help?' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + assert responses[2].output_transcription.text == 'How can I help?' + assert responses[2].output_transcription.finished is True + assert responses[2].partial is False diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index d78061dd..cbea0d5a 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -775,5 +775,89 @@ class TestRunnerCacheConfig: assert str(runner.context_cache_config) == expected_str +class TestRunnerShouldAppendEvent: + """Tests for Runner._should_append_event method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.session_service = InMemorySessionService() + self.artifact_service = InMemoryArtifactService() + self.root_agent = MockLlmAgent("root_agent") + self.runner = Runner( + app_name="test_app", + agent=self.root_agent, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + def test_should_append_event_finished_input_transcription(self): + event = Event( + invocation_id="inv1", + author="user", + input_transcription=types.Transcription(text="hello", finished=True), + ) + assert self.runner._should_append_event(event, is_live_call=True) is True + + def test_should_append_event_unfinished_input_transcription(self): + event = Event( + invocation_id="inv1", + author="user", + input_transcription=types.Transcription(text="hello", finished=False), + ) + assert self.runner._should_append_event(event, is_live_call=True) is True + + def test_should_append_event_finished_output_transcription(self): + event = Event( + invocation_id="inv1", + author="model", + output_transcription=types.Transcription(text="world", finished=True), + ) + assert self.runner._should_append_event(event, is_live_call=True) is True + + def test_should_append_event_unfinished_output_transcription(self): + event = Event( + invocation_id="inv1", + author="model", + output_transcription=types.Transcription(text="world", finished=False), + ) + assert self.runner._should_append_event(event, is_live_call=True) is True + + def test_should_not_append_event_live_model_audio(self): + event = Event( + invocation_id="inv1", + author="model", + content=types.Content( + parts=[ + types.Part( + inline_data=types.Blob(data=b"123", mime_type="audio/pcm") + ) + ] + ), + ) + assert self.runner._should_append_event(event, is_live_call=True) is False + + def test_should_append_event_non_live_model_audio(self): + event = Event( + invocation_id="inv1", + author="model", + content=types.Content( + parts=[ + types.Part( + inline_data=types.Blob(data=b"123", mime_type="audio/pcm") + ) + ] + ), + ) + assert self.runner._should_append_event(event, is_live_call=False) is True + + def test_should_append_event_other_event(self): + event = Event( + invocation_id="inv1", + author="model", + content=types.Content(parts=[types.Part(text="text")]), + ) + assert self.runner._should_append_event(event, is_live_call=True) is True + + if __name__ == "__main__": pytest.main([__file__])