feat: Implement checkpoint and resume logic for SequentialAgent

PiperOrigin-RevId: 811977004
This commit is contained in:
Shangjie Chen
2025-09-26 15:25:59 -07:00
committed by Copybara-Service
parent 28d44a365a
commit 1ee01cc05a
8 changed files with 245 additions and 72 deletions
+12 -15
View File
@@ -163,44 +163,41 @@ class BaseAgent(BaseModel):
self,
ctx: InvocationContext,
state_type: Type[AgentState],
default_state: AgentState,
) -> tuple[AgentState, bool]:
) -> Optional[AgentState]:
"""Loads the agent state from the invocation context, handling resumption.
Args:
ctx: The invocation context.
state_type: The type of the agent state.
default_state: The default state to use if not resuming.
Returns:
tuple[AgentState, bool]: The current state and a boolean indicating if
resuming.
The current state if resuming, otherwise None.
"""
if not ctx.is_resumable:
return None
if self.name not in ctx.agent_states:
return default_state, False
return None
else:
return state_type.model_validate(ctx.agent_states.get(self.name)), True
return state_type.model_validate(ctx.agent_states.get(self.name))
def _create_agent_state_event(
self,
ctx: InvocationContext,
*,
state: Optional[BaseAgentState] = None,
agent_state: Optional[BaseAgentState] = None,
end_of_agent: bool = False,
) -> Event:
"""Creates an event for agent state.
"""Returns an event with agent state.
Args:
ctx: The invocation context.
state: The agent state to checkpoint.
agent_state: The agent state to checkpoint.
end_of_agent: Whether the agent is finished running.
Returns:
An Event object representing the checkpoint.
"""
event_actions = EventActions()
if state:
event_actions.agent_state = state.model_dump(mode='json')
if agent_state:
event_actions.agent_state = agent_state.model_dump(mode='json')
if end_of_agent:
event_actions.end_of_agent = True
return Event(
+9 -4
View File
@@ -208,6 +208,14 @@ class InvocationContext(BaseModel):
of this invocation.
"""
@property
def is_resumable(self) -> bool:
"""Returns whether the current invocation is resumable."""
return (
self.resumability_config is not None
and self.resumability_config.is_resumable
)
def reset_agent_state(self, agent_name: str) -> None:
"""Resets the state of an agent, allowing it to be re-run."""
self.agent_states.pop(agent_name, None)
@@ -284,10 +292,7 @@ class InvocationContext(BaseModel):
Returns:
Whether to pause the invocation right after this event.
"""
if (
not self.resumability_config
or not self.resumability_config.is_resumable
):
if not self.is_resumable:
return False
if not event.long_running_tool_ids or not event.get_function_calls():
+53 -5
View File
@@ -16,6 +16,7 @@
from __future__ import annotations
import logging
from typing import AsyncGenerator
from typing import ClassVar
from typing import Type
@@ -32,6 +33,8 @@ from .invocation_context import InvocationContext
from .llm_agent import LlmAgent
from .sequential_agent_config import SequentialAgentConfig
logger = logging.getLogger('google_adk.' + __name__)
@experimental
class SequentialAgentState(BaseAgentState):
@@ -51,12 +54,27 @@ class SequentialAgent(BaseAgent):
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
# Skip if there is no sub-agent.
if not self.sub_agents:
return
for sub_agent in self.sub_agents:
pause_invocation = False
# Initialize or resume the execution state from the agent state.
agent_state = self._load_agent_state(ctx, SequentialAgentState)
start_index = self._get_start_index(agent_state)
pause_invocation = False
resuming_sub_agent = agent_state is not None
for i in range(start_index, len(self.sub_agents)):
sub_agent = self.sub_agents[i]
if not resuming_sub_agent:
# If we are resuming from the current event, it means the same event has
# already been logged, so we should avoid yielding it again.
if ctx.is_resumable:
agent_state = SequentialAgentState(current_sub_agent=sub_agent.name)
yield self._create_agent_state_event(ctx, agent_state=agent_state)
# Reset the sub-agent's state in the context to ensure that each
# sub-agent starts fresh.
ctx.reset_agent_state(sub_agent.name)
async with Aclosing(sub_agent.run_async(ctx)) as agen:
async for event in agen:
@@ -64,11 +82,41 @@ class SequentialAgent(BaseAgent):
if ctx.should_pause_invocation(event):
pause_invocation = True
# Indicates the invocation should pause when receiving signal from
# the current sub_agent.
# Skip the rest of the sub-agents if the invocation is paused.
if pause_invocation:
return
# Reset the flag for the next sub-agent.
resuming_sub_agent = False
if ctx.is_resumable:
yield self._create_agent_state_event(ctx, end_of_agent=True)
def _get_start_index(
self,
agent_state: SequentialAgentState,
) -> int:
"""Calculates the start index for the sub-agent loop."""
if not agent_state:
return 0
if not agent_state.current_sub_agent:
# This means the process was finished.
return len(self.sub_agents)
try:
sub_agent_names = [sub_agent.name for sub_agent in self.sub_agents]
return sub_agent_names.index(agent_state.current_sub_agent)
except ValueError:
# A sub-agent was removed so the agent name is not found.
# For now, we restart from the beginning.
logger.warning(
'Sub-agent %s was removed so the agent name is not found. Restarting'
' from the beginning.',
agent_state.current_sub_agent,
)
return 0
@override
async def _run_live_impl(
self, ctx: InvocationContext
@@ -350,6 +350,8 @@ def _get_current_turn_contents(
# Find the latest event that starts the current turn and process from there
for i in range(len(events) - 1, -1, -1):
event = events[i]
if not event.content:
continue
if event.author == 'user' or _is_other_agent_reply(agent_name, event):
return _get_contents(current_branch, events[i:], agent_name)
+70 -44
View File
@@ -26,6 +26,7 @@ from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.base_agent import BaseAgentState
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
from google.adk.apps.app import ResumabilityConfig
from google.adk.events.event import Event
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.plugins.plugin_manager import PluginManager
@@ -733,39 +734,6 @@ async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
[e async for e in agent.run_live(parent_ctx)]
@pytest.mark.asyncio
async def test_create_agent_state_event(request: pytest.FixtureRequest):
# Arrange
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
ctx = await _create_parent_invocation_context(
request.function.__name__, agent, branch='test_branch'
)
state = BaseAgentState()
# Act
event = agent._create_agent_state_event(ctx, state=state)
# Assert
assert event.invocation_id == ctx.invocation_id
assert event.author == agent.name
assert event.branch == 'test_branch'
assert event.actions is not None
assert event.actions.agent_state is not None
assert event.actions.agent_state == state.model_dump(mode='json')
assert not event.actions.end_of_agent
# Act
event = agent._create_agent_state_event(ctx, end_of_agent=True)
# Assert
assert event.invocation_id == ctx.invocation_id
assert event.author == agent.name
assert event.branch == 'test_branch'
assert event.actions is not None
assert event.actions.end_of_agent
assert event.actions.agent_state is None
def test_set_parent_agent_for_sub_agents(request: pytest.FixtureRequest):
sub_agents: list[BaseAgent] = [
_TestingAgent(name=f'{request.function.__name__}_sub_agent_1'),
@@ -895,7 +863,7 @@ class _TestAgentState(BaseAgentState):
@pytest.mark.asyncio
async def test_load_agent_state_no_resume():
async def test_load_agent_state_not_resumable():
agent = BaseAgent(name='test_agent')
session_service = InMemorySessionService()
session = await session_service.create_session(
@@ -907,14 +875,15 @@ async def test_load_agent_state_no_resume():
session=session,
session_service=session_service,
)
default_state = _TestAgentState(test_field='default')
state, is_resuming = agent._load_agent_state(
ctx, _TestAgentState, default_state
)
# Test case 1: resumability_config is None
state = agent._load_agent_state(ctx, _TestAgentState)
assert state is None
assert not is_resuming
assert state == default_state
# Test case 2: is_resumable is False
ctx.resumability_config = ResumabilityConfig(is_resumable=False)
state = agent._load_agent_state(ctx, _TestAgentState)
assert state is None
@pytest.mark.asyncio
@@ -929,13 +898,70 @@ async def test_load_agent_state_with_resume():
agent=agent,
session=session,
session_service=session_service,
resumability_config=ResumabilityConfig(is_resumable=True),
)
# Test case 1: agent state not in context
state = agent._load_agent_state(ctx, _TestAgentState)
assert state is None
# Test case 2: agent state in context
persisted_state = _TestAgentState(test_field='resumed')
ctx.agent_states[agent.name] = persisted_state.model_dump(mode='json')
state, is_resuming = agent._load_agent_state(
ctx, _TestAgentState, _TestAgentState()
state = agent._load_agent_state(ctx, _TestAgentState)
assert state == persisted_state
@pytest.mark.asyncio
async def test_create_agent_state_event():
agent = BaseAgent(name='test_agent')
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
ctx = InvocationContext(
invocation_id='test_invocation',
agent=agent,
session=session,
session_service=session_service,
)
assert is_resuming
assert state == persisted_state
ctx.branch = 'test_branch'
# Test case 1: with state
state = _TestAgentState(test_field='checkpoint')
event = agent._create_agent_state_event(ctx, agent_state=state)
assert event is not None
assert event.invocation_id == ctx.invocation_id
assert event.author == agent.name
assert event.branch == 'test_branch'
assert event.actions is not None
assert event.actions.agent_state is not None
assert event.actions.agent_state == state.model_dump(mode='json')
assert not event.actions.end_of_agent
# Test case 2: with end_of_agent
event = agent._create_agent_state_event(ctx, end_of_agent=True)
assert event is not None
assert event.invocation_id == ctx.invocation_id
assert event.author == agent.name
assert event.branch == 'test_branch'
assert event.actions is not None
assert event.actions.end_of_agent
assert event.actions.agent_state is None
# Test case 3: with both state and end_of_agent
state = _TestAgentState(test_field='checkpoint')
event = agent._create_agent_state_event(
ctx, agent_state=state, end_of_agent=True
)
assert event is not None
assert event.actions.agent_state == state.model_dump(mode='json')
assert event.actions.end_of_agent
# Test case 4: with neither
event = agent._create_agent_state_event(ctx)
assert event is not None
assert event.actions.agent_state is None
assert not event.actions.end_of_agent
@@ -206,3 +206,22 @@ class TestInvocationContextWithAppResumablity:
assert not mock_invocation_context.should_pause_invocation(
nonpausable_event
)
def test_is_resumable_true(self):
"""Tests that is_resumable is True when resumability is enabled."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
assert invocation_context.is_resumable
def test_is_resumable_false(self):
"""Tests that is_resumable is False when resumability is disabled."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=False)
)
assert not invocation_context.is_resumable
def test_is_resumable_no_config(self):
"""Tests that is_resumable is False when no resumability config is set."""
invocation_context = self._create_test_invocation_context(None)
assert not invocation_context.is_resumable
@@ -219,9 +219,10 @@ async def test_include_contents_none_sequential_agents():
runner = testing_utils.InMemoryRunner(sequential_agent)
events = runner.run("Original user request")
assert len(events) == 2
assert events[0].author == "agent1"
assert events[1].author == "agent2"
simplified_events = [event for event in events if event.content]
assert len(simplified_events) == 2
assert simplified_events[0].author == "agent1"
assert simplified_events[1].author == "agent2"
# Agent1 sees original user request
agent1_contents = testing_utils.simplify_contents(
@@ -19,6 +19,8 @@ from typing import AsyncGenerator
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.agents.sequential_agent import SequentialAgentState
from google.adk.apps import ResumabilityConfig
from google.adk.events.event import Event
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.genai import types
@@ -54,7 +56,7 @@ class _TestingAgent(BaseAgent):
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
test_name: str, agent: BaseAgent, resumable: bool = False
) -> InvocationContext:
session_service = InMemorySessionService()
session = await session_service.create_session(
@@ -65,6 +67,7 @@ async def _create_parent_invocation_context(
agent=agent,
session=session,
session_service=session_service,
resumability_config=ResumabilityConfig(is_resumable=resumable),
)
@@ -105,6 +108,78 @@ async def test_run_async_skip_if_no_sub_agent(request: pytest.FixtureRequest):
assert not events
@pytest.mark.asyncio
async def test_run_async_with_resumability(request: pytest.FixtureRequest):
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
sequential_agent = SequentialAgent(
name=f'{request.function.__name__}_test_agent',
sub_agents=[
agent_1,
agent_2,
],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent, resumable=True
)
events = [e async for e in sequential_agent.run_async(parent_ctx)]
# 5 events:
# 1. SequentialAgent checkpoint event for agent 1
# 2. Agent 1 event
# 3. SequentialAgent checkpoint event for agent 2
# 4. Agent 2 event
# 5. SequentialAgent final checkpoint event
assert len(events) == 5
assert events[0].author == sequential_agent.name
assert not events[0].actions.end_of_agent
assert events[0].actions.agent_state['current_sub_agent'] == agent_1.name
assert events[1].author == agent_1.name
assert events[1].content.parts[0].text == f'Hello, async {agent_1.name}!'
assert events[2].author == sequential_agent.name
assert not events[2].actions.end_of_agent
assert events[2].actions.agent_state['current_sub_agent'] == agent_2.name
assert events[3].author == agent_2.name
assert events[3].content.parts[0].text == f'Hello, async {agent_2.name}!'
assert events[4].author == sequential_agent.name
assert events[4].actions.end_of_agent
@pytest.mark.asyncio
async def test_resume_async(request: pytest.FixtureRequest):
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
sequential_agent = SequentialAgent(
name=f'{request.function.__name__}_test_agent',
sub_agents=[
agent_1,
agent_2,
],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, sequential_agent, resumable=True
)
parent_ctx.agent_states[sequential_agent.name] = SequentialAgentState(
current_sub_agent=agent_2.name
).model_dump(mode='json')
events = [e async for e in sequential_agent.run_async(parent_ctx)]
# 2 events:
# 1. Agent 2 event
# 2. SequentialAgent final checkpoint event
assert len(events) == 2
assert events[0].author == agent_2.name
assert events[0].content.parts[0].text == f'Hello, async {agent_2.name}!'
assert events[1].author == sequential_agent.name
assert events[1].actions.end_of_agent
@pytest.mark.asyncio
async def test_run_live(request: pytest.FixtureRequest):
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')