feat: Implement Live Session Resumption

Previous implementation doesn't pass the actual handle to server. Now we cache the handle and pass it over when reconnection happens.

To enable:
    run_config = RunConfig(
        session_resumption=types.SessionResumptionConfig(transparent=True)
    )

PiperOrigin-RevId: 791308462
This commit is contained in:
Hangfei Lin
2025-08-05 11:50:14 -07:00
committed by Copybara-Service
parent 423542a43f
commit 71fbc9275b
9 changed files with 333 additions and 76 deletions
@@ -100,8 +100,8 @@ def get_current_weather(location: str):
root_agent = Agent(
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
# model='gemini-live-2.5-flash-preview-native-audio', # for Vertex project
model="gemini-live-2.5-flash-preview", # for AI studio key
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
# model="gemini-live-2.5-flash-preview", # for AI studio key
name="root_agent",
instruction="""
You are a helpful assistant that can check time, roll dice and check if numbers are prime.
@@ -121,7 +121,9 @@ def stop_streaming(function_name: str):
root_agent = Agent(
model="gemini-live-2.5-flash-preview",
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
# model="gemini-live-2.5-flash-preview", # for AI studio key
name="video_streaming_agent",
instruction="""
You are a monitoring agent. You can do video monitoring and stock price monitoring
@@ -217,8 +217,9 @@ import asyncio
# Create the agent with tool callbacks
root_agent = Agent(
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
# model="gemini-2.0-flash-live-001", # for AI studio key
# model="gemini-live-2.5-flash-preview", # for AI studio key
name="tool_callbacks_agent",
description=(
"Live streaming agent that demonstrates tool callbacks functionality. "
@@ -153,6 +153,9 @@ class InvocationContext(BaseModel):
transcription_cache: Optional[list[TranscriptionEntry]] = None
"""Caches necessary data, audio or contents, that are needed by transcription."""
live_session_resumption_handle: Optional[str] = None
"""The handle for live session resumption."""
run_config: Optional[RunConfig] = None
"""Configurations for live agents under this invocation."""
+116 -72
View File
@@ -25,6 +25,7 @@ from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
from websockets.exceptions import ConnectionClosed
from websockets.exceptions import ConnectionClosedOK
from . import functions
@@ -86,80 +87,115 @@ class BaseLlmFlow(ABC):
invocation_context.agent.name,
llm_request,
)
async with llm.connect(llm_request) as llm_connection:
if llm_request.contents:
# Sends the conversation history to the model.
with tracer.start_as_current_span('send_data'):
if invocation_context.transcription_cache:
from . import audio_transcriber
audio_transcriber = audio_transcriber.AudioTranscriber(
init_client=True
if invocation_context.run_config.input_audio_transcription
is None
else False
)
contents = audio_transcriber.transcribe_file(invocation_context)
logger.debug('Sending history to model: %s', contents)
await llm_connection.send_history(contents)
invocation_context.transcription_cache = None
trace_send_data(invocation_context, event_id, contents)
else:
await llm_connection.send_history(llm_request.contents)
trace_send_data(invocation_context, event_id, llm_request.contents)
send_task = asyncio.create_task(
self._send_to_model(llm_connection, invocation_context)
)
attempt = 1
while True:
try:
async for event in self._receive_from_model(
llm_connection,
event_id,
invocation_context,
llm_request,
):
# Empty event means the queue is closed.
if not event:
break
logger.debug('Receive new event: %s', event)
yield event
# send back the function response
if event.get_function_responses():
logger.debug('Sending back last function response event: %s', event)
invocation_context.live_request_queue.send_content(event.content)
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'transfer_to_agent'
):
await asyncio.sleep(1)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
await llm_connection.close()
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'task_completed'
):
# this is used for sequential agent to signal the end of the agent.
await asyncio.sleep(1)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
return
finally:
# Clean up
if not send_task.done():
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
# On subsequent attempts, use the saved token to reconnect
if invocation_context.live_session_resumption_handle:
logger.info('Attempting to reconnect (Attempt %s)...', attempt)
attempt += 1
if not llm_request.live_connect_config:
llm_request.live_connect_config = types.LiveConnectConfig()
llm_request.live_connect_config.session_resumption.handle = (
invocation_context.live_session_resumption_handle
)
llm_request.live_connect_config.session_resumption.transparent = True
logger.info(
'Establishing live connection for agent: %s',
invocation_context.agent.name,
)
async with llm.connect(llm_request) as llm_connection:
if llm_request.contents:
# Sends the conversation history to the model.
with tracer.start_as_current_span('send_data'):
if invocation_context.transcription_cache:
from . import audio_transcriber
audio_transcriber = audio_transcriber.AudioTranscriber(
init_client=True
if invocation_context.run_config.input_audio_transcription
is None
else False
)
contents = audio_transcriber.transcribe_file(invocation_context)
logger.debug('Sending history to model: %s', contents)
await llm_connection.send_history(contents)
invocation_context.transcription_cache = None
trace_send_data(invocation_context, event_id, contents)
else:
await llm_connection.send_history(llm_request.contents)
trace_send_data(
invocation_context, event_id, llm_request.contents
)
send_task = asyncio.create_task(
self._send_to_model(llm_connection, invocation_context)
)
try:
async for event in self._receive_from_model(
llm_connection,
event_id,
invocation_context,
llm_request,
):
# Empty event means the queue is closed.
if not event:
break
logger.debug('Receive new event: %s', event)
yield event
# send back the function response
if event.get_function_responses():
logger.debug(
'Sending back last function response event: %s', event
)
invocation_context.live_request_queue.send_content(
event.content
)
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'transfer_to_agent'
):
await asyncio.sleep(1)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
await llm_connection.close()
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'task_completed'
):
# this is used for sequential agent to signal the end of the agent.
await asyncio.sleep(1)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
return
finally:
# Clean up
if not send_task.done():
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
except (ConnectionClosed, ConnectionClosedOK) as e:
# when the session timeout, it will just close and not throw exception.
# so this is for bad cases
logger.error(f'Connection closed: {e}.')
raise
except Exception as e:
logger.error(
f'An unexpected error occurred in live flow: {e}', exc_info=True
)
raise
async def _send_to_model(
self,
@@ -246,6 +282,14 @@ class BaseLlmFlow(ABC):
try:
while True:
async for llm_response in llm_connection.receive():
if llm_response.live_session_resumption_update:
logger.info(
'Update session resumption hanlde:'
f' {llm_response.live_session_resumption_update}.'
)
invocation_context.live_session_resumption_handle = (
llm_response.live_session_resumption_update.new_handle
)
model_response_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
@@ -219,6 +219,13 @@ class GeminiLlmConnection(BaseLlmConnection):
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
if message.session_resumption_update:
logger.info('Redeived session reassumption message: %s', message)
yield (
LlmResponse(
live_session_resumption_update=message.session_resumption_update
)
)
async def close(self):
"""Closes the llm server connection."""
+1
View File
@@ -289,6 +289,7 @@ class Gemini(BaseLlm):
],
)
llm_request.live_connect_config.tools = llm_request.config.tools
logger.info('Connecting to live with llm_request:%s', llm_request)
async with self._live_api_client.aio.live.connect(
model=llm_request.model, config=llm_request.live_connect_config
) as live_session:
+5
View File
@@ -89,6 +89,11 @@ class LlmResponse(BaseModel):
usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None
"""The usage metadata of the LlmResponse"""
live_session_resumption_update: Optional[
types.LiveServerSessionResumptionUpdate
] = None
"""The session resumption update of the LlmResponse"""
@staticmethod
def create(
generate_content_response: types.GenerateContentResponse,
@@ -0,0 +1,194 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import contextlib
from typing import AsyncGenerator
from google.adk.agents.live_request_queue import LiveRequestQueue
from google.adk.agents.llm_agent import Agent
from google.adk.models.llm_response import LlmResponse
from google.genai import types
import pytest
from typing_extensions import override # <-- FIX: Add this import
from websockets import frames # <-- FIX 1: Import the frames module
from websockets.exceptions import ConnectionClosed
from .. import testing_utils
def test_live_streaming_multi_agent_single_tool():
"""Test live streaming with multi-agent delegation for a single tool call."""
# --- 1. Mock LLM Responses ---
# Mock response for the root_agent to delegate the task to the roll_agent.
# FIX: Use from_function_call to represent delegation to a sub-agent.
delegation_to_roll_agent = types.Part.from_function_call(
name='transfer_to_agent', args={'agent_name': 'roll_agent'}
)
root_response1 = LlmResponse(
content=types.Content(role='model', parts=[delegation_to_roll_agent]),
turn_complete=False,
)
root_response2 = LlmResponse(turn_complete=True)
mock_root_model = testing_utils.MockModel.create(
[root_response1, root_response2]
)
# Mock response for the roll_agent to call its `roll_die` tool.
function_call = types.Part.from_function_call(
name='roll_die', args={'sides': 20}
)
roll_agent_response1 = LlmResponse(
content=types.Content(role='model', parts=[function_call]),
turn_complete=False,
)
roll_agent_response2 = LlmResponse(turn_complete=True)
mock_roll_model = testing_utils.MockModel.create(
[roll_agent_response1, roll_agent_response2]
)
# --- 2. Mock Tools and Agents ---
def roll_die(sides: int) -> int:
"""Rolls a die and returns a fixed result for testing."""
return 15
mock_roll_sub_agent = Agent(
name='roll_agent',
model=mock_roll_model,
tools=[roll_die],
)
main_agent = Agent(
name='root_agent',
model=mock_root_model,
sub_agents=[mock_roll_sub_agent],
)
# --- 3. Test Runner Setup ---
class CustomTestRunner(testing_utils.InMemoryRunner):
def run_live(
self,
live_request_queue: LiveRequestQueue,
run_config: testing_utils.RunConfig = None,
) -> list[testing_utils.Event]:
collected_responses = []
async def consume_responses(session: testing_utils.Session):
run_res = self.runner.run_live(
session=session,
live_request_queue=live_request_queue,
run_config=run_config or testing_utils.RunConfig(),
)
async for response in run_res:
collected_responses.append(response)
if len(collected_responses) >= 5:
return
try:
session = self.session
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
return collected_responses
runner = CustomTestRunner(root_agent=main_agent)
live_request_queue = LiveRequestQueue()
live_request_queue.send_realtime(
blob=types.Blob(data=b'Roll a 20-sided die', mime_type='audio/pcm')
)
# --- 4. Run and Assert ---
res_events = runner.run_live(live_request_queue)
assert res_events is not None, 'Expected a list of events, but got None.'
assert len(res_events) >= 1, 'Expected at least one event.'
delegation_found = False
tool_call_found = False
tool_response_found = False
for event in res_events:
if event.content and event.content.parts:
for part in event.content.parts:
if part.function_call:
# FIX: Check for the function call that represents delegation.
if part.function_call.name == 'transfer_to_agent':
delegation_found = True
assert part.function_call.args == {'agent_name': 'roll_agent'}
# Check for the function call made by the roll_agent.
if part.function_call.name == 'roll_die':
tool_call_found = True
assert part.function_call.args['sides'] == 20
# Check for the result from the executed function.
if part.function_response and part.function_response.name == 'roll_die':
tool_response_found = True
assert part.function_response.response['result'] == 15
assert delegation_found, 'A function_call event for delegation was not found.'
assert tool_call_found, 'A function_call event for roll_die was not found.'
assert tool_response_found, 'A function_response for roll_die was not found.'
def test_live_streaming_connection_error_on_connect():
"""
Tests that the runner correctly handles a ConnectionClosed exception
raised from the model's `connect` method during a live run.
"""
# 1. Create a mock model that fails during the connection phase.
class MockModelThatFailsToConnect(testing_utils.MockModel):
@contextlib.asynccontextmanager
@override
async def connect(self, llm_request: testing_utils.LlmRequest):
"""Override connect to simulate an immediate connection failure."""
# FIX 2: Create a proper `Close` frame object first.
close_frame = frames.Close(
1007,
'gemini-live-2.5-flash-preview is not supported in the live api.',
)
# FIX 3: Pass the frame object to the `rcvd` parameter of the exception.
raise ConnectionClosed(rcvd=close_frame, sent=None)
yield # pragma: no cover
# 2. Instantiate the custom mock model.
mock_model = MockModelThatFailsToConnect(responses=[])
# 3. Set up the agent and runner.
agent = Agent(name='test_agent_for_connection_failure', model=mock_model)
runner = testing_utils.InMemoryRunner(root_agent=agent)
live_request_queue = LiveRequestQueue()
live_request_queue.send_realtime(
blob=types.Blob(data=b'Initial audio chunk', mime_type='audio/pcm')
)
# 4. Assert that `run_live` raises `ConnectionClosed`.
with pytest.raises(ConnectionClosed) as excinfo:
runner.run_live(live_request_queue)
# 5. Verify the details of the exception. The `code` and `reason` are
# attributes of the received frame (`rcvd`), not the exception itself.
assert excinfo.value.rcvd.code == 1007
assert (
'is not supported in the live api' in excinfo.value.rcvd.reason
), 'The exception reason should match the simulated server error.'