feat: Pause invocations on long running function calls for resumable apps

PiperOrigin-RevId: 811518771
This commit is contained in:
Xinran (Sherry) Tang
2025-09-25 15:10:38 -07:00
committed by Copybara-Service
parent dd1ffad394
commit b2b80e7fa0
8 changed files with 719 additions and 3 deletions
@@ -260,6 +260,45 @@ class InvocationContext(BaseModel):
results = [event for event in results if event.branch == self.branch]
return results
def should_pause_invocation(self, event: Event) -> bool:
"""Returns whether to pause the invocation right after this event.
"Pausing" an invocation is different from "ending" an invocation. A paused
invocation can be resumed later, while an ended invocation cannot.
Pausing the current agent's run will also pause all the agents that
depend on its execution, i.e. the subsequent agents in a workflow, and the
current agent's ancestors, etc.
Note that parallel sibling agents won't be affected, but their common
ancestors will be paused after all the non-blocking sub-agents finished
running.
Should meet all following conditions to pause an invocation:
1. The app is resumable.
2. The current event has a long running function call.
Args:
event: The current event.
Returns:
Whether to pause the invocation right after this event.
"""
if (
not self.resumability_config
or not self.resumability_config.is_resumable
):
return False
if not event.long_running_tool_ids or not event.get_function_calls():
return False
for fc in event.get_function_calls():
if fc.id in event.long_running_tool_ids:
return True
return False
def new_invocation_context_id() -> str:
return "e-" + str(uuid.uuid4())
+2
View File
@@ -341,6 +341,8 @@ class LlmAgent(BaseAgent):
async for event in agen:
self.__maybe_save_output_to_state(event)
yield event
if ctx.should_pause_invocation(event):
return
@override
async def _run_live_impl(
+11
View File
@@ -70,15 +70,26 @@ class LoopAgent(BaseAgent):
while not self.max_iterations or times_looped < self.max_iterations:
for sub_agent in self.sub_agents:
should_exit = False
pause_invocation = False
async with Aclosing(sub_agent.run_async(ctx)) as agen:
async for event in agen:
yield event
if event.actions.escalate:
should_exit = True
if ctx.should_pause_invocation(event):
pause_invocation = True
# Indicates that the loop agent should exist after running this
# sub-agent.
if should_exit:
return
# Indicates that the invocation should be paused after running this
# sub-agent.
if pause_invocation:
return
times_looped += 1
return
+10
View File
@@ -181,16 +181,26 @@ class ParallelAgent(BaseAgent):
)
for sub_agent in self.sub_agents
]
pause_invocation = False
try:
# TODO remove if once Python <3.11 is no longer supported.
if sys.version_info >= (3, 11):
async with Aclosing(_merge_agent_run(agent_runs)) as agen:
async for event in agen:
yield event
if ctx.should_pause_invocation(event):
pause_invocation = True
else:
async with Aclosing(_merge_agent_run_pre_3_11(agent_runs)) as agen:
async for event in agen:
yield event
if ctx.should_pause_invocation(event):
pause_invocation = True
if pause_invocation:
return
finally:
for sub_agent_run in agent_runs:
await sub_agent_run.aclose()
@@ -52,9 +52,18 @@ class SequentialAgent(BaseAgent):
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
for sub_agent in self.sub_agents:
pause_invocation = False
async with Aclosing(sub_agent.run_async(ctx)) as agen:
async for event in agen:
yield event
if ctx.should_pause_invocation(event):
pause_invocation = True
# Indicates the invocation should pause when receiving signal from
# the current sub_agent.
if pause_invocation:
return
@override
async def _run_live_impl(
@@ -16,11 +16,16 @@ from unittest.mock import Mock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.apps import ResumabilityConfig
from google.adk.events.event import Event
from google.adk.sessions.base_session_service import BaseSessionService
from google.adk.sessions.session import Session
from google.genai.types import FunctionCall
from google.genai.types import Part
import pytest
from .. import testing_utils
class TestInvocationContext:
"""Test suite for InvocationContext."""
@@ -117,3 +122,87 @@ class TestInvocationContext:
current_branch=True,
)
assert not events
class TestInvocationContextWithAppResumablity:
"""Test suite for InvocationContext regarding app resumability."""
@pytest.fixture
def long_running_function_call(self) -> FunctionCall:
"""A long running function call."""
return FunctionCall(
id='tool_call_id_1',
name='long_running_function_call',
args={},
)
@pytest.fixture
def event_to_pause(self, long_running_function_call) -> Event:
"""An event with a long running function call."""
return Event(
invocation_id='inv_1',
author='agent',
content=testing_utils.ModelContent(
[Part(function_call=long_running_function_call)]
),
long_running_tool_ids=[long_running_function_call.id],
)
def _create_test_invocation_context(
self, resumability_config
) -> InvocationContext:
"""Create a mock invocation context for testing."""
ctx = InvocationContext(
session_service=Mock(spec=BaseSessionService),
agent=Mock(spec=BaseAgent),
invocation_id='inv_1',
session=Mock(spec=Session),
resumability_config=resumability_config,
)
return ctx
def test_should_pause_invocation_with_resumable_app(self, event_to_pause):
"""Tests should_pause_invocation with a resumable app."""
mock_invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
assert mock_invocation_context.should_pause_invocation(event_to_pause)
def test_should_not_pause_invocation_with_non_resumable_app(
self, event_to_pause
):
"""Tests should_pause_invocation with a non-resumable app."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=False)
)
assert not invocation_context.should_pause_invocation(event_to_pause)
def test_should_not_pause_invocation_with_no_long_running_tool_ids(
self, event_to_pause
):
"""Tests should_pause_invocation with no long running tools."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
nonpausable_event = event_to_pause.model_copy(
update={'long_running_tool_ids': []}
)
assert not invocation_context.should_pause_invocation(nonpausable_event)
def test_should_not_pause_invocation_with_no_function_calls(
self, event_to_pause
):
"""Tests should_pause_invocation with a non-model event."""
mock_invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
nonpausable_event = event_to_pause.model_copy(
update={'content': testing_utils.UserContent('test text part')}
)
assert not mock_invocation_context.should_pause_invocation(
nonpausable_event
)
@@ -0,0 +1,472 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the resumption flow with different agent structures."""
import asyncio
from typing import AsyncGenerator
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.loop_agent import LoopAgent
from google.adk.agents.parallel_agent import ParallelAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.apps.app import App
from google.adk.apps.app import ResumabilityConfig
from google.adk.events.event import Event
from google.adk.tools.exit_loop_tool import exit_loop
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.long_running_tool import LongRunningFunctionTool
from google.genai.types import Part
import pytest
from .. import testing_utils
def _transfer_call_part(agent_name: str) -> Part:
return Part.from_function_call(
name="transfer_to_agent", args={"agent_name": agent_name}
)
def test_tool() -> str:
return ""
class _TestingAgent(BaseAgent):
"""A testing agent that generates an event after a delay."""
delay: float = 0
"""The delay before the agent generates an event."""
def event(self, ctx: InvocationContext):
return Event(
author=self.name,
branch=ctx.branch,
invocation_id=ctx.invocation_id,
content=testing_utils.ModelContent(
parts=[Part.from_text(text="Delayed message")]
),
)
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
await asyncio.sleep(self.delay)
yield self.event(ctx)
_TRANSFER_RESPONSE_PART = Part.from_function_response(
name="transfer_to_agent", response={"result": None}
)
class BasePauseInvocationTest:
"""Base class for pausing invocation tests with common fixtures."""
@pytest.fixture
def agent(self) -> BaseAgent:
"""Provides a BaseAgent for the test."""
return BaseAgent(name="test_agent")
@pytest.fixture
def app(self, agent: BaseAgent) -> App:
"""Provides an App for the test."""
return App(
name="InMemoryRunner", # Required for using TestInMemoryRunner.
root_agent=agent,
resumability_config=ResumabilityConfig(is_resumable=True),
)
@pytest.fixture
def runner(self, app: App) -> testing_utils.TestInMemoryRunner:
"""Provides an in-memory runner for the agent."""
return testing_utils.TestInMemoryRunner(app=app, app_name=None)
@staticmethod
def mock_model(responses: list[Part]) -> testing_utils.MockModel:
"""Provides a mock model with predefined responses."""
return testing_utils.MockModel.create(responses=responses)
class TestPauseInvocationWithSingleLlmAgent(BasePauseInvocationTest):
"""Tests the resumption flow with a single LlmAgent."""
@pytest.fixture
def agent(self) -> BaseAgent:
"""Provides a BaseAgent for the test."""
def test_tool() -> str:
return ""
return LlmAgent(
name="root_agent",
model=self.mock_model(
responses=[Part.from_function_call(name="test_tool", args={})]
),
tools=[LongRunningFunctionTool(func=test_tool)],
)
@pytest.mark.asyncio
async def test_pause_on_long_running_function_call(
self,
runner: testing_utils.TestInMemoryRunner,
):
"""Tests that a single LlmAgent pauses on long running function call."""
assert testing_utils.simplify_events(
await runner.run_async_with_new_session("test")
) == [
("root_agent", Part.from_function_call(name="test_tool", args={})),
]
class TestPauseInvocationWithSequentialAgent(BasePauseInvocationTest):
"""Tests pausing invocation with a SequentialAgent."""
@pytest.fixture
def agent(self) -> BaseAgent:
"""Provides a BaseAgent for the test."""
sub_agent1 = LlmAgent(
name="sub_agent_1",
model=self.mock_model(
responses=[Part.from_function_call(name="test_tool", args={})]
),
tools=[LongRunningFunctionTool(func=test_tool)],
)
sub_agent2 = LlmAgent(
name="sub_agent_2",
model=self.mock_model(
responses=[Part.from_function_call(name="test_tool", args={})]
),
tools=[LongRunningFunctionTool(func=test_tool)],
)
return SequentialAgent(
name="root_agent",
sub_agents=[sub_agent1, sub_agent2],
)
@pytest.mark.asyncio
async def test_pause_first_agent_on_long_running_function_call(
self,
runner: testing_utils.TestInMemoryRunner,
):
"""Tests that a single LlmAgent pauses on long running function call."""
assert testing_utils.simplify_events(
await runner.run_async_with_new_session("test")
) == [
("sub_agent_1", Part.from_function_call(name="test_tool", args={})),
]
@pytest.mark.asyncio
async def test_pause_second_agent_on_long_running_function_call(
self,
runner: testing_utils.TestInMemoryRunner,
):
"""Tests that a single LlmAgent pauses on long running function call."""
# Change the base sequential agent, so that the first agent does not pause.
runner.agent.sub_agents[0].tools = [FunctionTool(func=test_tool)]
runner.agent.sub_agents[0].model = self.mock_model(
responses=[
Part.from_function_call(name="test_tool", args={}),
Part.from_text(text="model response after tool call"),
]
)
assert testing_utils.simplify_events(
await runner.run_async_with_new_session("test")
) == [
("sub_agent_1", Part.from_function_call(name="test_tool", args={})),
(
"sub_agent_1",
Part.from_function_response(
name="test_tool", response={"result": ""}
),
),
("sub_agent_1", "model response after tool call"),
("sub_agent_2", Part.from_function_call(name="test_tool", args={})),
]
class TestPauseInvocationWithParallelAgent(BasePauseInvocationTest):
"""Tests pausing invocation with a ParallelAgent."""
@pytest.fixture
def agent(self) -> BaseAgent:
"""Provides a BaseAgent for the test."""
sub_agent1 = LlmAgent(
name="sub_agent_1",
model=self.mock_model(
responses=[Part.from_function_call(name="test_tool", args={})]
),
tools=[LongRunningFunctionTool(func=test_tool)],
)
sub_agent2 = _TestingAgent(
name="sub_agent_2",
delay=0.5,
)
return ParallelAgent(
name="root_agent",
sub_agents=[sub_agent1, sub_agent2],
)
@pytest.mark.asyncio
async def test_pause_on_long_running_function_call(
self,
runner: testing_utils.TestInMemoryRunner,
):
"""Tests that a ParallelAgent pauses on long running function call."""
assert testing_utils.simplify_events(
await runner.run_async_with_new_session("test")
) == [
("sub_agent_1", Part.from_function_call(name="test_tool", args={})),
("sub_agent_2", "Delayed message"),
]
class TestPauseInvocationWithNestedParallelAgent(BasePauseInvocationTest):
"""Tests pausing invocation with a nested ParallelAgent."""
@pytest.fixture
def agent(self) -> BaseAgent:
"""Provides a BaseAgent for the test."""
nested_sub_agent_1 = LlmAgent(
name="nested_sub_agent_1",
model=self.mock_model(
responses=[Part.from_function_call(name="test_tool", args={})]
),
tools=[LongRunningFunctionTool(func=test_tool)],
)
nested_sub_agent_2 = _TestingAgent(
name="nested_sub_agent_2",
delay=0.5,
)
nested_parallel_agent = ParallelAgent(
name="nested_parallel_agent",
sub_agents=[nested_sub_agent_1, nested_sub_agent_2],
)
sub_agent_1 = _TestingAgent(
name="sub_agent_1",
delay=0.5,
)
return ParallelAgent(
name="root_agent",
sub_agents=[sub_agent_1, nested_parallel_agent],
)
@pytest.mark.asyncio
async def test_pause_on_long_running_function_call_in_nested_agent(
self,
runner: testing_utils.TestInMemoryRunner,
):
"""Tests that a nested ParallelAgent pauses on long running function call."""
assert testing_utils.simplify_events(
await runner.run_async_with_new_session("test")
) == [
(
"nested_sub_agent_1",
Part.from_function_call(name="test_tool", args={}),
),
("sub_agent_1", "Delayed message"),
("nested_sub_agent_2", "Delayed message"),
]
@pytest.mark.asyncio
async def test_pause_on_multiple_long_running_function_calls(
self,
runner: testing_utils.TestInMemoryRunner,
):
"""Tests that a ParallelAgent pauses on long running function calls."""
runner.agent.sub_agents[0] = LlmAgent(
name="sub_agent_1",
model=self.mock_model(
responses=[
Part.from_function_call(name="test_tool", args={}),
Part.from_function_call(name="test_tool", args={}),
]
),
tools=[LongRunningFunctionTool(func=test_tool)],
)
simplified_events = testing_utils.simplify_events(
await runner.run_async_with_new_session("test")
)
assert len(simplified_events) == 3
assert (
"sub_agent_1",
Part.from_function_call(name="test_tool", args={}),
) in simplified_events
assert (
"nested_sub_agent_1",
Part.from_function_call(name="test_tool", args={}),
) in simplified_events
class TestPauseInvocationWithLoopAgent(BasePauseInvocationTest):
"""Tests pausing invocation with a LoopAgent."""
@pytest.fixture
def agent(self) -> BaseAgent:
"""Provides a BaseAgent for the test."""
sub_agent_1 = LlmAgent(
name="sub_agent_1",
model=self.mock_model(
responses=[
Part.from_text(text="sub agent 1 response"),
]
),
)
sub_agent_2 = LlmAgent(
name="sub_agent_2",
model=self.mock_model(
responses=[
Part.from_function_call(name="test_tool", args={}),
]
),
tools=[LongRunningFunctionTool(func=test_tool)],
)
sub_agent_3 = LlmAgent(
name="sub_agent_3",
model=self.mock_model(
responses=[
Part.from_function_call(name="exit_loop", args={}),
]
),
tools=[exit_loop],
)
return LoopAgent(
name="root_agent",
sub_agents=[sub_agent_1, sub_agent_2, sub_agent_3],
max_iterations=2,
)
@pytest.mark.asyncio
async def test_pause_on_long_running_function_call_in_loop(
self,
runner: testing_utils.TestInMemoryRunner,
):
"""Tests that a LoopAgent pauses on long running function call."""
assert testing_utils.simplify_events(
await runner.run_async_with_new_session("test")
) == [
("sub_agent_1", "sub agent 1 response"),
("sub_agent_2", Part.from_function_call(name="test_tool", args={})),
]
class TestPauseInvocationWithLlmAgentTree(BasePauseInvocationTest):
"""Tests the pausing invocation with a tree of LlmAgents."""
@pytest.fixture
def agent(self) -> LlmAgent:
"""Provides an LlmAgent with sub-agents for the test."""
sub_llm_agent_1 = LlmAgent(
name="sub_llm_agent_1",
model=self.mock_model(
responses=[
_transfer_call_part("sub_llm_agent_2"),
"llm response not used",
]
),
)
sub_llm_agent_2 = LlmAgent(
name="sub_llm_agent_2",
model=self.mock_model(
responses=[
Part.from_function_call(name="test_tool", args={}),
"llm response not used",
]
),
tools=[LongRunningFunctionTool(func=test_tool)],
)
return LlmAgent(
name="root_agent",
model=self.mock_model(
responses=[
_transfer_call_part("sub_llm_agent_1"),
"llm response not used",
]
),
sub_agents=[sub_llm_agent_1, sub_llm_agent_2],
)
@pytest.mark.asyncio
async def test_pause_on_transfer_call_part(
self,
runner: testing_utils.TestInMemoryRunner,
):
"""Tests that a tree of resumable LlmAgents yields checkpoint events."""
assert testing_utils.simplify_events(
await runner.run_async_with_new_session("test")
) == [
("root_agent", _transfer_call_part("sub_llm_agent_1")),
("root_agent", _TRANSFER_RESPONSE_PART),
("sub_llm_agent_1", _transfer_call_part("sub_llm_agent_2")),
("sub_llm_agent_1", _TRANSFER_RESPONSE_PART),
("sub_llm_agent_2", Part.from_function_call(name="test_tool", args={})),
]
class TestPauseInvocationWithWithTransferLoop(BasePauseInvocationTest):
"""Tests the pausing the invocation when the agent transfer forms a loop."""
@pytest.fixture
def agent(self) -> LlmAgent:
"""Provides an LlmAgent with sub-agents for the test."""
sub_llm_agent_1 = LlmAgent(
name="sub_llm_agent_1",
model=self.mock_model(
responses=[
_transfer_call_part("sub_llm_agent_2"),
"llm response not used",
]
),
)
sub_llm_agent_2 = LlmAgent(
name="sub_llm_agent_2",
model=self.mock_model(
responses=[
_transfer_call_part("root_agent"),
"llm response not used",
]
),
)
return LlmAgent(
name="root_agent",
model=self.mock_model(
responses=[
_transfer_call_part("sub_llm_agent_1"),
Part.from_function_call(name="test_tool", args={}),
"llm response not used",
]
),
sub_agents=[sub_llm_agent_1, sub_llm_agent_2],
tools=[LongRunningFunctionTool(func=test_tool)],
)
@pytest.mark.asyncio
async def test_agent_tree_yields_checkpoints(
self,
runner: testing_utils.TestInMemoryRunner,
):
"""Tests that a tree of resumable LlmAgents yields checkpoint events."""
assert testing_utils.simplify_events(
await runner.run_async_with_new_session("test")
) == [
("root_agent", _transfer_call_part("sub_llm_agent_1")),
("root_agent", _TRANSFER_RESPONSE_PART),
("sub_llm_agent_1", _transfer_call_part("sub_llm_agent_2")),
("sub_llm_agent_1", _TRANSFER_RESPONSE_PART),
("sub_llm_agent_2", _transfer_call_part("root_agent")),
("sub_llm_agent_2", _TRANSFER_RESPONSE_PART),
("root_agent", Part.from_function_call(name="test_tool", args={})),
]
@@ -19,9 +19,8 @@ from unittest import mock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.loop_agent import LoopAgent
from google.adk.agents.parallel_agent import ParallelAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.apps.app import App
from google.adk.apps.app import ResumabilityConfig
from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
@@ -363,3 +362,88 @@ class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest):
testing_utils.simplify_events(copy.deepcopy(events))
== expected_parts_final
)
class TestHITLConfirmationFlowWithResumableApp:
"""Tests the HITL confirmation flow with a resumable app."""
@pytest.fixture
def tools(self) -> list[FunctionTool]:
"""Provides the tools for the agent."""
return [FunctionTool(func=_test_request_confirmation_function)]
@pytest.fixture
def llm_responses(
self, tools: list[FunctionTool]
) -> list[GenerateContentResponse]:
"""Provides mock LLM responses for the tests."""
return [
_create_llm_response_from_tools(tools),
_create_llm_response_from_text("test llm response after tool call"),
_create_llm_response_from_text(
"test llm response after final tool call"
),
]
@pytest.fixture
def mock_model(
self, llm_responses: list[GenerateContentResponse]
) -> testing_utils.MockModel:
"""Provides a mock model with predefined responses."""
return testing_utils.MockModel(responses=llm_responses)
@pytest.fixture
def agent(
self, mock_model: testing_utils.MockModel, tools: list[FunctionTool]
) -> LlmAgent:
"""Provides a single LlmAgent for the test."""
return LlmAgent(name="root_agent", model=mock_model, tools=tools)
@pytest.fixture
def runner(self, agent: LlmAgent) -> testing_utils.TestInMemoryRunner:
"""Provides an in-memory runner for the agent."""
# Mark the app as resumable. So that the invocation will be paused after the
# long running tool call.
app = App(
name="InMemoryRunner", # Required for using TestInMemoryRunner.
resumability_config=ResumabilityConfig(is_resumable=True),
root_agent=agent,
)
return testing_utils.TestInMemoryRunner(app=app, app_name=None)
@pytest.mark.asyncio
async def test_pause_on_request_confirmation(
self,
runner: testing_utils.TestInMemoryRunner,
agent: LlmAgent,
):
"""Tests HITL flow where all tool calls are confirmed."""
events = await runner.run_async_with_new_session("test user query")
# Verify that the invocation is paused after the long running tool call.
# So that no intermediate function response and llm response is generated.
assert testing_utils.simplify_events(copy.deepcopy(events)) == [
(
agent.name,
Part(function_call=FunctionCall(name=agent.tools[0].name, args={})),
),
(
agent.name,
Part(
function_call=FunctionCall(
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
args={
"originalFunctionCall": {
"name": agent.tools[0].name,
"id": mock.ANY,
"args": {},
},
"toolConfirmation": {
"hint": "test hint for request_confirmation",
"confirmed": False,
},
},
)
),
),
]