You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Implement Live Session Resumption
Previous implementation doesn't pass the actual handle to server. Now we cache the handle and pass it over when reconnection happens.
To enable:
run_config = RunConfig(
session_resumption=types.SessionResumptionConfig(transparent=True)
)
PiperOrigin-RevId: 791308462
This commit is contained in:
committed by
Copybara-Service
parent
423542a43f
commit
71fbc9275b
@@ -100,8 +100,8 @@ def get_current_weather(location: str):
|
||||
|
||||
root_agent = Agent(
|
||||
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
|
||||
# model='gemini-live-2.5-flash-preview-native-audio', # for Vertex project
|
||||
model="gemini-live-2.5-flash-preview", # for AI studio key
|
||||
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
|
||||
# model="gemini-live-2.5-flash-preview", # for AI studio key
|
||||
name="root_agent",
|
||||
instruction="""
|
||||
You are a helpful assistant that can check time, roll dice and check if numbers are prime.
|
||||
|
||||
@@ -121,7 +121,9 @@ def stop_streaming(function_name: str):
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model="gemini-live-2.5-flash-preview",
|
||||
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
|
||||
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
|
||||
# model="gemini-live-2.5-flash-preview", # for AI studio key
|
||||
name="video_streaming_agent",
|
||||
instruction="""
|
||||
You are a monitoring agent. You can do video monitoring and stock price monitoring
|
||||
|
||||
@@ -217,8 +217,9 @@ import asyncio
|
||||
|
||||
# Create the agent with tool callbacks
|
||||
root_agent = Agent(
|
||||
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
|
||||
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
|
||||
# model="gemini-2.0-flash-live-001", # for AI studio key
|
||||
# model="gemini-live-2.5-flash-preview", # for AI studio key
|
||||
name="tool_callbacks_agent",
|
||||
description=(
|
||||
"Live streaming agent that demonstrates tool callbacks functionality. "
|
||||
|
||||
@@ -153,6 +153,9 @@ class InvocationContext(BaseModel):
|
||||
transcription_cache: Optional[list[TranscriptionEntry]] = None
|
||||
"""Caches necessary data, audio or contents, that are needed by transcription."""
|
||||
|
||||
live_session_resumption_handle: Optional[str] = None
|
||||
"""The handle for live session resumption."""
|
||||
|
||||
run_config: Optional[RunConfig] = None
|
||||
"""Configurations for live agents under this invocation."""
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from . import functions
|
||||
@@ -86,80 +87,115 @@ class BaseLlmFlow(ABC):
|
||||
invocation_context.agent.name,
|
||||
llm_request,
|
||||
)
|
||||
async with llm.connect(llm_request) as llm_connection:
|
||||
if llm_request.contents:
|
||||
# Sends the conversation history to the model.
|
||||
with tracer.start_as_current_span('send_data'):
|
||||
|
||||
if invocation_context.transcription_cache:
|
||||
from . import audio_transcriber
|
||||
|
||||
audio_transcriber = audio_transcriber.AudioTranscriber(
|
||||
init_client=True
|
||||
if invocation_context.run_config.input_audio_transcription
|
||||
is None
|
||||
else False
|
||||
)
|
||||
contents = audio_transcriber.transcribe_file(invocation_context)
|
||||
logger.debug('Sending history to model: %s', contents)
|
||||
await llm_connection.send_history(contents)
|
||||
invocation_context.transcription_cache = None
|
||||
trace_send_data(invocation_context, event_id, contents)
|
||||
else:
|
||||
await llm_connection.send_history(llm_request.contents)
|
||||
trace_send_data(invocation_context, event_id, llm_request.contents)
|
||||
|
||||
send_task = asyncio.create_task(
|
||||
self._send_to_model(llm_connection, invocation_context)
|
||||
)
|
||||
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
async for event in self._receive_from_model(
|
||||
llm_connection,
|
||||
event_id,
|
||||
invocation_context,
|
||||
llm_request,
|
||||
):
|
||||
# Empty event means the queue is closed.
|
||||
if not event:
|
||||
break
|
||||
logger.debug('Receive new event: %s', event)
|
||||
yield event
|
||||
# send back the function response
|
||||
if event.get_function_responses():
|
||||
logger.debug('Sending back last function response event: %s', event)
|
||||
invocation_context.live_request_queue.send_content(event.content)
|
||||
if (
|
||||
event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].function_response
|
||||
and event.content.parts[0].function_response.name
|
||||
== 'transfer_to_agent'
|
||||
):
|
||||
await asyncio.sleep(1)
|
||||
# cancel the tasks that belongs to the closed connection.
|
||||
send_task.cancel()
|
||||
await llm_connection.close()
|
||||
if (
|
||||
event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].function_response
|
||||
and event.content.parts[0].function_response.name
|
||||
== 'task_completed'
|
||||
):
|
||||
# this is used for sequential agent to signal the end of the agent.
|
||||
await asyncio.sleep(1)
|
||||
# cancel the tasks that belongs to the closed connection.
|
||||
send_task.cancel()
|
||||
return
|
||||
finally:
|
||||
# Clean up
|
||||
if not send_task.done():
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# On subsequent attempts, use the saved token to reconnect
|
||||
if invocation_context.live_session_resumption_handle:
|
||||
logger.info('Attempting to reconnect (Attempt %s)...', attempt)
|
||||
attempt += 1
|
||||
if not llm_request.live_connect_config:
|
||||
llm_request.live_connect_config = types.LiveConnectConfig()
|
||||
llm_request.live_connect_config.session_resumption.handle = (
|
||||
invocation_context.live_session_resumption_handle
|
||||
)
|
||||
llm_request.live_connect_config.session_resumption.transparent = True
|
||||
|
||||
logger.info(
|
||||
'Establishing live connection for agent: %s',
|
||||
invocation_context.agent.name,
|
||||
)
|
||||
async with llm.connect(llm_request) as llm_connection:
|
||||
if llm_request.contents:
|
||||
# Sends the conversation history to the model.
|
||||
with tracer.start_as_current_span('send_data'):
|
||||
|
||||
if invocation_context.transcription_cache:
|
||||
from . import audio_transcriber
|
||||
|
||||
audio_transcriber = audio_transcriber.AudioTranscriber(
|
||||
init_client=True
|
||||
if invocation_context.run_config.input_audio_transcription
|
||||
is None
|
||||
else False
|
||||
)
|
||||
contents = audio_transcriber.transcribe_file(invocation_context)
|
||||
logger.debug('Sending history to model: %s', contents)
|
||||
await llm_connection.send_history(contents)
|
||||
invocation_context.transcription_cache = None
|
||||
trace_send_data(invocation_context, event_id, contents)
|
||||
else:
|
||||
await llm_connection.send_history(llm_request.contents)
|
||||
trace_send_data(
|
||||
invocation_context, event_id, llm_request.contents
|
||||
)
|
||||
|
||||
send_task = asyncio.create_task(
|
||||
self._send_to_model(llm_connection, invocation_context)
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in self._receive_from_model(
|
||||
llm_connection,
|
||||
event_id,
|
||||
invocation_context,
|
||||
llm_request,
|
||||
):
|
||||
# Empty event means the queue is closed.
|
||||
if not event:
|
||||
break
|
||||
logger.debug('Receive new event: %s', event)
|
||||
yield event
|
||||
# send back the function response
|
||||
if event.get_function_responses():
|
||||
logger.debug(
|
||||
'Sending back last function response event: %s', event
|
||||
)
|
||||
invocation_context.live_request_queue.send_content(
|
||||
event.content
|
||||
)
|
||||
if (
|
||||
event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].function_response
|
||||
and event.content.parts[0].function_response.name
|
||||
== 'transfer_to_agent'
|
||||
):
|
||||
await asyncio.sleep(1)
|
||||
# cancel the tasks that belongs to the closed connection.
|
||||
send_task.cancel()
|
||||
await llm_connection.close()
|
||||
if (
|
||||
event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].function_response
|
||||
and event.content.parts[0].function_response.name
|
||||
== 'task_completed'
|
||||
):
|
||||
# this is used for sequential agent to signal the end of the agent.
|
||||
await asyncio.sleep(1)
|
||||
# cancel the tasks that belongs to the closed connection.
|
||||
send_task.cancel()
|
||||
return
|
||||
finally:
|
||||
# Clean up
|
||||
if not send_task.done():
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except (ConnectionClosed, ConnectionClosedOK) as e:
|
||||
# when the session timeout, it will just close and not throw exception.
|
||||
# so this is for bad cases
|
||||
logger.error(f'Connection closed: {e}.')
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'An unexpected error occurred in live flow: {e}', exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def _send_to_model(
|
||||
self,
|
||||
@@ -246,6 +282,14 @@ class BaseLlmFlow(ABC):
|
||||
try:
|
||||
while True:
|
||||
async for llm_response in llm_connection.receive():
|
||||
if llm_response.live_session_resumption_update:
|
||||
logger.info(
|
||||
'Update session resumption hanlde:'
|
||||
f' {llm_response.live_session_resumption_update}.'
|
||||
)
|
||||
invocation_context.live_session_resumption_handle = (
|
||||
llm_response.live_session_resumption_update.new_handle
|
||||
)
|
||||
model_response_event = Event(
|
||||
id=Event.new_id(),
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
|
||||
@@ -219,6 +219,13 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
for function_call in message.tool_call.function_calls
|
||||
]
|
||||
yield LlmResponse(content=types.Content(role='model', parts=parts))
|
||||
if message.session_resumption_update:
|
||||
logger.info('Redeived session reassumption message: %s', message)
|
||||
yield (
|
||||
LlmResponse(
|
||||
live_session_resumption_update=message.session_resumption_update
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Closes the llm server connection."""
|
||||
|
||||
@@ -289,6 +289,7 @@ class Gemini(BaseLlm):
|
||||
],
|
||||
)
|
||||
llm_request.live_connect_config.tools = llm_request.config.tools
|
||||
logger.info('Connecting to live with llm_request:%s', llm_request)
|
||||
async with self._live_api_client.aio.live.connect(
|
||||
model=llm_request.model, config=llm_request.live_connect_config
|
||||
) as live_session:
|
||||
|
||||
@@ -89,6 +89,11 @@ class LlmResponse(BaseModel):
|
||||
usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None
|
||||
"""The usage metadata of the LlmResponse"""
|
||||
|
||||
live_session_resumption_update: Optional[
|
||||
types.LiveServerSessionResumptionUpdate
|
||||
] = None
|
||||
"""The session resumption update of the LlmResponse"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
generate_content_response: types.GenerateContentResponse,
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.adk.agents.live_request_queue import LiveRequestQueue
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.genai import types
|
||||
import pytest
|
||||
from typing_extensions import override # <-- FIX: Add this import
|
||||
from websockets import frames # <-- FIX 1: Import the frames module
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from .. import testing_utils
|
||||
|
||||
|
||||
def test_live_streaming_multi_agent_single_tool():
|
||||
"""Test live streaming with multi-agent delegation for a single tool call."""
|
||||
# --- 1. Mock LLM Responses ---
|
||||
|
||||
# Mock response for the root_agent to delegate the task to the roll_agent.
|
||||
# FIX: Use from_function_call to represent delegation to a sub-agent.
|
||||
delegation_to_roll_agent = types.Part.from_function_call(
|
||||
name='transfer_to_agent', args={'agent_name': 'roll_agent'}
|
||||
)
|
||||
|
||||
root_response1 = LlmResponse(
|
||||
content=types.Content(role='model', parts=[delegation_to_roll_agent]),
|
||||
turn_complete=False,
|
||||
)
|
||||
root_response2 = LlmResponse(turn_complete=True)
|
||||
mock_root_model = testing_utils.MockModel.create(
|
||||
[root_response1, root_response2]
|
||||
)
|
||||
|
||||
# Mock response for the roll_agent to call its `roll_die` tool.
|
||||
function_call = types.Part.from_function_call(
|
||||
name='roll_die', args={'sides': 20}
|
||||
)
|
||||
roll_agent_response1 = LlmResponse(
|
||||
content=types.Content(role='model', parts=[function_call]),
|
||||
turn_complete=False,
|
||||
)
|
||||
roll_agent_response2 = LlmResponse(turn_complete=True)
|
||||
mock_roll_model = testing_utils.MockModel.create(
|
||||
[roll_agent_response1, roll_agent_response2]
|
||||
)
|
||||
|
||||
# --- 2. Mock Tools and Agents ---
|
||||
|
||||
def roll_die(sides: int) -> int:
|
||||
"""Rolls a die and returns a fixed result for testing."""
|
||||
return 15
|
||||
|
||||
mock_roll_sub_agent = Agent(
|
||||
name='roll_agent',
|
||||
model=mock_roll_model,
|
||||
tools=[roll_die],
|
||||
)
|
||||
|
||||
main_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_root_model,
|
||||
sub_agents=[mock_roll_sub_agent],
|
||||
)
|
||||
|
||||
# --- 3. Test Runner Setup ---
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
run_config: testing_utils.RunConfig = None,
|
||||
) -> list[testing_utils.Event]:
|
||||
collected_responses = []
|
||||
|
||||
async def consume_responses(session: testing_utils.Session):
|
||||
run_res = self.runner.run_live(
|
||||
session=session,
|
||||
live_request_queue=live_request_queue,
|
||||
run_config=run_config or testing_utils.RunConfig(),
|
||||
)
|
||||
async for response in run_res:
|
||||
collected_responses.append(response)
|
||||
if len(collected_responses) >= 5:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
return collected_responses
|
||||
|
||||
runner = CustomTestRunner(root_agent=main_agent)
|
||||
live_request_queue = LiveRequestQueue()
|
||||
live_request_queue.send_realtime(
|
||||
blob=types.Blob(data=b'Roll a 20-sided die', mime_type='audio/pcm')
|
||||
)
|
||||
|
||||
# --- 4. Run and Assert ---
|
||||
res_events = runner.run_live(live_request_queue)
|
||||
|
||||
assert res_events is not None, 'Expected a list of events, but got None.'
|
||||
assert len(res_events) >= 1, 'Expected at least one event.'
|
||||
|
||||
delegation_found = False
|
||||
tool_call_found = False
|
||||
tool_response_found = False
|
||||
|
||||
for event in res_events:
|
||||
if event.content and event.content.parts:
|
||||
for part in event.content.parts:
|
||||
if part.function_call:
|
||||
# FIX: Check for the function call that represents delegation.
|
||||
if part.function_call.name == 'transfer_to_agent':
|
||||
delegation_found = True
|
||||
assert part.function_call.args == {'agent_name': 'roll_agent'}
|
||||
|
||||
# Check for the function call made by the roll_agent.
|
||||
if part.function_call.name == 'roll_die':
|
||||
tool_call_found = True
|
||||
assert part.function_call.args['sides'] == 20
|
||||
|
||||
# Check for the result from the executed function.
|
||||
if part.function_response and part.function_response.name == 'roll_die':
|
||||
tool_response_found = True
|
||||
assert part.function_response.response['result'] == 15
|
||||
|
||||
assert delegation_found, 'A function_call event for delegation was not found.'
|
||||
assert tool_call_found, 'A function_call event for roll_die was not found.'
|
||||
assert tool_response_found, 'A function_response for roll_die was not found.'
|
||||
|
||||
|
||||
def test_live_streaming_connection_error_on_connect():
|
||||
"""
|
||||
Tests that the runner correctly handles a ConnectionClosed exception
|
||||
raised from the model's `connect` method during a live run.
|
||||
"""
|
||||
|
||||
# 1. Create a mock model that fails during the connection phase.
|
||||
class MockModelThatFailsToConnect(testing_utils.MockModel):
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
@override
|
||||
async def connect(self, llm_request: testing_utils.LlmRequest):
|
||||
"""Override connect to simulate an immediate connection failure."""
|
||||
|
||||
# FIX 2: Create a proper `Close` frame object first.
|
||||
close_frame = frames.Close(
|
||||
1007,
|
||||
'gemini-live-2.5-flash-preview is not supported in the live api.',
|
||||
)
|
||||
|
||||
# FIX 3: Pass the frame object to the `rcvd` parameter of the exception.
|
||||
raise ConnectionClosed(rcvd=close_frame, sent=None)
|
||||
|
||||
yield # pragma: no cover
|
||||
|
||||
# 2. Instantiate the custom mock model.
|
||||
mock_model = MockModelThatFailsToConnect(responses=[])
|
||||
|
||||
# 3. Set up the agent and runner.
|
||||
agent = Agent(name='test_agent_for_connection_failure', model=mock_model)
|
||||
runner = testing_utils.InMemoryRunner(root_agent=agent)
|
||||
live_request_queue = LiveRequestQueue()
|
||||
live_request_queue.send_realtime(
|
||||
blob=types.Blob(data=b'Initial audio chunk', mime_type='audio/pcm')
|
||||
)
|
||||
|
||||
# 4. Assert that `run_live` raises `ConnectionClosed`.
|
||||
with pytest.raises(ConnectionClosed) as excinfo:
|
||||
runner.run_live(live_request_queue)
|
||||
|
||||
# 5. Verify the details of the exception. The `code` and `reason` are
|
||||
# attributes of the received frame (`rcvd`), not the exception itself.
|
||||
assert excinfo.value.rcvd.code == 1007
|
||||
assert (
|
||||
'is not supported in the live api' in excinfo.value.rcvd.reason
|
||||
), 'The exception reason should match the simulated server error.'
|
||||
Reference in New Issue
Block a user