From 71fbc9275b3d74700ec410cb4155ba0cb18580b7 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Tue, 5 Aug 2025 11:50:14 -0700 Subject: [PATCH] feat: Implement Live Session Resumption Previous implementation doesn't pass the actual handle to server. Now we cache the handle and pass it over when reconnection happens. To enable: run_config = RunConfig( session_resumption=types.SessionResumptionConfig(transparent=True) ) PiperOrigin-RevId: 791308462 --- .../live_bidi_streaming_multi_agent/agent.py | 4 +- .../live_bidi_streaming_tools_agent/agent.py | 4 +- .../live_tool_callbacks_agent/agent.py | 3 +- src/google/adk/agents/invocation_context.py | 3 + .../adk/flows/llm_flows/base_llm_flow.py | 188 ++++++++++------- .../adk/models/gemini_llm_connection.py | 7 + src/google/adk/models/google_llm.py | 1 + src/google/adk/models/llm_response.py | 5 + .../streaming/test_multi_agent_streaming.py | 194 ++++++++++++++++++ 9 files changed, 333 insertions(+), 76 deletions(-) create mode 100644 tests/unittests/streaming/test_multi_agent_streaming.py diff --git a/contributing/samples/live_bidi_streaming_multi_agent/agent.py b/contributing/samples/live_bidi_streaming_multi_agent/agent.py index ac50eb7a..413e33a7 100644 --- a/contributing/samples/live_bidi_streaming_multi_agent/agent.py +++ b/contributing/samples/live_bidi_streaming_multi_agent/agent.py @@ -100,8 +100,8 @@ def get_current_weather(location: str): root_agent = Agent( # find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/ - # model='gemini-live-2.5-flash-preview-native-audio', # for Vertex project - model="gemini-live-2.5-flash-preview", # for AI studio key + model="gemini-2.0-flash-live-preview-04-09", # for Vertex project + # model="gemini-live-2.5-flash-preview", # for AI studio key name="root_agent", instruction=""" You are a helpful assistant that can check time, roll dice and check if numbers are prime. diff --git a/contributing/samples/live_bidi_streaming_tools_agent/agent.py b/contributing/samples/live_bidi_streaming_tools_agent/agent.py index cdb09217..c5565186 100644 --- a/contributing/samples/live_bidi_streaming_tools_agent/agent.py +++ b/contributing/samples/live_bidi_streaming_tools_agent/agent.py @@ -121,7 +121,9 @@ def stop_streaming(function_name: str): root_agent = Agent( - model="gemini-live-2.5-flash-preview", + # find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/ + model="gemini-2.0-flash-live-preview-04-09", # for Vertex project + # model="gemini-live-2.5-flash-preview", # for AI studio key name="video_streaming_agent", instruction=""" You are a monitoring agent. You can do video monitoring and stock price monitoring diff --git a/contributing/samples/live_tool_callbacks_agent/agent.py b/contributing/samples/live_tool_callbacks_agent/agent.py index 3f540b97..95af9d8f 100644 --- a/contributing/samples/live_tool_callbacks_agent/agent.py +++ b/contributing/samples/live_tool_callbacks_agent/agent.py @@ -217,8 +217,9 @@ import asyncio # Create the agent with tool callbacks root_agent = Agent( + # find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/ model="gemini-2.0-flash-live-preview-04-09", # for Vertex project - # model="gemini-2.0-flash-live-001", # for AI studio key + # model="gemini-live-2.5-flash-preview", # for AI studio key name="tool_callbacks_agent", description=( "Live streaming agent that demonstrates tool callbacks functionality. " diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 033d51a6..66c61ed6 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -153,6 +153,9 @@ class InvocationContext(BaseModel): transcription_cache: Optional[list[TranscriptionEntry]] = None """Caches necessary data, audio or contents, that are needed by transcription.""" + live_session_resumption_handle: Optional[str] = None + """The handle for live session resumption.""" + run_config: Optional[RunConfig] = None """Configurations for live agents under this invocation.""" diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b3886671..0a1cdb91 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -25,6 +25,7 @@ from typing import Optional from typing import TYPE_CHECKING from google.genai import types +from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosedOK from . import functions @@ -86,80 +87,115 @@ class BaseLlmFlow(ABC): invocation_context.agent.name, llm_request, ) - async with llm.connect(llm_request) as llm_connection: - if llm_request.contents: - # Sends the conversation history to the model. - with tracer.start_as_current_span('send_data'): - - if invocation_context.transcription_cache: - from . import audio_transcriber - - audio_transcriber = audio_transcriber.AudioTranscriber( - init_client=True - if invocation_context.run_config.input_audio_transcription - is None - else False - ) - contents = audio_transcriber.transcribe_file(invocation_context) - logger.debug('Sending history to model: %s', contents) - await llm_connection.send_history(contents) - invocation_context.transcription_cache = None - trace_send_data(invocation_context, event_id, contents) - else: - await llm_connection.send_history(llm_request.contents) - trace_send_data(invocation_context, event_id, llm_request.contents) - - send_task = asyncio.create_task( - self._send_to_model(llm_connection, invocation_context) - ) + attempt = 1 + while True: try: - async for event in self._receive_from_model( - llm_connection, - event_id, - invocation_context, - llm_request, - ): - # Empty event means the queue is closed. - if not event: - break - logger.debug('Receive new event: %s', event) - yield event - # send back the function response - if event.get_function_responses(): - logger.debug('Sending back last function response event: %s', event) - invocation_context.live_request_queue.send_content(event.content) - if ( - event.content - and event.content.parts - and event.content.parts[0].function_response - and event.content.parts[0].function_response.name - == 'transfer_to_agent' - ): - await asyncio.sleep(1) - # cancel the tasks that belongs to the closed connection. - send_task.cancel() - await llm_connection.close() - if ( - event.content - and event.content.parts - and event.content.parts[0].function_response - and event.content.parts[0].function_response.name - == 'task_completed' - ): - # this is used for sequential agent to signal the end of the agent. - await asyncio.sleep(1) - # cancel the tasks that belongs to the closed connection. - send_task.cancel() - return - finally: - # Clean up - if not send_task.done(): - send_task.cancel() - try: - await send_task - except asyncio.CancelledError: - pass + # On subsequent attempts, use the saved token to reconnect + if invocation_context.live_session_resumption_handle: + logger.info('Attempting to reconnect (Attempt %s)...', attempt) + attempt += 1 + if not llm_request.live_connect_config: + llm_request.live_connect_config = types.LiveConnectConfig() + llm_request.live_connect_config.session_resumption.handle = ( + invocation_context.live_session_resumption_handle + ) + llm_request.live_connect_config.session_resumption.transparent = True + + logger.info( + 'Establishing live connection for agent: %s', + invocation_context.agent.name, + ) + async with llm.connect(llm_request) as llm_connection: + if llm_request.contents: + # Sends the conversation history to the model. + with tracer.start_as_current_span('send_data'): + + if invocation_context.transcription_cache: + from . import audio_transcriber + + audio_transcriber = audio_transcriber.AudioTranscriber( + init_client=True + if invocation_context.run_config.input_audio_transcription + is None + else False + ) + contents = audio_transcriber.transcribe_file(invocation_context) + logger.debug('Sending history to model: %s', contents) + await llm_connection.send_history(contents) + invocation_context.transcription_cache = None + trace_send_data(invocation_context, event_id, contents) + else: + await llm_connection.send_history(llm_request.contents) + trace_send_data( + invocation_context, event_id, llm_request.contents + ) + + send_task = asyncio.create_task( + self._send_to_model(llm_connection, invocation_context) + ) + + try: + async for event in self._receive_from_model( + llm_connection, + event_id, + invocation_context, + llm_request, + ): + # Empty event means the queue is closed. + if not event: + break + logger.debug('Receive new event: %s', event) + yield event + # send back the function response + if event.get_function_responses(): + logger.debug( + 'Sending back last function response event: %s', event + ) + invocation_context.live_request_queue.send_content( + event.content + ) + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name + == 'transfer_to_agent' + ): + await asyncio.sleep(1) + # cancel the tasks that belongs to the closed connection. + send_task.cancel() + await llm_connection.close() + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name + == 'task_completed' + ): + # this is used for sequential agent to signal the end of the agent. + await asyncio.sleep(1) + # cancel the tasks that belongs to the closed connection. + send_task.cancel() + return + finally: + # Clean up + if not send_task.done(): + send_task.cancel() + try: + await send_task + except asyncio.CancelledError: + pass + except (ConnectionClosed, ConnectionClosedOK) as e: + # when the session timeout, it will just close and not throw exception. + # so this is for bad cases + logger.error(f'Connection closed: {e}.') + raise + except Exception as e: + logger.error( + f'An unexpected error occurred in live flow: {e}', exc_info=True + ) + raise async def _send_to_model( self, @@ -246,6 +282,14 @@ class BaseLlmFlow(ABC): try: while True: async for llm_response in llm_connection.receive(): + if llm_response.live_session_resumption_update: + logger.info( + 'Update session resumption hanlde:' + f' {llm_response.live_session_resumption_update}.' + ) + invocation_context.live_session_resumption_handle = ( + llm_response.live_session_resumption_update.new_handle + ) model_response_event = Event( id=Event.new_id(), invocation_id=invocation_context.invocation_id, diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 3a902c56..3b46c91a 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -219,6 +219,13 @@ class GeminiLlmConnection(BaseLlmConnection): for function_call in message.tool_call.function_calls ] yield LlmResponse(content=types.Content(role='model', parts=parts)) + if message.session_resumption_update: + logger.info('Redeived session reassumption message: %s', message) + yield ( + LlmResponse( + live_session_resumption_update=message.session_resumption_update + ) + ) async def close(self): """Closes the llm server connection.""" diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index a68af629..b1cad1c5 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -289,6 +289,7 @@ class Gemini(BaseLlm): ], ) llm_request.live_connect_config.tools = llm_request.config.tools + logger.info('Connecting to live with llm_request:%s', llm_request) async with self._live_api_client.aio.live.connect( model=llm_request.model, config=llm_request.live_connect_config ) as live_session: diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index 6539ff1a..2f39ad42 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -89,6 +89,11 @@ class LlmResponse(BaseModel): usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None """The usage metadata of the LlmResponse""" + live_session_resumption_update: Optional[ + types.LiveServerSessionResumptionUpdate + ] = None + """The session resumption update of the LlmResponse""" + @staticmethod def create( generate_content_response: types.GenerateContentResponse, diff --git a/tests/unittests/streaming/test_multi_agent_streaming.py b/tests/unittests/streaming/test_multi_agent_streaming.py new file mode 100644 index 00000000..f7f9cb0d --- /dev/null +++ b/tests/unittests/streaming/test_multi_agent_streaming.py @@ -0,0 +1,194 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import contextlib +from typing import AsyncGenerator + +from google.adk.agents.live_request_queue import LiveRequestQueue +from google.adk.agents.llm_agent import Agent +from google.adk.models.llm_response import LlmResponse +from google.genai import types +import pytest +from typing_extensions import override # <-- FIX: Add this import +from websockets import frames # <-- FIX 1: Import the frames module +from websockets.exceptions import ConnectionClosed + +from .. import testing_utils + + +def test_live_streaming_multi_agent_single_tool(): + """Test live streaming with multi-agent delegation for a single tool call.""" + # --- 1. Mock LLM Responses --- + + # Mock response for the root_agent to delegate the task to the roll_agent. + # FIX: Use from_function_call to represent delegation to a sub-agent. + delegation_to_roll_agent = types.Part.from_function_call( + name='transfer_to_agent', args={'agent_name': 'roll_agent'} + ) + + root_response1 = LlmResponse( + content=types.Content(role='model', parts=[delegation_to_roll_agent]), + turn_complete=False, + ) + root_response2 = LlmResponse(turn_complete=True) + mock_root_model = testing_utils.MockModel.create( + [root_response1, root_response2] + ) + + # Mock response for the roll_agent to call its `roll_die` tool. + function_call = types.Part.from_function_call( + name='roll_die', args={'sides': 20} + ) + roll_agent_response1 = LlmResponse( + content=types.Content(role='model', parts=[function_call]), + turn_complete=False, + ) + roll_agent_response2 = LlmResponse(turn_complete=True) + mock_roll_model = testing_utils.MockModel.create( + [roll_agent_response1, roll_agent_response2] + ) + + # --- 2. Mock Tools and Agents --- + + def roll_die(sides: int) -> int: + """Rolls a die and returns a fixed result for testing.""" + return 15 + + mock_roll_sub_agent = Agent( + name='roll_agent', + model=mock_roll_model, + tools=[roll_die], + ) + + main_agent = Agent( + name='root_agent', + model=mock_root_model, + sub_agents=[mock_roll_sub_agent], + ) + + # --- 3. Test Runner Setup --- + class CustomTestRunner(testing_utils.InMemoryRunner): + + def run_live( + self, + live_request_queue: LiveRequestQueue, + run_config: testing_utils.RunConfig = None, + ) -> list[testing_utils.Event]: + collected_responses = [] + + async def consume_responses(session: testing_utils.Session): + run_res = self.runner.run_live( + session=session, + live_request_queue=live_request_queue, + run_config=run_config or testing_utils.RunConfig(), + ) + async for response in run_res: + collected_responses.append(response) + if len(collected_responses) >= 5: + return + + try: + session = self.session + asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + return collected_responses + + runner = CustomTestRunner(root_agent=main_agent) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=types.Blob(data=b'Roll a 20-sided die', mime_type='audio/pcm') + ) + + # --- 4. Run and Assert --- + res_events = runner.run_live(live_request_queue) + + assert res_events is not None, 'Expected a list of events, but got None.' + assert len(res_events) >= 1, 'Expected at least one event.' + + delegation_found = False + tool_call_found = False + tool_response_found = False + + for event in res_events: + if event.content and event.content.parts: + for part in event.content.parts: + if part.function_call: + # FIX: Check for the function call that represents delegation. + if part.function_call.name == 'transfer_to_agent': + delegation_found = True + assert part.function_call.args == {'agent_name': 'roll_agent'} + + # Check for the function call made by the roll_agent. + if part.function_call.name == 'roll_die': + tool_call_found = True + assert part.function_call.args['sides'] == 20 + + # Check for the result from the executed function. + if part.function_response and part.function_response.name == 'roll_die': + tool_response_found = True + assert part.function_response.response['result'] == 15 + + assert delegation_found, 'A function_call event for delegation was not found.' + assert tool_call_found, 'A function_call event for roll_die was not found.' + assert tool_response_found, 'A function_response for roll_die was not found.' + + +def test_live_streaming_connection_error_on_connect(): + """ + Tests that the runner correctly handles a ConnectionClosed exception + raised from the model's `connect` method during a live run. + """ + + # 1. Create a mock model that fails during the connection phase. + class MockModelThatFailsToConnect(testing_utils.MockModel): + + @contextlib.asynccontextmanager + @override + async def connect(self, llm_request: testing_utils.LlmRequest): + """Override connect to simulate an immediate connection failure.""" + + # FIX 2: Create a proper `Close` frame object first. + close_frame = frames.Close( + 1007, + 'gemini-live-2.5-flash-preview is not supported in the live api.', + ) + + # FIX 3: Pass the frame object to the `rcvd` parameter of the exception. + raise ConnectionClosed(rcvd=close_frame, sent=None) + + yield # pragma: no cover + + # 2. Instantiate the custom mock model. + mock_model = MockModelThatFailsToConnect(responses=[]) + + # 3. Set up the agent and runner. + agent = Agent(name='test_agent_for_connection_failure', model=mock_model) + runner = testing_utils.InMemoryRunner(root_agent=agent) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=types.Blob(data=b'Initial audio chunk', mime_type='audio/pcm') + ) + + # 4. Assert that `run_live` raises `ConnectionClosed`. + with pytest.raises(ConnectionClosed) as excinfo: + runner.run_live(live_request_queue) + + # 5. Verify the details of the exception. The `code` and `reason` are + # attributes of the received frame (`rcvd`), not the exception itself. + assert excinfo.value.rcvd.code == 1007 + assert ( + 'is not supported in the live api' in excinfo.value.rcvd.reason + ), 'The exception reason should match the simulated server error.'