You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Lazy register all streaming tools
Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 869480442
This commit is contained in:
committed by
Copybara-Service
parent
5269a6b1d6
commit
ede925b502
@@ -825,11 +825,21 @@ async def _process_function_live_helper(
|
||||
run_tool_and_update_queue(tool, function_args, tool_context)
|
||||
)
|
||||
|
||||
# The tool is already registered in active_streaming_tools by
|
||||
# runners.py at startup (all async-generator tools are registered
|
||||
# there). Just attach the background task.
|
||||
async with streaming_lock:
|
||||
invocation_context.active_streaming_tools[tool.name].task = task
|
||||
|
||||
if invocation_context.active_streaming_tools is None:
|
||||
invocation_context.active_streaming_tools = {}
|
||||
if tool.name in invocation_context.active_streaming_tools:
|
||||
invocation_context.active_streaming_tools[tool.name].task = task
|
||||
else:
|
||||
# Register the streaming tool lazily when the model calls it.
|
||||
# For input-streaming tools (those with `input_stream:
|
||||
# LiveRequestQueue`), _call_live will set .stream to a new
|
||||
# LiveRequestQueue so _send_to_model starts duplicating data.
|
||||
invocation_context.active_streaming_tools[tool.name] = (
|
||||
ActiveStreamingTool(task=task)
|
||||
)
|
||||
logger.debug('Lazily registered streaming tool: %s', tool.name)
|
||||
|
||||
# Immediately return a pending response.
|
||||
# This is required by current live model.
|
||||
|
||||
@@ -32,7 +32,6 @@ from google.adk.apps.compaction import _run_compaction_for_sliding_window
|
||||
from google.adk.artifacts import artifact_util
|
||||
from google.genai import types
|
||||
|
||||
from .agents.active_streaming_tool import ActiveStreamingTool
|
||||
from .agents.base_agent import BaseAgent
|
||||
from .agents.base_agent import BaseAgentState
|
||||
from .agents.context_cache_config import ContextCacheConfig
|
||||
@@ -1012,42 +1011,6 @@ class Runner:
|
||||
root_agent = self.agent
|
||||
invocation_context.agent = self._find_agent_to_run(session, root_agent)
|
||||
|
||||
# Pre-processing for live streaming tools
|
||||
# Inspect the tool's parameters to find if it uses LiveRequestQueue
|
||||
invocation_context.active_streaming_tools = {}
|
||||
# For shell agents, there is no canonical_tools method so we should skip.
|
||||
if hasattr(invocation_context.agent, 'canonical_tools'):
|
||||
import inspect
|
||||
|
||||
# Use canonical_tools to get properly wrapped BaseTool instances
|
||||
canonical_tools = await invocation_context.agent.canonical_tools(
|
||||
invocation_context
|
||||
)
|
||||
# Register all async-generator tools as streaming tools.
|
||||
# A streaming tool is any tool whose underlying function is an
|
||||
# async generator (i.e. uses `yield`). There are two sub-types:
|
||||
# 1. Input-streaming tools: accept a `input_stream:
|
||||
# LiveRequestQueue` parameter to consume the live audio/video
|
||||
# stream. The stream is created lazily in `_call_live` when
|
||||
# the model actually calls the tool.
|
||||
# 2. Output-streaming tools: async generators that yield results
|
||||
# over time but don't consume the live stream. They are run
|
||||
# as background tasks when called.
|
||||
# Both types are registered here with `stream=None`. The
|
||||
# distinction between them is made at call time.
|
||||
for tool in canonical_tools:
|
||||
callable_to_inspect = tool.func if hasattr(tool, 'func') else tool
|
||||
if not callable(callable_to_inspect):
|
||||
continue
|
||||
if inspect.isasyncgenfunction(callable_to_inspect):
|
||||
if not invocation_context.active_streaming_tools:
|
||||
invocation_context.active_streaming_tools = {}
|
||||
logger.debug('Register streaming tool: %s', tool.name)
|
||||
active_streaming_tool = ActiveStreamingTool()
|
||||
invocation_context.active_streaming_tools[tool.name] = (
|
||||
active_streaming_tool
|
||||
)
|
||||
|
||||
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
|
||||
async with Aclosing(ctx.agent.run_live(ctx)) as agen:
|
||||
async for event in agen:
|
||||
|
||||
@@ -1322,10 +1322,10 @@ class _LiveTestRunner(testing_utils.InMemoryRunner):
|
||||
return collected
|
||||
|
||||
|
||||
def test_input_streaming_tool_stream_is_none_before_model_calls():
|
||||
"""Test that input-streaming tools have stream=None until the model calls them."""
|
||||
# Add a text response before the function call so we can observe stream
|
||||
# state between registration and tool invocation.
|
||||
def test_input_streaming_tool_registered_lazily_with_stream():
|
||||
"""Test that input-streaming tools are registered lazily when called and receive a stream."""
|
||||
# A text response before the function call lets us observe that the
|
||||
# tool is NOT registered before the model calls it.
|
||||
text_response = LlmResponse(
|
||||
content=types.Content(
|
||||
role='model',
|
||||
@@ -1362,7 +1362,7 @@ def test_input_streaming_tool_stream_is_none_before_model_calls():
|
||||
|
||||
runner = _LiveTestRunner(root_agent=root_agent)
|
||||
|
||||
# Capture the invocation context to inspect stream state.
|
||||
# Capture the invocation context to inspect registration state.
|
||||
captured_context = None
|
||||
original_method = runner.runner._new_invocation_context_for_live
|
||||
|
||||
@@ -1379,39 +1379,39 @@ def test_input_streaming_tool_stream_is_none_before_model_calls():
|
||||
blob=types.Blob(data=b'test_data', mime_type='audio/pcm')
|
||||
)
|
||||
|
||||
# Collect events and capture stream state before the tool is called.
|
||||
# Collect events and check that the tool is NOT registered before
|
||||
# the model calls it.
|
||||
collected = []
|
||||
stream_states_before_call = []
|
||||
not_registered_before_call = None
|
||||
|
||||
async def consume(session: testing_utils.Session):
|
||||
nonlocal not_registered_before_call
|
||||
async for response in runner.runner.run_live(
|
||||
session=session,
|
||||
live_request_queue=live_request_queue,
|
||||
):
|
||||
collected.append(response)
|
||||
# On a non-function-call event, the tool is registered but not
|
||||
# yet invoked — capture the stream value at that point.
|
||||
# On the first non-function-call event, verify the tool is not
|
||||
# yet registered (lazy registration).
|
||||
active = (
|
||||
captured_context.active_streaming_tools if captured_context else {}
|
||||
captured_context.active_streaming_tools if captured_context else None
|
||||
)
|
||||
if (
|
||||
not stream_states_before_call
|
||||
not_registered_before_call is None
|
||||
and not response.get_function_calls()
|
||||
and 'monitor_video_stream' in active
|
||||
):
|
||||
stream_states_before_call.append(active['monitor_video_stream'].stream)
|
||||
not_registered_before_call = (
|
||||
active is None or 'monitor_video_stream' not in active
|
||||
)
|
||||
if len(collected) >= 4:
|
||||
return
|
||||
|
||||
runner._run_with_loop(asyncio.wait_for(consume(runner.session), timeout=5.0))
|
||||
|
||||
# Before the model calls the tool, stream should be None.
|
||||
# Tool should not be registered before the model calls it.
|
||||
assert (
|
||||
stream_states_before_call
|
||||
), 'Stream state was never observed before the tool call'
|
||||
assert (
|
||||
stream_states_before_call[0] is None
|
||||
), 'Expected stream to be None before the model calls the tool'
|
||||
not_registered_before_call is True
|
||||
), 'Expected tool to NOT be registered before the model calls it'
|
||||
# When the model calls the tool, input_stream should be provided.
|
||||
assert (
|
||||
stream_state_during_call is True
|
||||
@@ -1458,17 +1458,20 @@ def test_stop_streaming_resets_stream_to_none():
|
||||
|
||||
runner = _LiveTestRunner(root_agent=root_agent)
|
||||
|
||||
# Capture invocation context to verify stream is reset.
|
||||
captured_context = None
|
||||
original_method = runner.runner._new_invocation_context_for_live
|
||||
# Capture the child invocation context (created by _create_invocation_context
|
||||
# inside base_agent.run_live) to inspect active_streaming_tools.
|
||||
# We cannot use the parent context from _new_invocation_context_for_live
|
||||
# because model_copy creates a separate child object.
|
||||
captured_child_context = None
|
||||
original_create = root_agent._create_invocation_context
|
||||
|
||||
def capturing_method(*args, **kwargs):
|
||||
nonlocal captured_context
|
||||
ctx = original_method(*args, **kwargs)
|
||||
captured_context = ctx
|
||||
def capturing_create(*args, **kwargs):
|
||||
nonlocal captured_child_context
|
||||
ctx = original_create(*args, **kwargs)
|
||||
captured_child_context = ctx
|
||||
return ctx
|
||||
|
||||
runner.runner._new_invocation_context_for_live = capturing_method
|
||||
root_agent._create_invocation_context = capturing_create
|
||||
|
||||
live_request_queue = LiveRequestQueue()
|
||||
live_request_queue.send_realtime(
|
||||
@@ -1488,9 +1491,9 @@ def test_stop_streaming_resets_stream_to_none():
|
||||
|
||||
# Verify that stop_streaming reset the stream to None.
|
||||
assert (
|
||||
captured_context is not None
|
||||
), 'Expected invocation context to be captured'
|
||||
active_tools = captured_context.active_streaming_tools or {}
|
||||
captured_child_context is not None
|
||||
), 'Expected child invocation context to be captured'
|
||||
active_tools = captured_child_context.active_streaming_tools or {}
|
||||
assert (
|
||||
'monitor_stock_price' in active_tools
|
||||
), 'Expected monitor_stock_price in active_streaming_tools'
|
||||
@@ -1499,11 +1502,18 @@ def test_stop_streaming_resets_stream_to_none():
|
||||
), 'Expected stream to be reset to None after stop_streaming'
|
||||
|
||||
|
||||
def test_output_streaming_tool_registered_at_startup():
|
||||
"""Test that output-streaming tools (async generators without LiveRequestQueue) are registered at startup."""
|
||||
response1 = LlmResponse(turn_complete=True)
|
||||
def test_output_streaming_tool_registered_lazily_without_stream():
|
||||
"""Test that output-streaming tools are registered lazily when called, with stream=None."""
|
||||
function_call = types.Part.from_function_call(
|
||||
name='monitor_stock_price', args={'stock_symbol': 'GOOG'}
|
||||
)
|
||||
response1 = LlmResponse(
|
||||
content=types.Content(role='model', parts=[function_call]),
|
||||
turn_complete=False,
|
||||
)
|
||||
response2 = LlmResponse(turn_complete=True)
|
||||
|
||||
mock_model = testing_utils.MockModel.create([response1])
|
||||
mock_model = testing_utils.MockModel.create([response1, response2])
|
||||
|
||||
async def monitor_stock_price(stock_symbol: str):
|
||||
"""Yield periodic price updates."""
|
||||
@@ -1517,31 +1527,33 @@ def test_output_streaming_tool_registered_at_startup():
|
||||
|
||||
runner = _LiveTestRunner(root_agent=root_agent)
|
||||
|
||||
# Capture invocation context to verify registration.
|
||||
captured_context = None
|
||||
original_method = runner.runner._new_invocation_context_for_live
|
||||
# Capture the child invocation context (created by _create_invocation_context
|
||||
# inside base_agent.run_live) to inspect active_streaming_tools.
|
||||
captured_child_context = None
|
||||
original_create = root_agent._create_invocation_context
|
||||
|
||||
def capturing_method(*args, **kwargs):
|
||||
nonlocal captured_context
|
||||
ctx = original_method(*args, **kwargs)
|
||||
captured_context = ctx
|
||||
def capturing_create(*args, **kwargs):
|
||||
nonlocal captured_child_context
|
||||
ctx = original_create(*args, **kwargs)
|
||||
captured_child_context = ctx
|
||||
return ctx
|
||||
|
||||
runner.runner._new_invocation_context_for_live = capturing_method
|
||||
root_agent._create_invocation_context = capturing_create
|
||||
|
||||
live_request_queue = LiveRequestQueue()
|
||||
live_request_queue.send_realtime(
|
||||
blob=types.Blob(data=b'test', mime_type='audio/pcm')
|
||||
)
|
||||
|
||||
runner.run_live(live_request_queue, max_responses=1)
|
||||
runner.run_live(live_request_queue, max_responses=3)
|
||||
|
||||
# Output-streaming tool should be registered with stream=None.
|
||||
assert captured_context is not None
|
||||
active_tools = captured_context.active_streaming_tools or {}
|
||||
# After the model calls the tool, it should be registered with
|
||||
# stream=None (output-streaming tools don't consume the live stream).
|
||||
assert captured_child_context is not None
|
||||
active_tools = captured_child_context.active_streaming_tools or {}
|
||||
assert (
|
||||
'monitor_stock_price' in active_tools
|
||||
), 'Expected output-streaming tool to be registered at startup'
|
||||
), 'Expected output-streaming tool to be registered when called'
|
||||
assert (
|
||||
active_tools['monitor_stock_price'].stream is None
|
||||
), 'Expected stream to be None for output-streaming tool'
|
||||
|
||||
@@ -23,7 +23,6 @@ from unittest.mock import AsyncMock
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.context_cache_config import ContextCacheConfig
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.live_request_queue import LiveRequestQueue
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.apps.app import App
|
||||
@@ -35,7 +34,6 @@ from google.adk.plugins.base_plugin import BasePlugin
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
@@ -360,88 +358,6 @@ async def test_run_live_auto_create_session():
|
||||
assert session is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live_detects_streaming_tools_with_canonical_tools():
|
||||
"""run_live should detect streaming tools using canonical_tools and tool.name."""
|
||||
|
||||
# Define streaming tools - one as raw function, one wrapped in FunctionTool
|
||||
async def raw_streaming_tool(
|
||||
input_stream: LiveRequestQueue,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""A raw streaming tool function."""
|
||||
yield "test"
|
||||
|
||||
async def wrapped_streaming_tool(
|
||||
input_stream: LiveRequestQueue,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""A streaming tool wrapped in FunctionTool."""
|
||||
yield "test"
|
||||
|
||||
def non_streaming_tool(param: str) -> str:
|
||||
"""A regular non-streaming tool."""
|
||||
return param
|
||||
|
||||
# Create a mock LlmAgent that yields an event and captures invocation context
|
||||
captured_context = {}
|
||||
|
||||
class StreamingToolsAgent(LlmAgent):
|
||||
|
||||
async def _run_live_impl(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
# Capture the active_streaming_tools for verification
|
||||
captured_context["active_streaming_tools"] = (
|
||||
invocation_context.active_streaming_tools
|
||||
)
|
||||
yield Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(text="streaming test")]
|
||||
),
|
||||
)
|
||||
|
||||
agent = StreamingToolsAgent(
|
||||
name="streaming_agent",
|
||||
model="gemini-2.0-flash",
|
||||
tools=[
|
||||
raw_streaming_tool, # Raw function
|
||||
FunctionTool(wrapped_streaming_tool), # Wrapped in FunctionTool
|
||||
non_streaming_tool, # Non-streaming tool (should not be detected)
|
||||
],
|
||||
)
|
||||
|
||||
session_service = InMemorySessionService()
|
||||
artifact_service = InMemoryArtifactService()
|
||||
runner = Runner(
|
||||
app_name="streaming_test_app",
|
||||
agent=agent,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
auto_create_session=True,
|
||||
)
|
||||
|
||||
live_queue = LiveRequestQueue()
|
||||
|
||||
agen = runner.run_live(
|
||||
user_id="user",
|
||||
session_id="test_session",
|
||||
live_request_queue=live_queue,
|
||||
)
|
||||
|
||||
event = await agen.__anext__()
|
||||
await agen.aclose()
|
||||
|
||||
assert event.author == "streaming_agent"
|
||||
|
||||
# Verify streaming tools were detected correctly
|
||||
active_tools = captured_context.get("active_streaming_tools", {})
|
||||
assert "raw_streaming_tool" in active_tools
|
||||
assert "wrapped_streaming_tool" in active_tools
|
||||
# Non-streaming tool should not be detected
|
||||
assert "non_streaming_tool" not in active_tools
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
|
||||
project_root = tmp_path / "workspace"
|
||||
|
||||
Reference in New Issue
Block a user