You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Persist user input content to session in live mode
Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 859207592
This commit is contained in:
committed by
Copybara-Service
parent
5d941460b6
commit
a04828dd8a
@@ -272,6 +272,25 @@ class BaseLlmFlow(ABC):
|
||||
await llm_connection.send_realtime(live_request.blob)
|
||||
|
||||
if live_request.content:
|
||||
content = live_request.content
|
||||
# Persist user text content to session (similar to non-live mode)
|
||||
# Skip function responses - they are already handled separately
|
||||
is_function_response = content.parts and any(
|
||||
part.function_response for part in content.parts
|
||||
)
|
||||
if not is_function_response:
|
||||
if not content.role:
|
||||
content.role = 'user'
|
||||
user_content_event = Event(
|
||||
id=Event.new_id(),
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author='user',
|
||||
content=content,
|
||||
)
|
||||
await invocation_context.session_service.append_event(
|
||||
session=invocation_context.session,
|
||||
event=user_content_event,
|
||||
)
|
||||
await llm_connection.send_content(live_request.content)
|
||||
|
||||
async def _receive_from_model(
|
||||
@@ -391,8 +410,8 @@ class BaseLlmFlow(ABC):
|
||||
current_invocation=True, current_branch=True
|
||||
)
|
||||
|
||||
# Long-running tool calls should have been handled before this point.
|
||||
# If there are still long-running tool calls, it means the agent is paused
|
||||
# Long running tool calls should have been handled before this point.
|
||||
# If there are still long running tool calls, it means the agent is paused
|
||||
# before, and its branch hasn't been resumed yet.
|
||||
if (
|
||||
invocation_context.is_resumable
|
||||
|
||||
@@ -1120,3 +1120,89 @@ def test_live_streaming_buffered_function_call_yielded_during_transcription():
|
||||
assert (
|
||||
function_response_found
|
||||
), 'Buffered function_response event was not yielded.'
|
||||
|
||||
|
||||
def test_live_streaming_text_content_persisted_in_session():
|
||||
"""Test that user text content sent via send_content is persisted in session."""
|
||||
response1 = LlmResponse(
|
||||
content=types.Content(
|
||||
role='model', parts=[types.Part(text='Hello! How can I help you?')]
|
||||
),
|
||||
turn_complete=True,
|
||||
)
|
||||
|
||||
mock_model = testing_utils.MockModel.create([response1])
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[],
|
||||
)
|
||||
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def run_live_and_get_session(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
run_config: testing_utils.RunConfig = None,
|
||||
) -> tuple[list[testing_utils.Event], testing_utils.Session]:
|
||||
collected_responses = []
|
||||
|
||||
async def consume_responses(session: testing_utils.Session):
|
||||
run_res = self.runner.run_live(
|
||||
session=session,
|
||||
live_request_queue=live_request_queue,
|
||||
run_config=run_config or testing_utils.RunConfig(),
|
||||
)
|
||||
async for response in run_res:
|
||||
collected_responses.append(response)
|
||||
if len(collected_responses) >= 1:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
# Get the updated session
|
||||
updated_session = self.runner.session_service.get_session_sync(
|
||||
app_name=self.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
)
|
||||
return collected_responses, updated_session
|
||||
|
||||
runner = CustomTestRunner(root_agent=root_agent)
|
||||
live_request_queue = LiveRequestQueue()
|
||||
|
||||
# Send text content (not audio blob)
|
||||
user_text = 'Hello, this is a test message'
|
||||
live_request_queue.send_content(
|
||||
types.Content(role='user', parts=[types.Part(text=user_text)])
|
||||
)
|
||||
|
||||
res_events, session = runner.run_live_and_get_session(live_request_queue)
|
||||
|
||||
assert res_events is not None, 'Expected a list of events, got None.'
|
||||
|
||||
# Check that user text content was persisted in the session
|
||||
user_content_found = False
|
||||
for event in session.events:
|
||||
if event.author == 'user' and event.content:
|
||||
for part in event.content.parts:
|
||||
if part.text and user_text in part.text:
|
||||
user_content_found = True
|
||||
break
|
||||
|
||||
assert user_content_found, (
|
||||
f'Expected user text content "{user_text}" to be persisted in session. '
|
||||
f'Session events: {[e.content for e in session.events]}'
|
||||
)
|
||||
|
||||
@@ -409,6 +409,10 @@ class MockLlmConnection(BaseLlmConnection):
|
||||
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Yield each of the pre-defined LlmResponses."""
|
||||
for response in self.llm_responses:
|
||||
# Yield control to allow other tasks (like send_task) to run first.
|
||||
# This ensures user content gets persisted before the mock response
|
||||
# is yielded.
|
||||
await asyncio.sleep(0)
|
||||
yield response
|
||||
|
||||
async def close(self):
|
||||
|
||||
Reference in New Issue
Block a user