From ede925b5025972cffcfaf178b2f81679fabbe90f Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 12 Feb 2026 18:50:07 -0800 Subject: [PATCH] chore: Lazy register all streaming tools Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 869480442 --- src/google/adk/flows/llm_flows/functions.py | 18 +++- src/google/adk/runners.py | 37 ------- tests/unittests/streaming/test_streaming.py | 106 +++++++++++--------- tests/unittests/test_runners.py | 84 ---------------- 4 files changed, 73 insertions(+), 172 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 3d041a46..4c120b73 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -825,11 +825,21 @@ async def _process_function_live_helper( run_tool_and_update_queue(tool, function_args, tool_context) ) - # The tool is already registered in active_streaming_tools by - # runners.py at startup (all async-generator tools are registered - # there). Just attach the background task. async with streaming_lock: - invocation_context.active_streaming_tools[tool.name].task = task + + if invocation_context.active_streaming_tools is None: + invocation_context.active_streaming_tools = {} + if tool.name in invocation_context.active_streaming_tools: + invocation_context.active_streaming_tools[tool.name].task = task + else: + # Register the streaming tool lazily when the model calls it. + # For input-streaming tools (those with `input_stream: + # LiveRequestQueue`), _call_live will set .stream to a new + # LiveRequestQueue so _send_to_model starts duplicating data. + invocation_context.active_streaming_tools[tool.name] = ( + ActiveStreamingTool(task=task) + ) + logger.debug('Lazily registered streaming tool: %s', tool.name) # Immediately return a pending response. # This is required by current live model. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index e66ecb72..bc0251a8 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -32,7 +32,6 @@ from google.adk.apps.compaction import _run_compaction_for_sliding_window from google.adk.artifacts import artifact_util from google.genai import types -from .agents.active_streaming_tool import ActiveStreamingTool from .agents.base_agent import BaseAgent from .agents.base_agent import BaseAgentState from .agents.context_cache_config import ContextCacheConfig @@ -1012,42 +1011,6 @@ class Runner: root_agent = self.agent invocation_context.agent = self._find_agent_to_run(session, root_agent) - # Pre-processing for live streaming tools - # Inspect the tool's parameters to find if it uses LiveRequestQueue - invocation_context.active_streaming_tools = {} - # For shell agents, there is no canonical_tools method so we should skip. - if hasattr(invocation_context.agent, 'canonical_tools'): - import inspect - - # Use canonical_tools to get properly wrapped BaseTool instances - canonical_tools = await invocation_context.agent.canonical_tools( - invocation_context - ) - # Register all async-generator tools as streaming tools. - # A streaming tool is any tool whose underlying function is an - # async generator (i.e. uses `yield`). There are two sub-types: - # 1. Input-streaming tools: accept a `input_stream: - # LiveRequestQueue` parameter to consume the live audio/video - # stream. The stream is created lazily in `_call_live` when - # the model actually calls the tool. - # 2. Output-streaming tools: async generators that yield results - # over time but don't consume the live stream. They are run - # as background tasks when called. - # Both types are registered here with `stream=None`. The - # distinction between them is made at call time. - for tool in canonical_tools: - callable_to_inspect = tool.func if hasattr(tool, 'func') else tool - if not callable(callable_to_inspect): - continue - if inspect.isasyncgenfunction(callable_to_inspect): - if not invocation_context.active_streaming_tools: - invocation_context.active_streaming_tools = {} - logger.debug('Register streaming tool: %s', tool.name) - active_streaming_tool = ActiveStreamingTool() - invocation_context.active_streaming_tools[tool.name] = ( - active_streaming_tool - ) - async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing(ctx.agent.run_live(ctx)) as agen: async for event in agen: diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index 5dc0fb7c..8c54502e 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -1322,10 +1322,10 @@ class _LiveTestRunner(testing_utils.InMemoryRunner): return collected -def test_input_streaming_tool_stream_is_none_before_model_calls(): - """Test that input-streaming tools have stream=None until the model calls them.""" - # Add a text response before the function call so we can observe stream - # state between registration and tool invocation. +def test_input_streaming_tool_registered_lazily_with_stream(): + """Test that input-streaming tools are registered lazily when called and receive a stream.""" + # A text response before the function call lets us observe that the + # tool is NOT registered before the model calls it. text_response = LlmResponse( content=types.Content( role='model', @@ -1362,7 +1362,7 @@ def test_input_streaming_tool_stream_is_none_before_model_calls(): runner = _LiveTestRunner(root_agent=root_agent) - # Capture the invocation context to inspect stream state. + # Capture the invocation context to inspect registration state. captured_context = None original_method = runner.runner._new_invocation_context_for_live @@ -1379,39 +1379,39 @@ def test_input_streaming_tool_stream_is_none_before_model_calls(): blob=types.Blob(data=b'test_data', mime_type='audio/pcm') ) - # Collect events and capture stream state before the tool is called. + # Collect events and check that the tool is NOT registered before + # the model calls it. collected = [] - stream_states_before_call = [] + not_registered_before_call = None async def consume(session: testing_utils.Session): + nonlocal not_registered_before_call async for response in runner.runner.run_live( session=session, live_request_queue=live_request_queue, ): collected.append(response) - # On a non-function-call event, the tool is registered but not - # yet invoked — capture the stream value at that point. + # On the first non-function-call event, verify the tool is not + # yet registered (lazy registration). active = ( - captured_context.active_streaming_tools if captured_context else {} + captured_context.active_streaming_tools if captured_context else None ) if ( - not stream_states_before_call + not_registered_before_call is None and not response.get_function_calls() - and 'monitor_video_stream' in active ): - stream_states_before_call.append(active['monitor_video_stream'].stream) + not_registered_before_call = ( + active is None or 'monitor_video_stream' not in active + ) if len(collected) >= 4: return runner._run_with_loop(asyncio.wait_for(consume(runner.session), timeout=5.0)) - # Before the model calls the tool, stream should be None. + # Tool should not be registered before the model calls it. assert ( - stream_states_before_call - ), 'Stream state was never observed before the tool call' - assert ( - stream_states_before_call[0] is None - ), 'Expected stream to be None before the model calls the tool' + not_registered_before_call is True + ), 'Expected tool to NOT be registered before the model calls it' # When the model calls the tool, input_stream should be provided. assert ( stream_state_during_call is True @@ -1458,17 +1458,20 @@ def test_stop_streaming_resets_stream_to_none(): runner = _LiveTestRunner(root_agent=root_agent) - # Capture invocation context to verify stream is reset. - captured_context = None - original_method = runner.runner._new_invocation_context_for_live + # Capture the child invocation context (created by _create_invocation_context + # inside base_agent.run_live) to inspect active_streaming_tools. + # We cannot use the parent context from _new_invocation_context_for_live + # because model_copy creates a separate child object. + captured_child_context = None + original_create = root_agent._create_invocation_context - def capturing_method(*args, **kwargs): - nonlocal captured_context - ctx = original_method(*args, **kwargs) - captured_context = ctx + def capturing_create(*args, **kwargs): + nonlocal captured_child_context + ctx = original_create(*args, **kwargs) + captured_child_context = ctx return ctx - runner.runner._new_invocation_context_for_live = capturing_method + root_agent._create_invocation_context = capturing_create live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( @@ -1488,9 +1491,9 @@ def test_stop_streaming_resets_stream_to_none(): # Verify that stop_streaming reset the stream to None. assert ( - captured_context is not None - ), 'Expected invocation context to be captured' - active_tools = captured_context.active_streaming_tools or {} + captured_child_context is not None + ), 'Expected child invocation context to be captured' + active_tools = captured_child_context.active_streaming_tools or {} assert ( 'monitor_stock_price' in active_tools ), 'Expected monitor_stock_price in active_streaming_tools' @@ -1499,11 +1502,18 @@ def test_stop_streaming_resets_stream_to_none(): ), 'Expected stream to be reset to None after stop_streaming' -def test_output_streaming_tool_registered_at_startup(): - """Test that output-streaming tools (async generators without LiveRequestQueue) are registered at startup.""" - response1 = LlmResponse(turn_complete=True) +def test_output_streaming_tool_registered_lazily_without_stream(): + """Test that output-streaming tools are registered lazily when called, with stream=None.""" + function_call = types.Part.from_function_call( + name='monitor_stock_price', args={'stock_symbol': 'GOOG'} + ) + response1 = LlmResponse( + content=types.Content(role='model', parts=[function_call]), + turn_complete=False, + ) + response2 = LlmResponse(turn_complete=True) - mock_model = testing_utils.MockModel.create([response1]) + mock_model = testing_utils.MockModel.create([response1, response2]) async def monitor_stock_price(stock_symbol: str): """Yield periodic price updates.""" @@ -1517,31 +1527,33 @@ def test_output_streaming_tool_registered_at_startup(): runner = _LiveTestRunner(root_agent=root_agent) - # Capture invocation context to verify registration. - captured_context = None - original_method = runner.runner._new_invocation_context_for_live + # Capture the child invocation context (created by _create_invocation_context + # inside base_agent.run_live) to inspect active_streaming_tools. + captured_child_context = None + original_create = root_agent._create_invocation_context - def capturing_method(*args, **kwargs): - nonlocal captured_context - ctx = original_method(*args, **kwargs) - captured_context = ctx + def capturing_create(*args, **kwargs): + nonlocal captured_child_context + ctx = original_create(*args, **kwargs) + captured_child_context = ctx return ctx - runner.runner._new_invocation_context_for_live = capturing_method + root_agent._create_invocation_context = capturing_create live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob(data=b'test', mime_type='audio/pcm') ) - runner.run_live(live_request_queue, max_responses=1) + runner.run_live(live_request_queue, max_responses=3) - # Output-streaming tool should be registered with stream=None. - assert captured_context is not None - active_tools = captured_context.active_streaming_tools or {} + # After the model calls the tool, it should be registered with + # stream=None (output-streaming tools don't consume the live stream). + assert captured_child_context is not None + active_tools = captured_child_context.active_streaming_tools or {} assert ( 'monitor_stock_price' in active_tools - ), 'Expected output-streaming tool to be registered at startup' + ), 'Expected output-streaming tool to be registered when called' assert ( active_tools['monitor_stock_price'].stream is None ), 'Expected stream to be None for output-streaming tool' diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 62b8d733..ca7eb375 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -23,7 +23,6 @@ from unittest.mock import AsyncMock from google.adk.agents.base_agent import BaseAgent from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.agents.invocation_context import InvocationContext -from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App @@ -35,7 +34,6 @@ from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.session import Session -from google.adk.tools.function_tool import FunctionTool from google.genai import types import pytest @@ -360,88 +358,6 @@ async def test_run_live_auto_create_session(): assert session is not None -@pytest.mark.asyncio -async def test_run_live_detects_streaming_tools_with_canonical_tools(): - """run_live should detect streaming tools using canonical_tools and tool.name.""" - - # Define streaming tools - one as raw function, one wrapped in FunctionTool - async def raw_streaming_tool( - input_stream: LiveRequestQueue, - ) -> AsyncGenerator[str, None]: - """A raw streaming tool function.""" - yield "test" - - async def wrapped_streaming_tool( - input_stream: LiveRequestQueue, - ) -> AsyncGenerator[str, None]: - """A streaming tool wrapped in FunctionTool.""" - yield "test" - - def non_streaming_tool(param: str) -> str: - """A regular non-streaming tool.""" - return param - - # Create a mock LlmAgent that yields an event and captures invocation context - captured_context = {} - - class StreamingToolsAgent(LlmAgent): - - async def _run_live_impl( - self, invocation_context: InvocationContext - ) -> AsyncGenerator[Event, None]: - # Capture the active_streaming_tools for verification - captured_context["active_streaming_tools"] = ( - invocation_context.active_streaming_tools - ) - yield Event( - invocation_id=invocation_context.invocation_id, - author=self.name, - content=types.Content( - role="model", parts=[types.Part(text="streaming test")] - ), - ) - - agent = StreamingToolsAgent( - name="streaming_agent", - model="gemini-2.0-flash", - tools=[ - raw_streaming_tool, # Raw function - FunctionTool(wrapped_streaming_tool), # Wrapped in FunctionTool - non_streaming_tool, # Non-streaming tool (should not be detected) - ], - ) - - session_service = InMemorySessionService() - artifact_service = InMemoryArtifactService() - runner = Runner( - app_name="streaming_test_app", - agent=agent, - session_service=session_service, - artifact_service=artifact_service, - auto_create_session=True, - ) - - live_queue = LiveRequestQueue() - - agen = runner.run_live( - user_id="user", - session_id="test_session", - live_request_queue=live_queue, - ) - - event = await agen.__anext__() - await agen.aclose() - - assert event.author == "streaming_agent" - - # Verify streaming tools were detected correctly - active_tools = captured_context.get("active_streaming_tools", {}) - assert "raw_streaming_tool" in active_tools - assert "wrapped_streaming_tool" in active_tools - # Non-streaming tool should not be detected - assert "non_streaming_tool" not in active_tools - - @pytest.mark.asyncio async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch): project_root = tmp_path / "workspace"