feat: Persist user input content to session in live mode

Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com>
PiperOrigin-RevId: 859207592
This commit is contained in:
Xiang (Sean) Zhou
2026-01-21 12:10:37 -08:00
committed by Copybara-Service
parent 5d941460b6
commit a04828dd8a
3 changed files with 111 additions and 2 deletions
@@ -272,6 +272,25 @@ class BaseLlmFlow(ABC):
await llm_connection.send_realtime(live_request.blob)
if live_request.content:
content = live_request.content
# Persist user text content to session (similar to non-live mode)
# Skip function responses - they are already handled separately
is_function_response = content.parts and any(
part.function_response for part in content.parts
)
if not is_function_response:
if not content.role:
content.role = 'user'
user_content_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author='user',
content=content,
)
await invocation_context.session_service.append_event(
session=invocation_context.session,
event=user_content_event,
)
await llm_connection.send_content(live_request.content)
async def _receive_from_model(
@@ -391,8 +410,8 @@ class BaseLlmFlow(ABC):
current_invocation=True, current_branch=True
)
# Long-running tool calls should have been handled before this point.
# If there are still long-running tool calls, it means the agent is paused
# Long running tool calls should have been handled before this point.
# If there are still long running tool calls, it means the agent is paused
# before, and its branch hasn't been resumed yet.
if (
invocation_context.is_resumable
@@ -1120,3 +1120,89 @@ def test_live_streaming_buffered_function_call_yielded_during_transcription():
assert (
function_response_found
), 'Buffered function_response event was not yielded.'
def test_live_streaming_text_content_persisted_in_session():
"""Test that user text content sent via send_content is persisted in session."""
response1 = LlmResponse(
content=types.Content(
role='model', parts=[types.Part(text='Hello! How can I help you?')]
),
turn_complete=True,
)
mock_model = testing_utils.MockModel.create([response1])
root_agent = Agent(
name='root_agent',
model=mock_model,
tools=[],
)
class CustomTestRunner(testing_utils.InMemoryRunner):
def run_live_and_get_session(
self,
live_request_queue: LiveRequestQueue,
run_config: testing_utils.RunConfig = None,
) -> tuple[list[testing_utils.Event], testing_utils.Session]:
collected_responses = []
async def consume_responses(session: testing_utils.Session):
run_res = self.runner.run_live(
session=session,
live_request_queue=live_request_queue,
run_config=run_config or testing_utils.RunConfig(),
)
async for response in run_res:
collected_responses.append(response)
if len(collected_responses) >= 1:
return
try:
session = self.session
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
# Get the updated session
updated_session = self.runner.session_service.get_session_sync(
app_name=self.app_name,
user_id=session.user_id,
session_id=session.id,
)
return collected_responses, updated_session
runner = CustomTestRunner(root_agent=root_agent)
live_request_queue = LiveRequestQueue()
# Send text content (not audio blob)
user_text = 'Hello, this is a test message'
live_request_queue.send_content(
types.Content(role='user', parts=[types.Part(text=user_text)])
)
res_events, session = runner.run_live_and_get_session(live_request_queue)
assert res_events is not None, 'Expected a list of events, got None.'
# Check that user text content was persisted in the session
user_content_found = False
for event in session.events:
if event.author == 'user' and event.content:
for part in event.content.parts:
if part.text and user_text in part.text:
user_content_found = True
break
assert user_content_found, (
f'Expected user text content "{user_text}" to be persisted in session. '
f'Session events: {[e.content for e in session.events]}'
)
+4
View File
@@ -409,6 +409,10 @@ class MockLlmConnection(BaseLlmConnection):
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""Yield each of the pre-defined LlmResponses."""
for response in self.llm_responses:
# Yield control to allow other tasks (like send_task) to run first.
# This ensures user content gets persisted before the mock response
# is yielded.
await asyncio.sleep(0)
yield response
async def close(self):