chore: Lazy register all streaming tools

Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com>
PiperOrigin-RevId: 869480442
This commit is contained in:
Xiang (Sean) Zhou
2026-02-12 18:50:07 -08:00
committed by Copybara-Service
parent 5269a6b1d6
commit ede925b502
4 changed files with 73 additions and 172 deletions
+14 -4
View File
@@ -825,11 +825,21 @@ async def _process_function_live_helper(
run_tool_and_update_queue(tool, function_args, tool_context)
)
# The tool is already registered in active_streaming_tools by
# runners.py at startup (all async-generator tools are registered
# there). Just attach the background task.
async with streaming_lock:
invocation_context.active_streaming_tools[tool.name].task = task
if invocation_context.active_streaming_tools is None:
invocation_context.active_streaming_tools = {}
if tool.name in invocation_context.active_streaming_tools:
invocation_context.active_streaming_tools[tool.name].task = task
else:
# Register the streaming tool lazily when the model calls it.
# For input-streaming tools (those with `input_stream:
# LiveRequestQueue`), _call_live will set .stream to a new
# LiveRequestQueue so _send_to_model starts duplicating data.
invocation_context.active_streaming_tools[tool.name] = (
ActiveStreamingTool(task=task)
)
logger.debug('Lazily registered streaming tool: %s', tool.name)
# Immediately return a pending response.
# This is required by current live model.
-37
View File
@@ -32,7 +32,6 @@ from google.adk.apps.compaction import _run_compaction_for_sliding_window
from google.adk.artifacts import artifact_util
from google.genai import types
from .agents.active_streaming_tool import ActiveStreamingTool
from .agents.base_agent import BaseAgent
from .agents.base_agent import BaseAgentState
from .agents.context_cache_config import ContextCacheConfig
@@ -1012,42 +1011,6 @@ class Runner:
root_agent = self.agent
invocation_context.agent = self._find_agent_to_run(session, root_agent)
# Pre-processing for live streaming tools
# Inspect the tool's parameters to find if it uses LiveRequestQueue
invocation_context.active_streaming_tools = {}
# For shell agents, there is no canonical_tools method so we should skip.
if hasattr(invocation_context.agent, 'canonical_tools'):
import inspect
# Use canonical_tools to get properly wrapped BaseTool instances
canonical_tools = await invocation_context.agent.canonical_tools(
invocation_context
)
# Register all async-generator tools as streaming tools.
# A streaming tool is any tool whose underlying function is an
# async generator (i.e. uses `yield`). There are two sub-types:
# 1. Input-streaming tools: accept a `input_stream:
# LiveRequestQueue` parameter to consume the live audio/video
# stream. The stream is created lazily in `_call_live` when
# the model actually calls the tool.
# 2. Output-streaming tools: async generators that yield results
# over time but don't consume the live stream. They are run
# as background tasks when called.
# Both types are registered here with `stream=None`. The
# distinction between them is made at call time.
for tool in canonical_tools:
callable_to_inspect = tool.func if hasattr(tool, 'func') else tool
if not callable(callable_to_inspect):
continue
if inspect.isasyncgenfunction(callable_to_inspect):
if not invocation_context.active_streaming_tools:
invocation_context.active_streaming_tools = {}
logger.debug('Register streaming tool: %s', tool.name)
active_streaming_tool = ActiveStreamingTool()
invocation_context.active_streaming_tools[tool.name] = (
active_streaming_tool
)
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
async with Aclosing(ctx.agent.run_live(ctx)) as agen:
async for event in agen:
+59 -47
View File
@@ -1322,10 +1322,10 @@ class _LiveTestRunner(testing_utils.InMemoryRunner):
return collected
def test_input_streaming_tool_stream_is_none_before_model_calls():
"""Test that input-streaming tools have stream=None until the model calls them."""
# Add a text response before the function call so we can observe stream
# state between registration and tool invocation.
def test_input_streaming_tool_registered_lazily_with_stream():
"""Test that input-streaming tools are registered lazily when called and receive a stream."""
# A text response before the function call lets us observe that the
# tool is NOT registered before the model calls it.
text_response = LlmResponse(
content=types.Content(
role='model',
@@ -1362,7 +1362,7 @@ def test_input_streaming_tool_stream_is_none_before_model_calls():
runner = _LiveTestRunner(root_agent=root_agent)
# Capture the invocation context to inspect stream state.
# Capture the invocation context to inspect registration state.
captured_context = None
original_method = runner.runner._new_invocation_context_for_live
@@ -1379,39 +1379,39 @@ def test_input_streaming_tool_stream_is_none_before_model_calls():
blob=types.Blob(data=b'test_data', mime_type='audio/pcm')
)
# Collect events and capture stream state before the tool is called.
# Collect events and check that the tool is NOT registered before
# the model calls it.
collected = []
stream_states_before_call = []
not_registered_before_call = None
async def consume(session: testing_utils.Session):
nonlocal not_registered_before_call
async for response in runner.runner.run_live(
session=session,
live_request_queue=live_request_queue,
):
collected.append(response)
# On a non-function-call event, the tool is registered but not
# yet invoked — capture the stream value at that point.
# On the first non-function-call event, verify the tool is not
# yet registered (lazy registration).
active = (
captured_context.active_streaming_tools if captured_context else {}
captured_context.active_streaming_tools if captured_context else None
)
if (
not stream_states_before_call
not_registered_before_call is None
and not response.get_function_calls()
and 'monitor_video_stream' in active
):
stream_states_before_call.append(active['monitor_video_stream'].stream)
not_registered_before_call = (
active is None or 'monitor_video_stream' not in active
)
if len(collected) >= 4:
return
runner._run_with_loop(asyncio.wait_for(consume(runner.session), timeout=5.0))
# Before the model calls the tool, stream should be None.
# Tool should not be registered before the model calls it.
assert (
stream_states_before_call
), 'Stream state was never observed before the tool call'
assert (
stream_states_before_call[0] is None
), 'Expected stream to be None before the model calls the tool'
not_registered_before_call is True
), 'Expected tool to NOT be registered before the model calls it'
# When the model calls the tool, input_stream should be provided.
assert (
stream_state_during_call is True
@@ -1458,17 +1458,20 @@ def test_stop_streaming_resets_stream_to_none():
runner = _LiveTestRunner(root_agent=root_agent)
# Capture invocation context to verify stream is reset.
captured_context = None
original_method = runner.runner._new_invocation_context_for_live
# Capture the child invocation context (created by _create_invocation_context
# inside base_agent.run_live) to inspect active_streaming_tools.
# We cannot use the parent context from _new_invocation_context_for_live
# because model_copy creates a separate child object.
captured_child_context = None
original_create = root_agent._create_invocation_context
def capturing_method(*args, **kwargs):
nonlocal captured_context
ctx = original_method(*args, **kwargs)
captured_context = ctx
def capturing_create(*args, **kwargs):
nonlocal captured_child_context
ctx = original_create(*args, **kwargs)
captured_child_context = ctx
return ctx
runner.runner._new_invocation_context_for_live = capturing_method
root_agent._create_invocation_context = capturing_create
live_request_queue = LiveRequestQueue()
live_request_queue.send_realtime(
@@ -1488,9 +1491,9 @@ def test_stop_streaming_resets_stream_to_none():
# Verify that stop_streaming reset the stream to None.
assert (
captured_context is not None
), 'Expected invocation context to be captured'
active_tools = captured_context.active_streaming_tools or {}
captured_child_context is not None
), 'Expected child invocation context to be captured'
active_tools = captured_child_context.active_streaming_tools or {}
assert (
'monitor_stock_price' in active_tools
), 'Expected monitor_stock_price in active_streaming_tools'
@@ -1499,11 +1502,18 @@ def test_stop_streaming_resets_stream_to_none():
), 'Expected stream to be reset to None after stop_streaming'
def test_output_streaming_tool_registered_at_startup():
"""Test that output-streaming tools (async generators without LiveRequestQueue) are registered at startup."""
response1 = LlmResponse(turn_complete=True)
def test_output_streaming_tool_registered_lazily_without_stream():
"""Test that output-streaming tools are registered lazily when called, with stream=None."""
function_call = types.Part.from_function_call(
name='monitor_stock_price', args={'stock_symbol': 'GOOG'}
)
response1 = LlmResponse(
content=types.Content(role='model', parts=[function_call]),
turn_complete=False,
)
response2 = LlmResponse(turn_complete=True)
mock_model = testing_utils.MockModel.create([response1])
mock_model = testing_utils.MockModel.create([response1, response2])
async def monitor_stock_price(stock_symbol: str):
"""Yield periodic price updates."""
@@ -1517,31 +1527,33 @@ def test_output_streaming_tool_registered_at_startup():
runner = _LiveTestRunner(root_agent=root_agent)
# Capture invocation context to verify registration.
captured_context = None
original_method = runner.runner._new_invocation_context_for_live
# Capture the child invocation context (created by _create_invocation_context
# inside base_agent.run_live) to inspect active_streaming_tools.
captured_child_context = None
original_create = root_agent._create_invocation_context
def capturing_method(*args, **kwargs):
nonlocal captured_context
ctx = original_method(*args, **kwargs)
captured_context = ctx
def capturing_create(*args, **kwargs):
nonlocal captured_child_context
ctx = original_create(*args, **kwargs)
captured_child_context = ctx
return ctx
runner.runner._new_invocation_context_for_live = capturing_method
root_agent._create_invocation_context = capturing_create
live_request_queue = LiveRequestQueue()
live_request_queue.send_realtime(
blob=types.Blob(data=b'test', mime_type='audio/pcm')
)
runner.run_live(live_request_queue, max_responses=1)
runner.run_live(live_request_queue, max_responses=3)
# Output-streaming tool should be registered with stream=None.
assert captured_context is not None
active_tools = captured_context.active_streaming_tools or {}
# After the model calls the tool, it should be registered with
# stream=None (output-streaming tools don't consume the live stream).
assert captured_child_context is not None
active_tools = captured_child_context.active_streaming_tools or {}
assert (
'monitor_stock_price' in active_tools
), 'Expected output-streaming tool to be registered at startup'
), 'Expected output-streaming tool to be registered when called'
assert (
active_tools['monitor_stock_price'].stream is None
), 'Expected stream to be None for output-streaming tool'
-84
View File
@@ -23,7 +23,6 @@ from unittest.mock import AsyncMock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.live_request_queue import LiveRequestQueue
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.apps.app import App
@@ -35,7 +34,6 @@ from google.adk.plugins.base_plugin import BasePlugin
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.session import Session
from google.adk.tools.function_tool import FunctionTool
from google.genai import types
import pytest
@@ -360,88 +358,6 @@ async def test_run_live_auto_create_session():
assert session is not None
@pytest.mark.asyncio
async def test_run_live_detects_streaming_tools_with_canonical_tools():
"""run_live should detect streaming tools using canonical_tools and tool.name."""
# Define streaming tools - one as raw function, one wrapped in FunctionTool
async def raw_streaming_tool(
input_stream: LiveRequestQueue,
) -> AsyncGenerator[str, None]:
"""A raw streaming tool function."""
yield "test"
async def wrapped_streaming_tool(
input_stream: LiveRequestQueue,
) -> AsyncGenerator[str, None]:
"""A streaming tool wrapped in FunctionTool."""
yield "test"
def non_streaming_tool(param: str) -> str:
"""A regular non-streaming tool."""
return param
# Create a mock LlmAgent that yields an event and captures invocation context
captured_context = {}
class StreamingToolsAgent(LlmAgent):
async def _run_live_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
# Capture the active_streaming_tools for verification
captured_context["active_streaming_tools"] = (
invocation_context.active_streaming_tools
)
yield Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
content=types.Content(
role="model", parts=[types.Part(text="streaming test")]
),
)
agent = StreamingToolsAgent(
name="streaming_agent",
model="gemini-2.0-flash",
tools=[
raw_streaming_tool, # Raw function
FunctionTool(wrapped_streaming_tool), # Wrapped in FunctionTool
non_streaming_tool, # Non-streaming tool (should not be detected)
],
)
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name="streaming_test_app",
agent=agent,
session_service=session_service,
artifact_service=artifact_service,
auto_create_session=True,
)
live_queue = LiveRequestQueue()
agen = runner.run_live(
user_id="user",
session_id="test_session",
live_request_queue=live_queue,
)
event = await agen.__anext__()
await agen.aclose()
assert event.author == "streaming_agent"
# Verify streaming tools were detected correctly
active_tools = captured_context.get("active_streaming_tools", {})
assert "raw_streaming_tool" in active_tools
assert "wrapped_streaming_tool" in active_tools
# Non-streaming tool should not be detected
assert "non_streaming_tool" not in active_tools
@pytest.mark.asyncio
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
project_root = tmp_path / "workspace"