From a04828dd8a848482acbd48acc7da432d0d2cb0aa Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 21 Jan 2026 12:10:37 -0800 Subject: [PATCH] feat: Persist user input content to session in live mode Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 859207592 --- .../adk/flows/llm_flows/base_llm_flow.py | 23 ++++- tests/unittests/streaming/test_streaming.py | 86 +++++++++++++++++++ tests/unittests/testing_utils.py | 4 + 3 files changed, 111 insertions(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index d57e31ff..759ac532 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -272,6 +272,25 @@ class BaseLlmFlow(ABC): await llm_connection.send_realtime(live_request.blob) if live_request.content: + content = live_request.content + # Persist user text content to session (similar to non-live mode) + # Skip function responses - they are already handled separately + is_function_response = content.parts and any( + part.function_response for part in content.parts + ) + if not is_function_response: + if not content.role: + content.role = 'user' + user_content_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author='user', + content=content, + ) + await invocation_context.session_service.append_event( + session=invocation_context.session, + event=user_content_event, + ) await llm_connection.send_content(live_request.content) async def _receive_from_model( @@ -391,8 +410,8 @@ class BaseLlmFlow(ABC): current_invocation=True, current_branch=True ) - # Long-running tool calls should have been handled before this point. - # If there are still long-running tool calls, it means the agent is paused + # Long running tool calls should have been handled before this point. + # If there are still long running tool calls, it means the agent is paused # before, and its branch hasn't been resumed yet. if ( invocation_context.is_resumable diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index 5ee4721c..e697eacd 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -1120,3 +1120,89 @@ def test_live_streaming_buffered_function_call_yielded_during_transcription(): assert ( function_response_found ), 'Buffered function_response event was not yielded.' + + +def test_live_streaming_text_content_persisted_in_session(): + """Test that user text content sent via send_content is persisted in session.""" + response1 = LlmResponse( + content=types.Content( + role='model', parts=[types.Part(text='Hello! How can I help you?')] + ), + turn_complete=True, + ) + + mock_model = testing_utils.MockModel.create([response1]) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[], + ) + + class CustomTestRunner(testing_utils.InMemoryRunner): + + def run_live_and_get_session( + self, + live_request_queue: LiveRequestQueue, + run_config: testing_utils.RunConfig = None, + ) -> tuple[list[testing_utils.Event], testing_utils.Session]: + collected_responses = [] + + async def consume_responses(session: testing_utils.Session): + run_res = self.runner.run_live( + session=session, + live_request_queue=live_request_queue, + run_config=run_config or testing_utils.RunConfig(), + ) + async for response in run_res: + collected_responses.append(response) + if len(collected_responses) >= 1: + return + + try: + session = self.session + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + # Get the updated session + updated_session = self.runner.session_service.get_session_sync( + app_name=self.app_name, + user_id=session.user_id, + session_id=session.id, + ) + return collected_responses, updated_session + + runner = CustomTestRunner(root_agent=root_agent) + live_request_queue = LiveRequestQueue() + + # Send text content (not audio blob) + user_text = 'Hello, this is a test message' + live_request_queue.send_content( + types.Content(role='user', parts=[types.Part(text=user_text)]) + ) + + res_events, session = runner.run_live_and_get_session(live_request_queue) + + assert res_events is not None, 'Expected a list of events, got None.' + + # Check that user text content was persisted in the session + user_content_found = False + for event in session.events: + if event.author == 'user' and event.content: + for part in event.content.parts: + if part.text and user_text in part.text: + user_content_found = True + break + + assert user_content_found, ( + f'Expected user text content "{user_text}" to be persisted in session. ' + f'Session events: {[e.content for e in session.events]}' + ) diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index f76668b1..4f9b8636 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -409,6 +409,10 @@ class MockLlmConnection(BaseLlmConnection): async def receive(self) -> AsyncGenerator[LlmResponse, None]: """Yield each of the pre-defined LlmResponses.""" for response in self.llm_responses: + # Yield control to allow other tasks (like send_task) to run first. + # This ensures user content gets persisted before the mock response + # is yielded. + await asyncio.sleep(0) yield response async def close(self):