You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
2367901ec5
Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 858763407
1326 lines
42 KiB
Python
1326 lines
42 KiB
Python
# Copyright 2026 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import importlib
|
|
from pathlib import Path
|
|
import sys
|
|
import textwrap
|
|
from typing import AsyncGenerator
|
|
from typing import Optional
|
|
from unittest.mock import AsyncMock
|
|
|
|
from google.adk.agents.base_agent import BaseAgent
|
|
from google.adk.agents.context_cache_config import ContextCacheConfig
|
|
from google.adk.agents.invocation_context import InvocationContext
|
|
from google.adk.agents.live_request_queue import LiveRequestQueue
|
|
from google.adk.agents.llm_agent import LlmAgent
|
|
from google.adk.agents.run_config import RunConfig
|
|
from google.adk.apps.app import App
|
|
from google.adk.apps.app import ResumabilityConfig
|
|
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
|
from google.adk.cli.utils.agent_loader import AgentLoader
|
|
from google.adk.events.event import Event
|
|
from google.adk.plugins.base_plugin import BasePlugin
|
|
from google.adk.runners import Runner
|
|
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
|
from google.adk.sessions.session import Session
|
|
from google.adk.tools.function_tool import FunctionTool
|
|
from google.genai import types
|
|
import pytest
|
|
|
|
TEST_APP_ID = "test_app"
|
|
TEST_USER_ID = "test_user"
|
|
TEST_SESSION_ID = "test_session"
|
|
|
|
|
|
class MockAgent(BaseAgent):
|
|
"""Mock agent for unit testing."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
parent_agent: Optional[BaseAgent] = None,
|
|
):
|
|
super().__init__(name=name, sub_agents=[])
|
|
# BaseAgent doesn't have disallow_transfer_to_parent field
|
|
# This is intentional as we want to test non-LLM agents
|
|
if parent_agent:
|
|
self.parent_agent = parent_agent
|
|
|
|
async def _run_async_impl(
|
|
self, invocation_context: InvocationContext
|
|
) -> AsyncGenerator[Event, None]:
|
|
yield Event(
|
|
invocation_id=invocation_context.invocation_id,
|
|
author=self.name,
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(text="Test response")]
|
|
),
|
|
)
|
|
|
|
|
|
class MockLiveAgent(BaseAgent):
|
|
"""Mock live agent for unit testing."""
|
|
|
|
def __init__(self, name: str):
|
|
super().__init__(name=name, sub_agents=[])
|
|
|
|
async def _run_live_impl(
|
|
self, invocation_context: InvocationContext
|
|
) -> AsyncGenerator[Event, None]:
|
|
yield Event(
|
|
invocation_id=invocation_context.invocation_id,
|
|
author=self.name,
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(text="live hello")]
|
|
),
|
|
)
|
|
|
|
|
|
class MockLlmAgent(LlmAgent):
|
|
"""Mock LLM agent for unit testing."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
disallow_transfer_to_parent: bool = False,
|
|
parent_agent: Optional[BaseAgent] = None,
|
|
):
|
|
# Use a string model instead of mock
|
|
super().__init__(name=name, model="gemini-1.5-pro", sub_agents=[])
|
|
self.disallow_transfer_to_parent = disallow_transfer_to_parent
|
|
self.parent_agent = parent_agent
|
|
|
|
async def _run_async_impl(
|
|
self, invocation_context: InvocationContext
|
|
) -> AsyncGenerator[Event, None]:
|
|
yield Event(
|
|
invocation_id=invocation_context.invocation_id,
|
|
author=self.name,
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(text="Test LLM response")]
|
|
),
|
|
)
|
|
|
|
|
|
class MockAgentWithMetadata(BaseAgent):
|
|
"""Mock agent that returns event-level custom metadata."""
|
|
|
|
def __init__(self, name: str):
|
|
super().__init__(name=name, sub_agents=[])
|
|
|
|
async def _run_async_impl(
|
|
self, invocation_context: InvocationContext
|
|
) -> AsyncGenerator[Event, None]:
|
|
yield Event(
|
|
invocation_id=invocation_context.invocation_id,
|
|
author=self.name,
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(text="Test response")]
|
|
),
|
|
custom_metadata={"event_key": "event_value"},
|
|
)
|
|
|
|
|
|
class MockPlugin(BasePlugin):
|
|
"""Mock plugin for unit testing."""
|
|
|
|
ON_USER_CALLBACK_MSG = (
|
|
"Modified user message ON_USER_CALLBACK_MSG from MockPlugin"
|
|
)
|
|
ON_EVENT_CALLBACK_MSG = "Modified event ON_EVENT_CALLBACK_MSG from MockPlugin"
|
|
|
|
def __init__(self):
|
|
super().__init__(name="mock_plugin")
|
|
self.enable_user_message_callback = False
|
|
self.enable_event_callback = False
|
|
self.user_content_seen_in_before_run_callback = None
|
|
|
|
async def on_user_message_callback(
|
|
self,
|
|
*,
|
|
invocation_context: InvocationContext,
|
|
user_message: types.Content,
|
|
) -> Optional[types.Content]:
|
|
if not self.enable_user_message_callback:
|
|
return None
|
|
return types.Content(
|
|
role="model",
|
|
parts=[types.Part(text=self.ON_USER_CALLBACK_MSG)],
|
|
)
|
|
|
|
async def before_run_callback(
|
|
self,
|
|
*,
|
|
invocation_context: InvocationContext,
|
|
) -> None:
|
|
self.user_content_seen_in_before_run_callback = (
|
|
invocation_context.user_content
|
|
)
|
|
|
|
async def on_event_callback(
|
|
self, *, invocation_context: InvocationContext, event: Event
|
|
) -> Optional[Event]:
|
|
if not self.enable_event_callback:
|
|
return None
|
|
return Event(
|
|
invocation_id="",
|
|
author="",
|
|
content=types.Content(
|
|
parts=[
|
|
types.Part(
|
|
text=self.ON_EVENT_CALLBACK_MSG,
|
|
)
|
|
],
|
|
role=event.content.role,
|
|
),
|
|
)
|
|
|
|
|
|
class TestRunnerFindAgentToRun:
|
|
"""Tests for Runner._find_agent_to_run method."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test fixtures."""
|
|
self.session_service = InMemorySessionService()
|
|
self.artifact_service = InMemoryArtifactService()
|
|
|
|
# Create test agents
|
|
self.root_agent = MockLlmAgent("root_agent")
|
|
self.sub_agent1 = MockLlmAgent("sub_agent1", parent_agent=self.root_agent)
|
|
self.sub_agent2 = MockLlmAgent("sub_agent2", parent_agent=self.root_agent)
|
|
self.non_transferable_agent = MockLlmAgent(
|
|
"non_transferable",
|
|
disallow_transfer_to_parent=True,
|
|
parent_agent=self.root_agent,
|
|
)
|
|
|
|
self.root_agent.sub_agents = [
|
|
self.sub_agent1,
|
|
self.sub_agent2,
|
|
self.non_transferable_agent,
|
|
]
|
|
|
|
self.runner = Runner(
|
|
app_name="test_app",
|
|
agent=self.root_agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_not_found_message_includes_alignment_hint():
|
|
|
|
class RunnerWithMismatch(Runner):
|
|
|
|
def _infer_agent_origin(
|
|
self, agent: BaseAgent
|
|
) -> tuple[Optional[str], Optional[Path]]:
|
|
del agent
|
|
return "expected_app", Path("/workspace/agents/expected_app")
|
|
|
|
session_service = InMemorySessionService()
|
|
runner = RunnerWithMismatch(
|
|
app_name="configured_app",
|
|
agent=MockLlmAgent("root_agent"),
|
|
session_service=session_service,
|
|
artifact_service=InMemoryArtifactService(),
|
|
)
|
|
|
|
agen = runner.run_async(
|
|
user_id="user",
|
|
session_id="missing",
|
|
new_message=types.Content(role="user", parts=[]),
|
|
)
|
|
|
|
with pytest.raises(ValueError) as excinfo:
|
|
await agen.__anext__()
|
|
|
|
await agen.aclose()
|
|
|
|
message = str(excinfo.value)
|
|
assert "Session not found" in message
|
|
assert "configured_app" in message
|
|
assert "expected_app" in message
|
|
assert "Ensure the runner app_name matches" in message
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_auto_creation():
|
|
|
|
class RunnerWithMismatch(Runner):
|
|
|
|
def _infer_agent_origin(
|
|
self, agent: BaseAgent
|
|
) -> tuple[Optional[str], Optional[Path]]:
|
|
del agent
|
|
return "expected_app", Path("/workspace/agents/expected_app")
|
|
|
|
session_service = InMemorySessionService()
|
|
runner = RunnerWithMismatch(
|
|
app_name="expected_app",
|
|
agent=MockLlmAgent("test_agent"),
|
|
session_service=session_service,
|
|
artifact_service=InMemoryArtifactService(),
|
|
auto_create_session=True,
|
|
)
|
|
|
|
agen = runner.run_async(
|
|
user_id="user",
|
|
session_id="missing",
|
|
new_message=types.Content(role="user", parts=[types.Part(text="hi")]),
|
|
)
|
|
|
|
event = await agen.__anext__()
|
|
await agen.aclose()
|
|
|
|
# Verify that session_id="missing" doesn't error out - session is auto-created
|
|
assert event.author == "test_agent"
|
|
assert event.content.parts[0].text == "Test LLM response"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rewind_auto_create_session_on_missing_session():
|
|
"""When auto_create_session=True, rewind should create session if missing.
|
|
|
|
The newly created session won't contain the target invocation, so
|
|
`rewind_async` should raise an Invocation ID not found error (rather than
|
|
a session not found error), demonstrating auto-creation occurred.
|
|
"""
|
|
session_service = InMemorySessionService()
|
|
runner = Runner(
|
|
app_name="auto_create_app",
|
|
agent=MockLlmAgent("agent_for_rewind"),
|
|
session_service=session_service,
|
|
artifact_service=InMemoryArtifactService(),
|
|
auto_create_session=True,
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=r"Invocation ID not found: inv_missing"):
|
|
await runner.rewind_async(
|
|
user_id="user",
|
|
session_id="missing",
|
|
rewind_before_invocation_id="inv_missing",
|
|
)
|
|
|
|
# Verify the session actually exists now due to auto-creation.
|
|
session = await session_service.get_session(
|
|
app_name="auto_create_app", user_id="user", session_id="missing"
|
|
)
|
|
assert session is not None
|
|
assert session.app_name == "auto_create_app"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_live_auto_create_session():
|
|
"""run_live should auto-create session when missing and yield events."""
|
|
session_service = InMemorySessionService()
|
|
artifact_service = InMemoryArtifactService()
|
|
runner = Runner(
|
|
app_name="live_app",
|
|
agent=MockLiveAgent("live_agent"),
|
|
session_service=session_service,
|
|
artifact_service=artifact_service,
|
|
auto_create_session=True,
|
|
)
|
|
|
|
# An empty LiveRequestQueue is sufficient for our mock agent.
|
|
from google.adk.agents.live_request_queue import LiveRequestQueue
|
|
|
|
live_queue = LiveRequestQueue()
|
|
|
|
agen = runner.run_live(
|
|
user_id="user",
|
|
session_id="missing",
|
|
live_request_queue=live_queue,
|
|
)
|
|
|
|
event = await agen.__anext__()
|
|
await agen.aclose()
|
|
|
|
assert event.author == "live_agent"
|
|
assert event.content.parts[0].text == "live hello"
|
|
|
|
# Session should have been created automatically.
|
|
session = await session_service.get_session(
|
|
app_name="live_app", user_id="user", session_id="missing"
|
|
)
|
|
assert session is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_live_detects_streaming_tools_with_canonical_tools():
|
|
"""run_live should detect streaming tools using canonical_tools and tool.name."""
|
|
|
|
# Define streaming tools - one as raw function, one wrapped in FunctionTool
|
|
async def raw_streaming_tool(
|
|
input_stream: LiveRequestQueue,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""A raw streaming tool function."""
|
|
yield "test"
|
|
|
|
async def wrapped_streaming_tool(
|
|
input_stream: LiveRequestQueue,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""A streaming tool wrapped in FunctionTool."""
|
|
yield "test"
|
|
|
|
def non_streaming_tool(param: str) -> str:
|
|
"""A regular non-streaming tool."""
|
|
return param
|
|
|
|
# Create a mock LlmAgent that yields an event and captures invocation context
|
|
captured_context = {}
|
|
|
|
class StreamingToolsAgent(LlmAgent):
|
|
|
|
async def _run_live_impl(
|
|
self, invocation_context: InvocationContext
|
|
) -> AsyncGenerator[Event, None]:
|
|
# Capture the active_streaming_tools for verification
|
|
captured_context["active_streaming_tools"] = (
|
|
invocation_context.active_streaming_tools
|
|
)
|
|
yield Event(
|
|
invocation_id=invocation_context.invocation_id,
|
|
author=self.name,
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(text="streaming test")]
|
|
),
|
|
)
|
|
|
|
agent = StreamingToolsAgent(
|
|
name="streaming_agent",
|
|
model="gemini-2.0-flash",
|
|
tools=[
|
|
raw_streaming_tool, # Raw function
|
|
FunctionTool(wrapped_streaming_tool), # Wrapped in FunctionTool
|
|
non_streaming_tool, # Non-streaming tool (should not be detected)
|
|
],
|
|
)
|
|
|
|
session_service = InMemorySessionService()
|
|
artifact_service = InMemoryArtifactService()
|
|
runner = Runner(
|
|
app_name="streaming_test_app",
|
|
agent=agent,
|
|
session_service=session_service,
|
|
artifact_service=artifact_service,
|
|
auto_create_session=True,
|
|
)
|
|
|
|
live_queue = LiveRequestQueue()
|
|
|
|
agen = runner.run_live(
|
|
user_id="user",
|
|
session_id="test_session",
|
|
live_request_queue=live_queue,
|
|
)
|
|
|
|
event = await agen.__anext__()
|
|
await agen.aclose()
|
|
|
|
assert event.author == "streaming_agent"
|
|
|
|
# Verify streaming tools were detected correctly
|
|
active_tools = captured_context.get("active_streaming_tools", {})
|
|
assert "raw_streaming_tool" in active_tools
|
|
assert "wrapped_streaming_tool" in active_tools
|
|
# Non-streaming tool should not be detected
|
|
assert "non_streaming_tool" not in active_tools
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
|
|
project_root = tmp_path / "workspace"
|
|
agent_dir = project_root / "agents" / "examples" / "001_hello_world"
|
|
agent_dir.mkdir(parents=True)
|
|
# Make package structure importable.
|
|
for pkg_dir in [
|
|
project_root / "agents",
|
|
project_root / "agents" / "examples",
|
|
agent_dir,
|
|
]:
|
|
(pkg_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
# Extra directories that previously confused origin inference, e.g. virtualenv.
|
|
(project_root / "agents" / ".venv").mkdir()
|
|
|
|
agent_source = textwrap.dedent("""\
|
|
from google.adk.events.event import Event
|
|
from google.adk.agents.base_agent import BaseAgent
|
|
from google.genai import types
|
|
|
|
|
|
class SimpleAgent(BaseAgent):
|
|
|
|
def __init__(self):
|
|
super().__init__(name='simplest_agent', sub_agents=[])
|
|
|
|
async def _run_async_impl(self, invocation_context):
|
|
yield Event(
|
|
invocation_id=invocation_context.invocation_id,
|
|
author=self.name,
|
|
content=types.Content(
|
|
role='model',
|
|
parts=[types.Part(text='hello from nested')],
|
|
),
|
|
)
|
|
|
|
|
|
root_agent = SimpleAgent()
|
|
""")
|
|
(agent_dir / "agent.py").write_text(agent_source, encoding="utf-8")
|
|
|
|
monkeypatch.chdir(project_root)
|
|
loader = AgentLoader(agents_dir="agents/examples")
|
|
loaded_agent = loader.load_agent("001_hello_world")
|
|
|
|
assert isinstance(loaded_agent, BaseAgent)
|
|
session_service = InMemorySessionService()
|
|
artifact_service = InMemoryArtifactService()
|
|
runner = Runner(
|
|
app_name="001_hello_world",
|
|
agent=loaded_agent,
|
|
session_service=session_service,
|
|
artifact_service=artifact_service,
|
|
)
|
|
assert runner._app_name_alignment_hint is None
|
|
|
|
session = await session_service.create_session(
|
|
app_name="001_hello_world",
|
|
user_id="user",
|
|
)
|
|
agen = runner.run_async(
|
|
user_id=session.user_id,
|
|
session_id=session.id,
|
|
new_message=types.Content(
|
|
role="user",
|
|
parts=[types.Part(text="hi")],
|
|
),
|
|
)
|
|
event = await agen.__anext__()
|
|
await agen.aclose()
|
|
|
|
assert event.author == "simplest_agent"
|
|
assert event.content
|
|
assert event.content.parts
|
|
assert event.content.parts[0].text == "hello from nested"
|
|
|
|
def test_find_agent_to_run_with_function_response_scenario(self):
|
|
"""Test finding agent when last event is function response."""
|
|
# Create a function call from sub_agent1
|
|
function_call = types.FunctionCall(id="func_123", name="test_func", args={})
|
|
function_response = types.FunctionResponse(
|
|
id="func_123", name="test_func", response={}
|
|
)
|
|
|
|
call_event = Event(
|
|
invocation_id="inv1",
|
|
author="sub_agent1",
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(function_call=function_call)]
|
|
),
|
|
)
|
|
|
|
response_event = Event(
|
|
invocation_id="inv2",
|
|
author="user",
|
|
content=types.Content(
|
|
role="user", parts=[types.Part(function_response=function_response)]
|
|
),
|
|
)
|
|
|
|
session = Session(
|
|
id="test_session",
|
|
user_id="test_user",
|
|
app_name="test_app",
|
|
events=[call_event, response_event],
|
|
)
|
|
|
|
result = self.runner._find_agent_to_run(session, self.root_agent)
|
|
assert result == self.sub_agent1
|
|
|
|
def test_find_agent_to_run_returns_root_agent_when_no_events(self):
|
|
"""Test that root agent is returned when session has no non-user events."""
|
|
session = Session(
|
|
id="test_session",
|
|
user_id="test_user",
|
|
app_name="test_app",
|
|
events=[
|
|
Event(
|
|
invocation_id="inv1",
|
|
author="user",
|
|
content=types.Content(
|
|
role="user", parts=[types.Part(text="Hello")]
|
|
),
|
|
)
|
|
],
|
|
)
|
|
|
|
result = self.runner._find_agent_to_run(session, self.root_agent)
|
|
assert result == self.root_agent
|
|
|
|
def test_find_agent_to_run_returns_root_agent_when_found_in_events(self):
|
|
"""Test that root agent is returned when it's found in session events."""
|
|
session = Session(
|
|
id="test_session",
|
|
user_id="test_user",
|
|
app_name="test_app",
|
|
events=[
|
|
Event(
|
|
invocation_id="inv1",
|
|
author="root_agent",
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(text="Root response")]
|
|
),
|
|
)
|
|
],
|
|
)
|
|
|
|
result = self.runner._find_agent_to_run(session, self.root_agent)
|
|
assert result == self.root_agent
|
|
|
|
def test_find_agent_to_run_returns_transferable_sub_agent(self):
|
|
"""Test that transferable sub agent is returned when found."""
|
|
session = Session(
|
|
id="test_session",
|
|
user_id="test_user",
|
|
app_name="test_app",
|
|
events=[
|
|
Event(
|
|
invocation_id="inv1",
|
|
author="sub_agent1",
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(text="Sub agent response")]
|
|
),
|
|
)
|
|
],
|
|
)
|
|
|
|
result = self.runner._find_agent_to_run(session, self.root_agent)
|
|
assert result == self.sub_agent1
|
|
|
|
def test_find_agent_to_run_skips_non_transferable_agent(self):
|
|
"""Test that non-transferable agent is skipped and root agent is returned."""
|
|
session = Session(
|
|
id="test_session",
|
|
user_id="test_user",
|
|
app_name="test_app",
|
|
events=[
|
|
Event(
|
|
invocation_id="inv1",
|
|
author="non_transferable",
|
|
content=types.Content(
|
|
role="model",
|
|
parts=[types.Part(text="Non-transferable response")],
|
|
),
|
|
)
|
|
],
|
|
)
|
|
|
|
result = self.runner._find_agent_to_run(session, self.root_agent)
|
|
assert result == self.root_agent
|
|
|
|
def test_find_agent_to_run_skips_unknown_agent(self):
|
|
"""Test that unknown agent is skipped and root agent is returned."""
|
|
session = Session(
|
|
id="test_session",
|
|
user_id="test_user",
|
|
app_name="test_app",
|
|
events=[
|
|
Event(
|
|
invocation_id="inv1",
|
|
author="unknown_agent",
|
|
content=types.Content(
|
|
role="model",
|
|
parts=[types.Part(text="Unknown agent response")],
|
|
),
|
|
),
|
|
Event(
|
|
invocation_id="inv2",
|
|
author="root_agent",
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(text="Root response")]
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
result = self.runner._find_agent_to_run(session, self.root_agent)
|
|
assert result == self.root_agent
|
|
|
|
def test_find_agent_to_run_function_response_takes_precedence(self):
|
|
"""Test that function response scenario takes precedence over other logic."""
|
|
# Create a function call from sub_agent2
|
|
function_call = types.FunctionCall(id="func_456", name="test_func", args={})
|
|
function_response = types.FunctionResponse(
|
|
id="func_456", name="test_func", response={}
|
|
)
|
|
|
|
call_event = Event(
|
|
invocation_id="inv1",
|
|
author="sub_agent2",
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(function_call=function_call)]
|
|
),
|
|
)
|
|
|
|
# Add another event from root_agent
|
|
root_event = Event(
|
|
invocation_id="inv2",
|
|
author="root_agent",
|
|
content=types.Content(
|
|
role="model", parts=[types.Part(text="Root response")]
|
|
),
|
|
)
|
|
|
|
response_event = Event(
|
|
invocation_id="inv3",
|
|
author="user",
|
|
content=types.Content(
|
|
role="user", parts=[types.Part(function_response=function_response)]
|
|
),
|
|
)
|
|
|
|
session = Session(
|
|
id="test_session",
|
|
user_id="test_user",
|
|
app_name="test_app",
|
|
events=[call_event, root_event, response_event],
|
|
)
|
|
|
|
# Should return sub_agent2 due to function response, not root_agent
|
|
result = self.runner._find_agent_to_run(session, self.root_agent)
|
|
assert result == self.sub_agent2
|
|
|
|
def test_is_transferable_across_agent_tree_with_llm_agent(self):
|
|
"""Test _is_transferable_across_agent_tree with LLM agent."""
|
|
result = self.runner._is_transferable_across_agent_tree(self.sub_agent1)
|
|
assert result is True
|
|
|
|
def test_is_transferable_across_agent_tree_with_non_transferable_agent(self):
|
|
"""Test _is_transferable_across_agent_tree with non-transferable agent."""
|
|
result = self.runner._is_transferable_across_agent_tree(
|
|
self.non_transferable_agent
|
|
)
|
|
assert result is False
|
|
|
|
def test_is_transferable_across_agent_tree_with_non_llm_agent(self):
|
|
"""Test _is_transferable_across_agent_tree with non-LLM agent."""
|
|
non_llm_agent = MockAgent("non_llm_agent")
|
|
# MockAgent inherits from BaseAgent, not LlmAgent, so it should return False
|
|
result = self.runner._is_transferable_across_agent_tree(non_llm_agent)
|
|
assert result is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_config_custom_metadata_propagates_to_events():
|
|
session_service = InMemorySessionService()
|
|
runner = Runner(
|
|
app_name=TEST_APP_ID,
|
|
agent=MockAgentWithMetadata("metadata_agent"),
|
|
session_service=session_service,
|
|
artifact_service=InMemoryArtifactService(),
|
|
)
|
|
await session_service.create_session(
|
|
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
|
|
)
|
|
|
|
run_config = RunConfig(custom_metadata={"request_id": "req-1"})
|
|
events = [
|
|
event
|
|
async for event in runner.run_async(
|
|
user_id=TEST_USER_ID,
|
|
session_id=TEST_SESSION_ID,
|
|
new_message=types.Content(role="user", parts=[types.Part(text="hi")]),
|
|
run_config=run_config,
|
|
)
|
|
]
|
|
|
|
assert events[0].custom_metadata is not None
|
|
assert events[0].custom_metadata["request_id"] == "req-1"
|
|
assert events[0].custom_metadata["event_key"] == "event_value"
|
|
|
|
session = await session_service.get_session(
|
|
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
|
|
)
|
|
user_event = next(event for event in session.events if event.author == "user")
|
|
assert user_event.custom_metadata == {"request_id": "req-1"}
|
|
|
|
|
|
class TestRunnerWithPlugins:
|
|
"""Tests for Runner with plugins."""
|
|
|
|
def setup_method(self):
|
|
self.plugin = MockPlugin()
|
|
self.session_service = InMemorySessionService()
|
|
self.artifact_service = InMemoryArtifactService()
|
|
self.root_agent = MockLlmAgent("root_agent")
|
|
self.runner = Runner(
|
|
app_name="test_app",
|
|
agent=MockLlmAgent("test_agent"),
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
plugins=[self.plugin],
|
|
)
|
|
|
|
async def run_test(self, original_user_input="Hello") -> list[Event]:
|
|
"""Prepares the test by creating a session and running the runner."""
|
|
await self.session_service.create_session(
|
|
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
|
|
)
|
|
events = []
|
|
async for event in self.runner.run_async(
|
|
user_id=TEST_USER_ID,
|
|
session_id=TEST_SESSION_ID,
|
|
new_message=types.Content(
|
|
role="user", parts=[types.Part(text=original_user_input)]
|
|
),
|
|
):
|
|
events.append(event)
|
|
return events
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_runner_is_initialized_with_plugins(self):
|
|
"""Test that the runner is initialized with plugins."""
|
|
await self.run_test()
|
|
|
|
assert self.runner.plugin_manager is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_runner_modifies_user_message_before_execution(self):
|
|
"""Test that the runner modifies the user message before execution."""
|
|
original_user_input = "original_input"
|
|
self.plugin.enable_user_message_callback = True
|
|
|
|
await self.run_test(original_user_input=original_user_input)
|
|
session = await self.session_service.get_session(
|
|
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
|
|
)
|
|
generated_event = session.events[0]
|
|
modified_user_message = generated_event.content.parts[0].text
|
|
|
|
assert modified_user_message == MockPlugin.ON_USER_CALLBACK_MSG
|
|
assert self.plugin.user_content_seen_in_before_run_callback is not None
|
|
assert (
|
|
self.plugin.user_content_seen_in_before_run_callback.parts[0].text
|
|
== MockPlugin.ON_USER_CALLBACK_MSG
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_runner_modifies_event_after_execution(self):
|
|
"""Test that the runner modifies the event after execution."""
|
|
self.plugin.enable_event_callback = True
|
|
|
|
events = await self.run_test()
|
|
generated_event = events[0]
|
|
modified_event_message = generated_event.content.parts[0].text
|
|
|
|
assert modified_event_message == MockPlugin.ON_EVENT_CALLBACK_MSG
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_runner_close_calls_plugin_close(self):
|
|
"""Test that runner.close() calls plugin manager close."""
|
|
# Mock the plugin manager's close method
|
|
self.runner.plugin_manager.close = AsyncMock()
|
|
|
|
await self.runner.close()
|
|
|
|
self.runner.plugin_manager.close.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_runner_passes_plugin_close_timeout(self):
|
|
"""Test that runner passes plugin_close_timeout to PluginManager."""
|
|
runner = Runner(
|
|
app_name="test_app",
|
|
agent=MockLlmAgent("test_agent"),
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
plugins=[self.plugin],
|
|
plugin_close_timeout=10.0,
|
|
)
|
|
assert runner.plugin_manager._close_timeout == 10.0
|
|
|
|
@pytest.mark.filterwarnings(
|
|
"ignore:The `plugins` argument is deprecated:DeprecationWarning"
|
|
)
|
|
def test_runner_init_raises_error_with_app_and_agent(self):
|
|
"""Test that ValueError is raised when app and agent are provided."""
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="When app is provided, agent should not be provided.",
|
|
):
|
|
Runner(
|
|
app=App(name="test_app", root_agent=self.root_agent),
|
|
agent=self.root_agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
@pytest.mark.filterwarnings(
|
|
"ignore:The `plugins` argument is deprecated:DeprecationWarning"
|
|
)
|
|
def test_runner_init_allows_app_name_override_with_app(self):
|
|
"""Test that app_name can override app.name when both are provided."""
|
|
app = App(name="test_app", root_agent=self.root_agent)
|
|
runner = Runner(
|
|
app=app,
|
|
app_name="override_name",
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
assert runner.app_name == "override_name"
|
|
assert runner.agent == self.root_agent
|
|
assert runner.app == app
|
|
|
|
def test_runner_init_raises_error_without_app_and_app_name(self):
|
|
"""Test ValueError is raised when app is not provided and app_name is missing."""
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="Either app or both app_name and agent must be provided.",
|
|
):
|
|
Runner(
|
|
agent=self.root_agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
def test_runner_init_raises_error_without_app_and_agent(self):
|
|
"""Test ValueError is raised when app is not provided and agent is missing."""
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="Either app or both app_name and agent must be provided.",
|
|
):
|
|
Runner(
|
|
app_name="test_app",
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
|
|
class TestRunnerCacheConfig:
|
|
"""Tests for Runner cache config extraction and handling."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test fixtures."""
|
|
self.session_service = InMemorySessionService()
|
|
self.artifact_service = InMemoryArtifactService()
|
|
self.root_agent = MockLlmAgent("root_agent")
|
|
|
|
def test_runner_extracts_cache_config_from_app(self):
|
|
"""Test that Runner extracts cache config from App."""
|
|
cache_config = ContextCacheConfig(
|
|
cache_intervals=15, ttl_seconds=3600, min_tokens=1024
|
|
)
|
|
|
|
app = App(
|
|
name="test_app",
|
|
root_agent=self.root_agent,
|
|
context_cache_config=cache_config,
|
|
)
|
|
|
|
runner = Runner(
|
|
app=app,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
assert runner.context_cache_config == cache_config
|
|
assert runner.context_cache_config.cache_intervals == 15
|
|
assert runner.context_cache_config.ttl_seconds == 3600
|
|
assert runner.context_cache_config.min_tokens == 1024
|
|
|
|
def test_runner_with_app_without_cache_config(self):
|
|
"""Test Runner with App that has no cache config."""
|
|
app = App(
|
|
name="test_app", root_agent=self.root_agent, context_cache_config=None
|
|
)
|
|
|
|
runner = Runner(
|
|
app=app,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
assert runner.context_cache_config is None
|
|
|
|
def test_runner_without_app_has_no_cache_config(self):
|
|
"""Test Runner created without App has no cache config."""
|
|
runner = Runner(
|
|
app_name="test_app",
|
|
agent=self.root_agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
assert runner.context_cache_config is None
|
|
|
|
def test_runner_cache_config_passed_to_invocation_context(self):
|
|
"""Test that cache config is passed to InvocationContext."""
|
|
cache_config = ContextCacheConfig(
|
|
cache_intervals=20, ttl_seconds=7200, min_tokens=2048
|
|
)
|
|
|
|
app = App(
|
|
name="test_app",
|
|
root_agent=self.root_agent,
|
|
context_cache_config=cache_config,
|
|
)
|
|
|
|
runner = Runner(
|
|
app=app,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
# Create a mock session
|
|
mock_session = Session(
|
|
id=TEST_SESSION_ID,
|
|
app_name=TEST_APP_ID,
|
|
user_id=TEST_USER_ID,
|
|
events=[],
|
|
)
|
|
|
|
# Create invocation context using runner's method
|
|
invocation_context = runner._new_invocation_context(mock_session)
|
|
|
|
assert invocation_context.context_cache_config == cache_config
|
|
assert invocation_context.context_cache_config.cache_intervals == 20
|
|
|
|
def test_runner_validate_params_return_order(self):
|
|
"""Test that _validate_runner_params returns values in correct order."""
|
|
cache_config = ContextCacheConfig(cache_intervals=25)
|
|
|
|
app = App(
|
|
name="order_test_app",
|
|
root_agent=self.root_agent,
|
|
context_cache_config=cache_config,
|
|
resumability_config=ResumabilityConfig(is_resumable=True),
|
|
)
|
|
|
|
runner = Runner(
|
|
app=app,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
# Test the validation method directly
|
|
app_name, agent, context_cache_config, resumability_config, plugins = (
|
|
runner._validate_runner_params(app, None, None, None)
|
|
)
|
|
|
|
assert app_name == "order_test_app"
|
|
assert agent == self.root_agent
|
|
assert context_cache_config == cache_config
|
|
assert context_cache_config.cache_intervals == 25
|
|
assert resumability_config == app.resumability_config
|
|
assert plugins == []
|
|
|
|
def test_runner_validate_params_without_app(self):
|
|
"""Test _validate_runner_params without App returns None for cache config."""
|
|
runner = Runner(
|
|
app_name="test_app",
|
|
agent=self.root_agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
app_name, agent, context_cache_config, resumability_config, plugins = (
|
|
runner._validate_runner_params(None, "test_app", self.root_agent, None)
|
|
)
|
|
|
|
assert app_name == "test_app"
|
|
assert agent == self.root_agent
|
|
assert context_cache_config is None
|
|
assert resumability_config is None
|
|
assert plugins is None
|
|
|
|
def test_runner_app_name_and_agent_extracted_correctly(self):
|
|
"""Test that app_name and agent are correctly extracted from App."""
|
|
cache_config = ContextCacheConfig()
|
|
|
|
app = App(
|
|
name="extracted_app",
|
|
root_agent=self.root_agent,
|
|
context_cache_config=cache_config,
|
|
)
|
|
|
|
runner = Runner(
|
|
app=app,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
assert runner.app_name == "extracted_app"
|
|
assert runner.agent == self.root_agent
|
|
assert runner.context_cache_config == cache_config
|
|
|
|
def test_runner_realistic_cache_config_scenario(self):
|
|
"""Test realistic scenario with production-like cache config."""
|
|
# Production cache config
|
|
production_cache_config = ContextCacheConfig(
|
|
cache_intervals=30, ttl_seconds=14400, min_tokens=4096 # 4 hours
|
|
)
|
|
|
|
app = App(
|
|
name="production_app",
|
|
root_agent=self.root_agent,
|
|
context_cache_config=production_cache_config,
|
|
)
|
|
|
|
runner = Runner(
|
|
app=app,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
# Verify all settings are preserved
|
|
assert runner.context_cache_config.cache_intervals == 30
|
|
assert runner.context_cache_config.ttl_seconds == 14400
|
|
assert runner.context_cache_config.ttl_string == "14400s"
|
|
assert runner.context_cache_config.min_tokens == 4096
|
|
|
|
# Verify string representation
|
|
expected_str = (
|
|
"ContextCacheConfig(cache_intervals=30, ttl=14400s, min_tokens=4096)"
|
|
)
|
|
assert str(runner.context_cache_config) == expected_str
|
|
|
|
|
|
class TestRunnerShouldAppendEvent:
|
|
"""Tests for Runner._should_append_event method."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test fixtures."""
|
|
self.session_service = InMemorySessionService()
|
|
self.artifact_service = InMemoryArtifactService()
|
|
self.root_agent = MockLlmAgent("root_agent")
|
|
self.runner = Runner(
|
|
app_name="test_app",
|
|
agent=self.root_agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
def test_should_append_event_finished_input_transcription(self):
|
|
event = Event(
|
|
invocation_id="inv1",
|
|
author="user",
|
|
input_transcription=types.Transcription(text="hello", finished=True),
|
|
)
|
|
assert self.runner._should_append_event(event, is_live_call=True) is True
|
|
|
|
def test_should_append_event_unfinished_input_transcription(self):
|
|
event = Event(
|
|
invocation_id="inv1",
|
|
author="user",
|
|
input_transcription=types.Transcription(text="hello", finished=False),
|
|
)
|
|
assert self.runner._should_append_event(event, is_live_call=True) is True
|
|
|
|
def test_should_append_event_finished_output_transcription(self):
|
|
event = Event(
|
|
invocation_id="inv1",
|
|
author="model",
|
|
output_transcription=types.Transcription(text="world", finished=True),
|
|
)
|
|
assert self.runner._should_append_event(event, is_live_call=True) is True
|
|
|
|
def test_should_append_event_unfinished_output_transcription(self):
|
|
event = Event(
|
|
invocation_id="inv1",
|
|
author="model",
|
|
output_transcription=types.Transcription(text="world", finished=False),
|
|
)
|
|
assert self.runner._should_append_event(event, is_live_call=True) is True
|
|
|
|
def test_should_not_append_event_live_model_audio(self):
|
|
event = Event(
|
|
invocation_id="inv1",
|
|
author="model",
|
|
content=types.Content(
|
|
parts=[
|
|
types.Part(
|
|
inline_data=types.Blob(data=b"123", mime_type="audio/pcm")
|
|
)
|
|
]
|
|
),
|
|
)
|
|
assert self.runner._should_append_event(event, is_live_call=True) is False
|
|
|
|
def test_should_append_event_non_live_model_audio(self):
|
|
event = Event(
|
|
invocation_id="inv1",
|
|
author="model",
|
|
content=types.Content(
|
|
parts=[
|
|
types.Part(
|
|
inline_data=types.Blob(data=b"123", mime_type="audio/pcm")
|
|
)
|
|
]
|
|
),
|
|
)
|
|
assert self.runner._should_append_event(event, is_live_call=False) is True
|
|
|
|
def test_should_append_event_other_event(self):
|
|
event = Event(
|
|
invocation_id="inv1",
|
|
author="model",
|
|
content=types.Content(parts=[types.Part(text="text")]),
|
|
)
|
|
assert self.runner._should_append_event(event, is_live_call=True) is True
|
|
|
|
|
|
@pytest.fixture
|
|
def user_agent_module(tmp_path, monkeypatch):
|
|
"""Fixture that creates a temporary user agent module for testing.
|
|
|
|
Yields a callable that creates an agent module with the given name and
|
|
returns the loaded agent.
|
|
"""
|
|
created_modules = []
|
|
original_path = None
|
|
|
|
def _create_agent(agent_dir_name: str):
|
|
nonlocal original_path
|
|
agent_dir = tmp_path / "agents" / agent_dir_name
|
|
agent_dir.mkdir(parents=True, exist_ok=True)
|
|
(tmp_path / "agents" / "__init__.py").write_text("", encoding="utf-8")
|
|
(agent_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
|
|
agent_source = f"""\
|
|
from google.adk.agents.llm_agent import LlmAgent
|
|
|
|
class MyAgent(LlmAgent):
|
|
pass
|
|
|
|
root_agent = MyAgent(name="{agent_dir_name}", model="gemini-2.0-flash")
|
|
"""
|
|
(agent_dir / "agent.py").write_text(agent_source, encoding="utf-8")
|
|
|
|
monkeypatch.chdir(tmp_path)
|
|
if original_path is None:
|
|
original_path = str(tmp_path)
|
|
sys.path.insert(0, original_path)
|
|
|
|
module_name = f"agents.{agent_dir_name}.agent"
|
|
module = importlib.import_module(module_name)
|
|
created_modules.append(module_name)
|
|
return module.root_agent
|
|
|
|
yield _create_agent
|
|
|
|
# Cleanup
|
|
if original_path and original_path in sys.path:
|
|
sys.path.remove(original_path)
|
|
for mod_name in list(sys.modules.keys()):
|
|
if mod_name.startswith("agents"):
|
|
del sys.modules[mod_name]
|
|
|
|
|
|
class TestRunnerInferAgentOrigin:
|
|
"""Tests for Runner._infer_agent_origin method."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test fixtures."""
|
|
self.session_service = InMemorySessionService()
|
|
self.artifact_service = InMemoryArtifactService()
|
|
|
|
def test_infer_agent_origin_uses_adk_metadata_when_available(self):
|
|
"""Test that _infer_agent_origin uses _adk_origin_* metadata when set."""
|
|
agent = MockLlmAgent("test_agent")
|
|
# Simulate metadata set by AgentLoader
|
|
agent._adk_origin_app_name = "my_app"
|
|
agent._adk_origin_path = Path("/workspace/agents/my_app")
|
|
|
|
runner = Runner(
|
|
app_name="my_app",
|
|
agent=agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
origin_name, origin_path = runner._infer_agent_origin(agent)
|
|
assert origin_name == "my_app"
|
|
assert origin_path == Path("/workspace/agents/my_app")
|
|
|
|
def test_infer_agent_origin_no_false_positive_for_direct_llm_agent(self):
|
|
"""Test that using LlmAgent directly doesn't trigger mismatch warning.
|
|
|
|
Regression test for GitHub issue #3143: Users who instantiate LlmAgent
|
|
directly and run from a directory that is a parent of the ADK installation
|
|
were getting false positive 'App name mismatch' warnings.
|
|
|
|
This also verifies that _infer_agent_origin returns None for ADK internal
|
|
modules (google.adk.*).
|
|
"""
|
|
agent = LlmAgent(
|
|
name="my_custom_agent",
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
runner = Runner(
|
|
app_name="my_custom_agent",
|
|
agent=agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
# Should return None for ADK internal modules
|
|
origin_name, _ = runner._infer_agent_origin(agent)
|
|
assert origin_name is None
|
|
# No mismatch warning should be generated
|
|
assert runner._app_name_alignment_hint is None
|
|
|
|
def test_infer_agent_origin_with_subclassed_agent_in_user_code(
|
|
self, user_agent_module
|
|
):
|
|
"""Test that subclassed agents in user code still trigger origin inference."""
|
|
agent = user_agent_module("my_agent")
|
|
|
|
runner = Runner(
|
|
app_name="my_agent",
|
|
agent=agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
# Should infer origin correctly from user's code
|
|
origin_name, origin_path = runner._infer_agent_origin(agent)
|
|
assert origin_name == "my_agent"
|
|
assert runner._app_name_alignment_hint is None
|
|
|
|
def test_infer_agent_origin_detects_mismatch_for_user_agent(
|
|
self, user_agent_module
|
|
):
|
|
"""Test that mismatched app_name is detected for user-defined agents."""
|
|
agent = user_agent_module("actual_name")
|
|
|
|
runner = Runner(
|
|
app_name="wrong_name", # Intentionally wrong
|
|
agent=agent,
|
|
session_service=self.session_service,
|
|
artifact_service=self.artifact_service,
|
|
)
|
|
|
|
# Should detect the mismatch
|
|
assert runner._app_name_alignment_hint is not None
|
|
assert "wrong_name" in runner._app_name_alignment_hint
|
|
assert "actual_name" in runner._app_name_alignment_hint
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|