You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Pause invocations on long running function calls for resumable apps
PiperOrigin-RevId: 811518771
This commit is contained in:
committed by
Copybara-Service
parent
dd1ffad394
commit
b2b80e7fa0
@@ -260,6 +260,45 @@ class InvocationContext(BaseModel):
|
||||
results = [event for event in results if event.branch == self.branch]
|
||||
return results
|
||||
|
||||
def should_pause_invocation(self, event: Event) -> bool:
|
||||
"""Returns whether to pause the invocation right after this event.
|
||||
|
||||
"Pausing" an invocation is different from "ending" an invocation. A paused
|
||||
invocation can be resumed later, while an ended invocation cannot.
|
||||
|
||||
Pausing the current agent's run will also pause all the agents that
|
||||
depend on its execution, i.e. the subsequent agents in a workflow, and the
|
||||
current agent's ancestors, etc.
|
||||
|
||||
Note that parallel sibling agents won't be affected, but their common
|
||||
ancestors will be paused after all the non-blocking sub-agents finished
|
||||
running.
|
||||
|
||||
Should meet all following conditions to pause an invocation:
|
||||
1. The app is resumable.
|
||||
2. The current event has a long running function call.
|
||||
|
||||
Args:
|
||||
event: The current event.
|
||||
|
||||
Returns:
|
||||
Whether to pause the invocation right after this event.
|
||||
"""
|
||||
if (
|
||||
not self.resumability_config
|
||||
or not self.resumability_config.is_resumable
|
||||
):
|
||||
return False
|
||||
|
||||
if not event.long_running_tool_ids or not event.get_function_calls():
|
||||
return False
|
||||
|
||||
for fc in event.get_function_calls():
|
||||
if fc.id in event.long_running_tool_ids:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def new_invocation_context_id() -> str:
|
||||
return "e-" + str(uuid.uuid4())
|
||||
|
||||
@@ -341,6 +341,8 @@ class LlmAgent(BaseAgent):
|
||||
async for event in agen:
|
||||
self.__maybe_save_output_to_state(event)
|
||||
yield event
|
||||
if ctx.should_pause_invocation(event):
|
||||
return
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
|
||||
@@ -70,15 +70,26 @@ class LoopAgent(BaseAgent):
|
||||
while not self.max_iterations or times_looped < self.max_iterations:
|
||||
for sub_agent in self.sub_agents:
|
||||
should_exit = False
|
||||
pause_invocation = False
|
||||
|
||||
async with Aclosing(sub_agent.run_async(ctx)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
if event.actions.escalate:
|
||||
should_exit = True
|
||||
if ctx.should_pause_invocation(event):
|
||||
pause_invocation = True
|
||||
|
||||
# Indicates that the loop agent should exist after running this
|
||||
# sub-agent.
|
||||
if should_exit:
|
||||
return
|
||||
|
||||
# Indicates that the invocation should be paused after running this
|
||||
# sub-agent.
|
||||
if pause_invocation:
|
||||
return
|
||||
|
||||
times_looped += 1
|
||||
return
|
||||
|
||||
|
||||
@@ -181,16 +181,26 @@ class ParallelAgent(BaseAgent):
|
||||
)
|
||||
for sub_agent in self.sub_agents
|
||||
]
|
||||
|
||||
pause_invocation = False
|
||||
try:
|
||||
# TODO remove if once Python <3.11 is no longer supported.
|
||||
if sys.version_info >= (3, 11):
|
||||
async with Aclosing(_merge_agent_run(agent_runs)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
if ctx.should_pause_invocation(event):
|
||||
pause_invocation = True
|
||||
else:
|
||||
async with Aclosing(_merge_agent_run_pre_3_11(agent_runs)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
if ctx.should_pause_invocation(event):
|
||||
pause_invocation = True
|
||||
|
||||
if pause_invocation:
|
||||
return
|
||||
|
||||
finally:
|
||||
for sub_agent_run in agent_runs:
|
||||
await sub_agent_run.aclose()
|
||||
|
||||
@@ -52,9 +52,18 @@ class SequentialAgent(BaseAgent):
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
for sub_agent in self.sub_agents:
|
||||
pause_invocation = False
|
||||
|
||||
async with Aclosing(sub_agent.run_async(ctx)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
if ctx.should_pause_invocation(event):
|
||||
pause_invocation = True
|
||||
|
||||
# Indicates the invocation should pause when receiving signal from
|
||||
# the current sub_agent.
|
||||
if pause_invocation:
|
||||
return
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
|
||||
@@ -16,11 +16,16 @@ from unittest.mock import Mock
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.apps import ResumabilityConfig
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.sessions.base_session_service import BaseSessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai.types import FunctionCall
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
|
||||
from .. import testing_utils
|
||||
|
||||
|
||||
class TestInvocationContext:
|
||||
"""Test suite for InvocationContext."""
|
||||
@@ -117,3 +122,87 @@ class TestInvocationContext:
|
||||
current_branch=True,
|
||||
)
|
||||
assert not events
|
||||
|
||||
|
||||
class TestInvocationContextWithAppResumablity:
|
||||
"""Test suite for InvocationContext regarding app resumability."""
|
||||
|
||||
@pytest.fixture
|
||||
def long_running_function_call(self) -> FunctionCall:
|
||||
"""A long running function call."""
|
||||
return FunctionCall(
|
||||
id='tool_call_id_1',
|
||||
name='long_running_function_call',
|
||||
args={},
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def event_to_pause(self, long_running_function_call) -> Event:
|
||||
"""An event with a long running function call."""
|
||||
return Event(
|
||||
invocation_id='inv_1',
|
||||
author='agent',
|
||||
content=testing_utils.ModelContent(
|
||||
[Part(function_call=long_running_function_call)]
|
||||
),
|
||||
long_running_tool_ids=[long_running_function_call.id],
|
||||
)
|
||||
|
||||
def _create_test_invocation_context(
|
||||
self, resumability_config
|
||||
) -> InvocationContext:
|
||||
"""Create a mock invocation context for testing."""
|
||||
ctx = InvocationContext(
|
||||
session_service=Mock(spec=BaseSessionService),
|
||||
agent=Mock(spec=BaseAgent),
|
||||
invocation_id='inv_1',
|
||||
session=Mock(spec=Session),
|
||||
resumability_config=resumability_config,
|
||||
)
|
||||
return ctx
|
||||
|
||||
def test_should_pause_invocation_with_resumable_app(self, event_to_pause):
|
||||
"""Tests should_pause_invocation with a resumable app."""
|
||||
mock_invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
|
||||
assert mock_invocation_context.should_pause_invocation(event_to_pause)
|
||||
|
||||
def test_should_not_pause_invocation_with_non_resumable_app(
|
||||
self, event_to_pause
|
||||
):
|
||||
"""Tests should_pause_invocation with a non-resumable app."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=False)
|
||||
)
|
||||
|
||||
assert not invocation_context.should_pause_invocation(event_to_pause)
|
||||
|
||||
def test_should_not_pause_invocation_with_no_long_running_tool_ids(
|
||||
self, event_to_pause
|
||||
):
|
||||
"""Tests should_pause_invocation with no long running tools."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
nonpausable_event = event_to_pause.model_copy(
|
||||
update={'long_running_tool_ids': []}
|
||||
)
|
||||
|
||||
assert not invocation_context.should_pause_invocation(nonpausable_event)
|
||||
|
||||
def test_should_not_pause_invocation_with_no_function_calls(
|
||||
self, event_to_pause
|
||||
):
|
||||
"""Tests should_pause_invocation with a non-model event."""
|
||||
mock_invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
nonpausable_event = event_to_pause.model_copy(
|
||||
update={'content': testing_utils.UserContent('test text part')}
|
||||
)
|
||||
|
||||
assert not mock_invocation_context.should_pause_invocation(
|
||||
nonpausable_event
|
||||
)
|
||||
|
||||
@@ -0,0 +1,472 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the resumption flow with different agent structures."""
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.loop_agent import LoopAgent
|
||||
from google.adk.agents.parallel_agent import ParallelAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.apps.app import App
|
||||
from google.adk.apps.app import ResumabilityConfig
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.tools.exit_loop_tool import exit_loop
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.long_running_tool import LongRunningFunctionTool
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
|
||||
from .. import testing_utils
|
||||
|
||||
|
||||
def _transfer_call_part(agent_name: str) -> Part:
|
||||
return Part.from_function_call(
|
||||
name="transfer_to_agent", args={"agent_name": agent_name}
|
||||
)
|
||||
|
||||
|
||||
def test_tool() -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class _TestingAgent(BaseAgent):
|
||||
"""A testing agent that generates an event after a delay."""
|
||||
|
||||
delay: float = 0
|
||||
"""The delay before the agent generates an event."""
|
||||
|
||||
def event(self, ctx: InvocationContext):
|
||||
return Event(
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
invocation_id=ctx.invocation_id,
|
||||
content=testing_utils.ModelContent(
|
||||
parts=[Part.from_text(text="Delayed message")]
|
||||
),
|
||||
)
|
||||
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
await asyncio.sleep(self.delay)
|
||||
yield self.event(ctx)
|
||||
|
||||
|
||||
_TRANSFER_RESPONSE_PART = Part.from_function_response(
|
||||
name="transfer_to_agent", response={"result": None}
|
||||
)
|
||||
|
||||
|
||||
class BasePauseInvocationTest:
|
||||
"""Base class for pausing invocation tests with common fixtures."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self) -> BaseAgent:
|
||||
"""Provides a BaseAgent for the test."""
|
||||
return BaseAgent(name="test_agent")
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, agent: BaseAgent) -> App:
|
||||
"""Provides an App for the test."""
|
||||
return App(
|
||||
name="InMemoryRunner", # Required for using TestInMemoryRunner.
|
||||
root_agent=agent,
|
||||
resumability_config=ResumabilityConfig(is_resumable=True),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self, app: App) -> testing_utils.TestInMemoryRunner:
|
||||
"""Provides an in-memory runner for the agent."""
|
||||
return testing_utils.TestInMemoryRunner(app=app, app_name=None)
|
||||
|
||||
@staticmethod
|
||||
def mock_model(responses: list[Part]) -> testing_utils.MockModel:
|
||||
"""Provides a mock model with predefined responses."""
|
||||
return testing_utils.MockModel.create(responses=responses)
|
||||
|
||||
|
||||
class TestPauseInvocationWithSingleLlmAgent(BasePauseInvocationTest):
|
||||
"""Tests the resumption flow with a single LlmAgent."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self) -> BaseAgent:
|
||||
"""Provides a BaseAgent for the test."""
|
||||
|
||||
def test_tool() -> str:
|
||||
return ""
|
||||
|
||||
return LlmAgent(
|
||||
name="root_agent",
|
||||
model=self.mock_model(
|
||||
responses=[Part.from_function_call(name="test_tool", args={})]
|
||||
),
|
||||
tools=[LongRunningFunctionTool(func=test_tool)],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_on_long_running_function_call(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
):
|
||||
"""Tests that a single LlmAgent pauses on long running function call."""
|
||||
assert testing_utils.simplify_events(
|
||||
await runner.run_async_with_new_session("test")
|
||||
) == [
|
||||
("root_agent", Part.from_function_call(name="test_tool", args={})),
|
||||
]
|
||||
|
||||
|
||||
class TestPauseInvocationWithSequentialAgent(BasePauseInvocationTest):
|
||||
"""Tests pausing invocation with a SequentialAgent."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self) -> BaseAgent:
|
||||
"""Provides a BaseAgent for the test."""
|
||||
sub_agent1 = LlmAgent(
|
||||
name="sub_agent_1",
|
||||
model=self.mock_model(
|
||||
responses=[Part.from_function_call(name="test_tool", args={})]
|
||||
),
|
||||
tools=[LongRunningFunctionTool(func=test_tool)],
|
||||
)
|
||||
sub_agent2 = LlmAgent(
|
||||
name="sub_agent_2",
|
||||
model=self.mock_model(
|
||||
responses=[Part.from_function_call(name="test_tool", args={})]
|
||||
),
|
||||
tools=[LongRunningFunctionTool(func=test_tool)],
|
||||
)
|
||||
return SequentialAgent(
|
||||
name="root_agent",
|
||||
sub_agents=[sub_agent1, sub_agent2],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_first_agent_on_long_running_function_call(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
):
|
||||
"""Tests that a single LlmAgent pauses on long running function call."""
|
||||
assert testing_utils.simplify_events(
|
||||
await runner.run_async_with_new_session("test")
|
||||
) == [
|
||||
("sub_agent_1", Part.from_function_call(name="test_tool", args={})),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_second_agent_on_long_running_function_call(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
):
|
||||
"""Tests that a single LlmAgent pauses on long running function call."""
|
||||
# Change the base sequential agent, so that the first agent does not pause.
|
||||
runner.agent.sub_agents[0].tools = [FunctionTool(func=test_tool)]
|
||||
runner.agent.sub_agents[0].model = self.mock_model(
|
||||
responses=[
|
||||
Part.from_function_call(name="test_tool", args={}),
|
||||
Part.from_text(text="model response after tool call"),
|
||||
]
|
||||
)
|
||||
assert testing_utils.simplify_events(
|
||||
await runner.run_async_with_new_session("test")
|
||||
) == [
|
||||
("sub_agent_1", Part.from_function_call(name="test_tool", args={})),
|
||||
(
|
||||
"sub_agent_1",
|
||||
Part.from_function_response(
|
||||
name="test_tool", response={"result": ""}
|
||||
),
|
||||
),
|
||||
("sub_agent_1", "model response after tool call"),
|
||||
("sub_agent_2", Part.from_function_call(name="test_tool", args={})),
|
||||
]
|
||||
|
||||
|
||||
class TestPauseInvocationWithParallelAgent(BasePauseInvocationTest):
|
||||
"""Tests pausing invocation with a ParallelAgent."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self) -> BaseAgent:
|
||||
"""Provides a BaseAgent for the test."""
|
||||
sub_agent1 = LlmAgent(
|
||||
name="sub_agent_1",
|
||||
model=self.mock_model(
|
||||
responses=[Part.from_function_call(name="test_tool", args={})]
|
||||
),
|
||||
tools=[LongRunningFunctionTool(func=test_tool)],
|
||||
)
|
||||
sub_agent2 = _TestingAgent(
|
||||
name="sub_agent_2",
|
||||
delay=0.5,
|
||||
)
|
||||
return ParallelAgent(
|
||||
name="root_agent",
|
||||
sub_agents=[sub_agent1, sub_agent2],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_on_long_running_function_call(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
):
|
||||
"""Tests that a ParallelAgent pauses on long running function call."""
|
||||
assert testing_utils.simplify_events(
|
||||
await runner.run_async_with_new_session("test")
|
||||
) == [
|
||||
("sub_agent_1", Part.from_function_call(name="test_tool", args={})),
|
||||
("sub_agent_2", "Delayed message"),
|
||||
]
|
||||
|
||||
|
||||
class TestPauseInvocationWithNestedParallelAgent(BasePauseInvocationTest):
|
||||
"""Tests pausing invocation with a nested ParallelAgent."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self) -> BaseAgent:
|
||||
"""Provides a BaseAgent for the test."""
|
||||
nested_sub_agent_1 = LlmAgent(
|
||||
name="nested_sub_agent_1",
|
||||
model=self.mock_model(
|
||||
responses=[Part.from_function_call(name="test_tool", args={})]
|
||||
),
|
||||
tools=[LongRunningFunctionTool(func=test_tool)],
|
||||
)
|
||||
nested_sub_agent_2 = _TestingAgent(
|
||||
name="nested_sub_agent_2",
|
||||
delay=0.5,
|
||||
)
|
||||
nested_parallel_agent = ParallelAgent(
|
||||
name="nested_parallel_agent",
|
||||
sub_agents=[nested_sub_agent_1, nested_sub_agent_2],
|
||||
)
|
||||
sub_agent_1 = _TestingAgent(
|
||||
name="sub_agent_1",
|
||||
delay=0.5,
|
||||
)
|
||||
return ParallelAgent(
|
||||
name="root_agent",
|
||||
sub_agents=[sub_agent_1, nested_parallel_agent],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_on_long_running_function_call_in_nested_agent(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
):
|
||||
"""Tests that a nested ParallelAgent pauses on long running function call."""
|
||||
assert testing_utils.simplify_events(
|
||||
await runner.run_async_with_new_session("test")
|
||||
) == [
|
||||
(
|
||||
"nested_sub_agent_1",
|
||||
Part.from_function_call(name="test_tool", args={}),
|
||||
),
|
||||
("sub_agent_1", "Delayed message"),
|
||||
("nested_sub_agent_2", "Delayed message"),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_on_multiple_long_running_function_calls(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
):
|
||||
"""Tests that a ParallelAgent pauses on long running function calls."""
|
||||
runner.agent.sub_agents[0] = LlmAgent(
|
||||
name="sub_agent_1",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
Part.from_function_call(name="test_tool", args={}),
|
||||
Part.from_function_call(name="test_tool", args={}),
|
||||
]
|
||||
),
|
||||
tools=[LongRunningFunctionTool(func=test_tool)],
|
||||
)
|
||||
simplified_events = testing_utils.simplify_events(
|
||||
await runner.run_async_with_new_session("test")
|
||||
)
|
||||
assert len(simplified_events) == 3
|
||||
assert (
|
||||
"sub_agent_1",
|
||||
Part.from_function_call(name="test_tool", args={}),
|
||||
) in simplified_events
|
||||
assert (
|
||||
"nested_sub_agent_1",
|
||||
Part.from_function_call(name="test_tool", args={}),
|
||||
) in simplified_events
|
||||
|
||||
|
||||
class TestPauseInvocationWithLoopAgent(BasePauseInvocationTest):
|
||||
"""Tests pausing invocation with a LoopAgent."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self) -> BaseAgent:
|
||||
"""Provides a BaseAgent for the test."""
|
||||
sub_agent_1 = LlmAgent(
|
||||
name="sub_agent_1",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
Part.from_text(text="sub agent 1 response"),
|
||||
]
|
||||
),
|
||||
)
|
||||
sub_agent_2 = LlmAgent(
|
||||
name="sub_agent_2",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
Part.from_function_call(name="test_tool", args={}),
|
||||
]
|
||||
),
|
||||
tools=[LongRunningFunctionTool(func=test_tool)],
|
||||
)
|
||||
sub_agent_3 = LlmAgent(
|
||||
name="sub_agent_3",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
Part.from_function_call(name="exit_loop", args={}),
|
||||
]
|
||||
),
|
||||
tools=[exit_loop],
|
||||
)
|
||||
return LoopAgent(
|
||||
name="root_agent",
|
||||
sub_agents=[sub_agent_1, sub_agent_2, sub_agent_3],
|
||||
max_iterations=2,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_on_long_running_function_call_in_loop(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
):
|
||||
"""Tests that a LoopAgent pauses on long running function call."""
|
||||
assert testing_utils.simplify_events(
|
||||
await runner.run_async_with_new_session("test")
|
||||
) == [
|
||||
("sub_agent_1", "sub agent 1 response"),
|
||||
("sub_agent_2", Part.from_function_call(name="test_tool", args={})),
|
||||
]
|
||||
|
||||
|
||||
class TestPauseInvocationWithLlmAgentTree(BasePauseInvocationTest):
|
||||
"""Tests the pausing invocation with a tree of LlmAgents."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self) -> LlmAgent:
|
||||
"""Provides an LlmAgent with sub-agents for the test."""
|
||||
sub_llm_agent_1 = LlmAgent(
|
||||
name="sub_llm_agent_1",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
_transfer_call_part("sub_llm_agent_2"),
|
||||
"llm response not used",
|
||||
]
|
||||
),
|
||||
)
|
||||
sub_llm_agent_2 = LlmAgent(
|
||||
name="sub_llm_agent_2",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
Part.from_function_call(name="test_tool", args={}),
|
||||
"llm response not used",
|
||||
]
|
||||
),
|
||||
tools=[LongRunningFunctionTool(func=test_tool)],
|
||||
)
|
||||
return LlmAgent(
|
||||
name="root_agent",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
_transfer_call_part("sub_llm_agent_1"),
|
||||
"llm response not used",
|
||||
]
|
||||
),
|
||||
sub_agents=[sub_llm_agent_1, sub_llm_agent_2],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_on_transfer_call_part(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
):
|
||||
"""Tests that a tree of resumable LlmAgents yields checkpoint events."""
|
||||
assert testing_utils.simplify_events(
|
||||
await runner.run_async_with_new_session("test")
|
||||
) == [
|
||||
("root_agent", _transfer_call_part("sub_llm_agent_1")),
|
||||
("root_agent", _TRANSFER_RESPONSE_PART),
|
||||
("sub_llm_agent_1", _transfer_call_part("sub_llm_agent_2")),
|
||||
("sub_llm_agent_1", _TRANSFER_RESPONSE_PART),
|
||||
("sub_llm_agent_2", Part.from_function_call(name="test_tool", args={})),
|
||||
]
|
||||
|
||||
|
||||
class TestPauseInvocationWithWithTransferLoop(BasePauseInvocationTest):
|
||||
"""Tests the pausing the invocation when the agent transfer forms a loop."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self) -> LlmAgent:
|
||||
"""Provides an LlmAgent with sub-agents for the test."""
|
||||
sub_llm_agent_1 = LlmAgent(
|
||||
name="sub_llm_agent_1",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
_transfer_call_part("sub_llm_agent_2"),
|
||||
"llm response not used",
|
||||
]
|
||||
),
|
||||
)
|
||||
sub_llm_agent_2 = LlmAgent(
|
||||
name="sub_llm_agent_2",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
_transfer_call_part("root_agent"),
|
||||
"llm response not used",
|
||||
]
|
||||
),
|
||||
)
|
||||
return LlmAgent(
|
||||
name="root_agent",
|
||||
model=self.mock_model(
|
||||
responses=[
|
||||
_transfer_call_part("sub_llm_agent_1"),
|
||||
Part.from_function_call(name="test_tool", args={}),
|
||||
"llm response not used",
|
||||
]
|
||||
),
|
||||
sub_agents=[sub_llm_agent_1, sub_llm_agent_2],
|
||||
tools=[LongRunningFunctionTool(func=test_tool)],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_tree_yields_checkpoints(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
):
|
||||
"""Tests that a tree of resumable LlmAgents yields checkpoint events."""
|
||||
assert testing_utils.simplify_events(
|
||||
await runner.run_async_with_new_session("test")
|
||||
) == [
|
||||
("root_agent", _transfer_call_part("sub_llm_agent_1")),
|
||||
("root_agent", _TRANSFER_RESPONSE_PART),
|
||||
("sub_llm_agent_1", _transfer_call_part("sub_llm_agent_2")),
|
||||
("sub_llm_agent_1", _TRANSFER_RESPONSE_PART),
|
||||
("sub_llm_agent_2", _transfer_call_part("root_agent")),
|
||||
("sub_llm_agent_2", _TRANSFER_RESPONSE_PART),
|
||||
("root_agent", Part.from_function_call(name="test_tool", args={})),
|
||||
]
|
||||
@@ -19,9 +19,8 @@ from unittest import mock
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.loop_agent import LoopAgent
|
||||
from google.adk.agents.parallel_agent import ParallelAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.apps.app import App
|
||||
from google.adk.apps.app import ResumabilityConfig
|
||||
from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
@@ -363,3 +362,88 @@ class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest):
|
||||
testing_utils.simplify_events(copy.deepcopy(events))
|
||||
== expected_parts_final
|
||||
)
|
||||
|
||||
|
||||
class TestHITLConfirmationFlowWithResumableApp:
|
||||
"""Tests the HITL confirmation flow with a resumable app."""
|
||||
|
||||
@pytest.fixture
|
||||
def tools(self) -> list[FunctionTool]:
|
||||
"""Provides the tools for the agent."""
|
||||
return [FunctionTool(func=_test_request_confirmation_function)]
|
||||
|
||||
@pytest.fixture
|
||||
def llm_responses(
|
||||
self, tools: list[FunctionTool]
|
||||
) -> list[GenerateContentResponse]:
|
||||
"""Provides mock LLM responses for the tests."""
|
||||
return [
|
||||
_create_llm_response_from_tools(tools),
|
||||
_create_llm_response_from_text("test llm response after tool call"),
|
||||
_create_llm_response_from_text(
|
||||
"test llm response after final tool call"
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(
|
||||
self, llm_responses: list[GenerateContentResponse]
|
||||
) -> testing_utils.MockModel:
|
||||
"""Provides a mock model with predefined responses."""
|
||||
return testing_utils.MockModel(responses=llm_responses)
|
||||
|
||||
@pytest.fixture
|
||||
def agent(
|
||||
self, mock_model: testing_utils.MockModel, tools: list[FunctionTool]
|
||||
) -> LlmAgent:
|
||||
"""Provides a single LlmAgent for the test."""
|
||||
return LlmAgent(name="root_agent", model=mock_model, tools=tools)
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self, agent: LlmAgent) -> testing_utils.TestInMemoryRunner:
|
||||
"""Provides an in-memory runner for the agent."""
|
||||
# Mark the app as resumable. So that the invocation will be paused after the
|
||||
# long running tool call.
|
||||
app = App(
|
||||
name="InMemoryRunner", # Required for using TestInMemoryRunner.
|
||||
resumability_config=ResumabilityConfig(is_resumable=True),
|
||||
root_agent=agent,
|
||||
)
|
||||
return testing_utils.TestInMemoryRunner(app=app, app_name=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_on_request_confirmation(
|
||||
self,
|
||||
runner: testing_utils.TestInMemoryRunner,
|
||||
agent: LlmAgent,
|
||||
):
|
||||
"""Tests HITL flow where all tool calls are confirmed."""
|
||||
events = await runner.run_async_with_new_session("test user query")
|
||||
|
||||
# Verify that the invocation is paused after the long running tool call.
|
||||
# So that no intermediate function response and llm response is generated.
|
||||
assert testing_utils.simplify_events(copy.deepcopy(events)) == [
|
||||
(
|
||||
agent.name,
|
||||
Part(function_call=FunctionCall(name=agent.tools[0].name, args={})),
|
||||
),
|
||||
(
|
||||
agent.name,
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
args={
|
||||
"originalFunctionCall": {
|
||||
"name": agent.tools[0].name,
|
||||
"id": mock.ANY,
|
||||
"args": {},
|
||||
},
|
||||
"toolConfirmation": {
|
||||
"hint": "test hint for request_confirmation",
|
||||
"confirmed": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user