You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Implement checkpoint and resume logic for SequentialAgent
PiperOrigin-RevId: 811977004
This commit is contained in:
committed by
Copybara-Service
parent
28d44a365a
commit
1ee01cc05a
@@ -163,44 +163,41 @@ class BaseAgent(BaseModel):
|
||||
self,
|
||||
ctx: InvocationContext,
|
||||
state_type: Type[AgentState],
|
||||
default_state: AgentState,
|
||||
) -> tuple[AgentState, bool]:
|
||||
) -> Optional[AgentState]:
|
||||
"""Loads the agent state from the invocation context, handling resumption.
|
||||
|
||||
Args:
|
||||
ctx: The invocation context.
|
||||
state_type: The type of the agent state.
|
||||
default_state: The default state to use if not resuming.
|
||||
|
||||
Returns:
|
||||
tuple[AgentState, bool]: The current state and a boolean indicating if
|
||||
resuming.
|
||||
The current state if resuming, otherwise None.
|
||||
"""
|
||||
if not ctx.is_resumable:
|
||||
return None
|
||||
|
||||
if self.name not in ctx.agent_states:
|
||||
return default_state, False
|
||||
return None
|
||||
else:
|
||||
return state_type.model_validate(ctx.agent_states.get(self.name)), True
|
||||
return state_type.model_validate(ctx.agent_states.get(self.name))
|
||||
|
||||
def _create_agent_state_event(
|
||||
self,
|
||||
ctx: InvocationContext,
|
||||
*,
|
||||
state: Optional[BaseAgentState] = None,
|
||||
agent_state: Optional[BaseAgentState] = None,
|
||||
end_of_agent: bool = False,
|
||||
) -> Event:
|
||||
"""Creates an event for agent state.
|
||||
"""Returns an event with agent state.
|
||||
|
||||
Args:
|
||||
ctx: The invocation context.
|
||||
state: The agent state to checkpoint.
|
||||
agent_state: The agent state to checkpoint.
|
||||
end_of_agent: Whether the agent is finished running.
|
||||
|
||||
Returns:
|
||||
An Event object representing the checkpoint.
|
||||
"""
|
||||
event_actions = EventActions()
|
||||
if state:
|
||||
event_actions.agent_state = state.model_dump(mode='json')
|
||||
if agent_state:
|
||||
event_actions.agent_state = agent_state.model_dump(mode='json')
|
||||
if end_of_agent:
|
||||
event_actions.end_of_agent = True
|
||||
return Event(
|
||||
|
||||
@@ -208,6 +208,14 @@ class InvocationContext(BaseModel):
|
||||
of this invocation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def is_resumable(self) -> bool:
|
||||
"""Returns whether the current invocation is resumable."""
|
||||
return (
|
||||
self.resumability_config is not None
|
||||
and self.resumability_config.is_resumable
|
||||
)
|
||||
|
||||
def reset_agent_state(self, agent_name: str) -> None:
|
||||
"""Resets the state of an agent, allowing it to be re-run."""
|
||||
self.agent_states.pop(agent_name, None)
|
||||
@@ -284,10 +292,7 @@ class InvocationContext(BaseModel):
|
||||
Returns:
|
||||
Whether to pause the invocation right after this event.
|
||||
"""
|
||||
if (
|
||||
not self.resumability_config
|
||||
or not self.resumability_config.is_resumable
|
||||
):
|
||||
if not self.is_resumable:
|
||||
return False
|
||||
|
||||
if not event.long_running_tool_ids or not event.get_function_calls():
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import ClassVar
|
||||
from typing import Type
|
||||
@@ -32,6 +33,8 @@ from .invocation_context import InvocationContext
|
||||
from .llm_agent import LlmAgent
|
||||
from .sequential_agent_config import SequentialAgentConfig
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
@experimental
|
||||
class SequentialAgentState(BaseAgentState):
|
||||
@@ -51,12 +54,27 @@ class SequentialAgent(BaseAgent):
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
# Skip if there is no sub-agent.
|
||||
if not self.sub_agents:
|
||||
return
|
||||
|
||||
for sub_agent in self.sub_agents:
|
||||
pause_invocation = False
|
||||
# Initialize or resume the execution state from the agent state.
|
||||
agent_state = self._load_agent_state(ctx, SequentialAgentState)
|
||||
start_index = self._get_start_index(agent_state)
|
||||
|
||||
pause_invocation = False
|
||||
resuming_sub_agent = agent_state is not None
|
||||
for i in range(start_index, len(self.sub_agents)):
|
||||
sub_agent = self.sub_agents[i]
|
||||
if not resuming_sub_agent:
|
||||
# If we are resuming from the current event, it means the same event has
|
||||
# already been logged, so we should avoid yielding it again.
|
||||
if ctx.is_resumable:
|
||||
agent_state = SequentialAgentState(current_sub_agent=sub_agent.name)
|
||||
yield self._create_agent_state_event(ctx, agent_state=agent_state)
|
||||
|
||||
# Reset the sub-agent's state in the context to ensure that each
|
||||
# sub-agent starts fresh.
|
||||
ctx.reset_agent_state(sub_agent.name)
|
||||
|
||||
async with Aclosing(sub_agent.run_async(ctx)) as agen:
|
||||
async for event in agen:
|
||||
@@ -64,11 +82,41 @@ class SequentialAgent(BaseAgent):
|
||||
if ctx.should_pause_invocation(event):
|
||||
pause_invocation = True
|
||||
|
||||
# Indicates the invocation should pause when receiving signal from
|
||||
# the current sub_agent.
|
||||
# Skip the rest of the sub-agents if the invocation is paused.
|
||||
if pause_invocation:
|
||||
return
|
||||
|
||||
# Reset the flag for the next sub-agent.
|
||||
resuming_sub_agent = False
|
||||
|
||||
if ctx.is_resumable:
|
||||
yield self._create_agent_state_event(ctx, end_of_agent=True)
|
||||
|
||||
def _get_start_index(
|
||||
self,
|
||||
agent_state: SequentialAgentState,
|
||||
) -> int:
|
||||
"""Calculates the start index for the sub-agent loop."""
|
||||
if not agent_state:
|
||||
return 0
|
||||
|
||||
if not agent_state.current_sub_agent:
|
||||
# This means the process was finished.
|
||||
return len(self.sub_agents)
|
||||
|
||||
try:
|
||||
sub_agent_names = [sub_agent.name for sub_agent in self.sub_agents]
|
||||
return sub_agent_names.index(agent_state.current_sub_agent)
|
||||
except ValueError:
|
||||
# A sub-agent was removed so the agent name is not found.
|
||||
# For now, we restart from the beginning.
|
||||
logger.warning(
|
||||
'Sub-agent %s was removed so the agent name is not found. Restarting'
|
||||
' from the beginning.',
|
||||
agent_state.current_sub_agent,
|
||||
)
|
||||
return 0
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
|
||||
@@ -350,6 +350,8 @@ def _get_current_turn_contents(
|
||||
# Find the latest event that starts the current turn and process from there
|
||||
for i in range(len(events) - 1, -1, -1):
|
||||
event = events[i]
|
||||
if not event.content:
|
||||
continue
|
||||
if event.author == 'user' or _is_other_agent_reply(agent_name, event):
|
||||
return _get_contents(current_branch, events[i:], agent_name)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.base_agent import BaseAgentState
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.apps.app import ResumabilityConfig
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.plugins.base_plugin import BasePlugin
|
||||
from google.adk.plugins.plugin_manager import PluginManager
|
||||
@@ -733,39 +734,6 @@ async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
|
||||
[e async for e in agent.run_live(parent_ctx)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_state_event(request: pytest.FixtureRequest):
|
||||
# Arrange
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, agent, branch='test_branch'
|
||||
)
|
||||
state = BaseAgentState()
|
||||
|
||||
# Act
|
||||
event = agent._create_agent_state_event(ctx, state=state)
|
||||
|
||||
# Assert
|
||||
assert event.invocation_id == ctx.invocation_id
|
||||
assert event.author == agent.name
|
||||
assert event.branch == 'test_branch'
|
||||
assert event.actions is not None
|
||||
assert event.actions.agent_state is not None
|
||||
assert event.actions.agent_state == state.model_dump(mode='json')
|
||||
assert not event.actions.end_of_agent
|
||||
|
||||
# Act
|
||||
event = agent._create_agent_state_event(ctx, end_of_agent=True)
|
||||
|
||||
# Assert
|
||||
assert event.invocation_id == ctx.invocation_id
|
||||
assert event.author == agent.name
|
||||
assert event.branch == 'test_branch'
|
||||
assert event.actions is not None
|
||||
assert event.actions.end_of_agent
|
||||
assert event.actions.agent_state is None
|
||||
|
||||
|
||||
def test_set_parent_agent_for_sub_agents(request: pytest.FixtureRequest):
|
||||
sub_agents: list[BaseAgent] = [
|
||||
_TestingAgent(name=f'{request.function.__name__}_sub_agent_1'),
|
||||
@@ -895,7 +863,7 @@ class _TestAgentState(BaseAgentState):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_agent_state_no_resume():
|
||||
async def test_load_agent_state_not_resumable():
|
||||
agent = BaseAgent(name='test_agent')
|
||||
session_service = InMemorySessionService()
|
||||
session = await session_service.create_session(
|
||||
@@ -907,14 +875,15 @@ async def test_load_agent_state_no_resume():
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
)
|
||||
default_state = _TestAgentState(test_field='default')
|
||||
|
||||
state, is_resuming = agent._load_agent_state(
|
||||
ctx, _TestAgentState, default_state
|
||||
)
|
||||
# Test case 1: resumability_config is None
|
||||
state = agent._load_agent_state(ctx, _TestAgentState)
|
||||
assert state is None
|
||||
|
||||
assert not is_resuming
|
||||
assert state == default_state
|
||||
# Test case 2: is_resumable is False
|
||||
ctx.resumability_config = ResumabilityConfig(is_resumable=False)
|
||||
state = agent._load_agent_state(ctx, _TestAgentState)
|
||||
assert state is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -929,13 +898,70 @@ async def test_load_agent_state_with_resume():
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
resumability_config=ResumabilityConfig(is_resumable=True),
|
||||
)
|
||||
|
||||
# Test case 1: agent state not in context
|
||||
state = agent._load_agent_state(ctx, _TestAgentState)
|
||||
assert state is None
|
||||
|
||||
# Test case 2: agent state in context
|
||||
persisted_state = _TestAgentState(test_field='resumed')
|
||||
ctx.agent_states[agent.name] = persisted_state.model_dump(mode='json')
|
||||
|
||||
state, is_resuming = agent._load_agent_state(
|
||||
ctx, _TestAgentState, _TestAgentState()
|
||||
state = agent._load_agent_state(ctx, _TestAgentState)
|
||||
assert state == persisted_state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_state_event():
|
||||
agent = BaseAgent(name='test_agent')
|
||||
session_service = InMemorySessionService()
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
ctx = InvocationContext(
|
||||
invocation_id='test_invocation',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
assert is_resuming
|
||||
assert state == persisted_state
|
||||
ctx.branch = 'test_branch'
|
||||
|
||||
# Test case 1: with state
|
||||
state = _TestAgentState(test_field='checkpoint')
|
||||
event = agent._create_agent_state_event(ctx, agent_state=state)
|
||||
assert event is not None
|
||||
assert event.invocation_id == ctx.invocation_id
|
||||
assert event.author == agent.name
|
||||
assert event.branch == 'test_branch'
|
||||
assert event.actions is not None
|
||||
assert event.actions.agent_state is not None
|
||||
assert event.actions.agent_state == state.model_dump(mode='json')
|
||||
assert not event.actions.end_of_agent
|
||||
|
||||
# Test case 2: with end_of_agent
|
||||
event = agent._create_agent_state_event(ctx, end_of_agent=True)
|
||||
assert event is not None
|
||||
assert event.invocation_id == ctx.invocation_id
|
||||
assert event.author == agent.name
|
||||
assert event.branch == 'test_branch'
|
||||
assert event.actions is not None
|
||||
assert event.actions.end_of_agent
|
||||
assert event.actions.agent_state is None
|
||||
|
||||
# Test case 3: with both state and end_of_agent
|
||||
state = _TestAgentState(test_field='checkpoint')
|
||||
event = agent._create_agent_state_event(
|
||||
ctx, agent_state=state, end_of_agent=True
|
||||
)
|
||||
assert event is not None
|
||||
assert event.actions.agent_state == state.model_dump(mode='json')
|
||||
assert event.actions.end_of_agent
|
||||
|
||||
# Test case 4: with neither
|
||||
event = agent._create_agent_state_event(ctx)
|
||||
assert event is not None
|
||||
assert event.actions.agent_state is None
|
||||
assert not event.actions.end_of_agent
|
||||
|
||||
@@ -206,3 +206,22 @@ class TestInvocationContextWithAppResumablity:
|
||||
assert not mock_invocation_context.should_pause_invocation(
|
||||
nonpausable_event
|
||||
)
|
||||
|
||||
def test_is_resumable_true(self):
|
||||
"""Tests that is_resumable is True when resumability is enabled."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
assert invocation_context.is_resumable
|
||||
|
||||
def test_is_resumable_false(self):
|
||||
"""Tests that is_resumable is False when resumability is disabled."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=False)
|
||||
)
|
||||
assert not invocation_context.is_resumable
|
||||
|
||||
def test_is_resumable_no_config(self):
|
||||
"""Tests that is_resumable is False when no resumability config is set."""
|
||||
invocation_context = self._create_test_invocation_context(None)
|
||||
assert not invocation_context.is_resumable
|
||||
|
||||
@@ -219,9 +219,10 @@ async def test_include_contents_none_sequential_agents():
|
||||
runner = testing_utils.InMemoryRunner(sequential_agent)
|
||||
events = runner.run("Original user request")
|
||||
|
||||
assert len(events) == 2
|
||||
assert events[0].author == "agent1"
|
||||
assert events[1].author == "agent2"
|
||||
simplified_events = [event for event in events if event.content]
|
||||
assert len(simplified_events) == 2
|
||||
assert simplified_events[0].author == "agent1"
|
||||
assert simplified_events[1].author == "agent2"
|
||||
|
||||
# Agent1 sees original user request
|
||||
agent1_contents = testing_utils.simplify_contents(
|
||||
|
||||
@@ -19,6 +19,8 @@ from typing import AsyncGenerator
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgentState
|
||||
from google.adk.apps import ResumabilityConfig
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.genai import types
|
||||
@@ -54,7 +56,7 @@ class _TestingAgent(BaseAgent):
|
||||
|
||||
|
||||
async def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
test_name: str, agent: BaseAgent, resumable: bool = False
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = await session_service.create_session(
|
||||
@@ -65,6 +67,7 @@ async def _create_parent_invocation_context(
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
resumability_config=ResumabilityConfig(is_resumable=resumable),
|
||||
)
|
||||
|
||||
|
||||
@@ -105,6 +108,78 @@ async def test_run_async_skip_if_no_sub_agent(request: pytest.FixtureRequest):
|
||||
assert not events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_resumability(request: pytest.FixtureRequest):
|
||||
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
|
||||
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
|
||||
sequential_agent = SequentialAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
sub_agents=[
|
||||
agent_1,
|
||||
agent_2,
|
||||
],
|
||||
)
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, sequential_agent, resumable=True
|
||||
)
|
||||
events = [e async for e in sequential_agent.run_async(parent_ctx)]
|
||||
|
||||
# 5 events:
|
||||
# 1. SequentialAgent checkpoint event for agent 1
|
||||
# 2. Agent 1 event
|
||||
# 3. SequentialAgent checkpoint event for agent 2
|
||||
# 4. Agent 2 event
|
||||
# 5. SequentialAgent final checkpoint event
|
||||
assert len(events) == 5
|
||||
assert events[0].author == sequential_agent.name
|
||||
assert not events[0].actions.end_of_agent
|
||||
assert events[0].actions.agent_state['current_sub_agent'] == agent_1.name
|
||||
|
||||
assert events[1].author == agent_1.name
|
||||
assert events[1].content.parts[0].text == f'Hello, async {agent_1.name}!'
|
||||
|
||||
assert events[2].author == sequential_agent.name
|
||||
assert not events[2].actions.end_of_agent
|
||||
assert events[2].actions.agent_state['current_sub_agent'] == agent_2.name
|
||||
|
||||
assert events[3].author == agent_2.name
|
||||
assert events[3].content.parts[0].text == f'Hello, async {agent_2.name}!'
|
||||
|
||||
assert events[4].author == sequential_agent.name
|
||||
assert events[4].actions.end_of_agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_async(request: pytest.FixtureRequest):
|
||||
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
|
||||
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
|
||||
sequential_agent = SequentialAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
sub_agents=[
|
||||
agent_1,
|
||||
agent_2,
|
||||
],
|
||||
)
|
||||
parent_ctx = await _create_parent_invocation_context(
|
||||
request.function.__name__, sequential_agent, resumable=True
|
||||
)
|
||||
parent_ctx.agent_states[sequential_agent.name] = SequentialAgentState(
|
||||
current_sub_agent=agent_2.name
|
||||
).model_dump(mode='json')
|
||||
|
||||
events = [e async for e in sequential_agent.run_async(parent_ctx)]
|
||||
|
||||
# 2 events:
|
||||
# 1. Agent 2 event
|
||||
# 2. SequentialAgent final checkpoint event
|
||||
assert len(events) == 2
|
||||
assert events[0].author == agent_2.name
|
||||
assert events[0].content.parts[0].text == f'Hello, async {agent_2.name}!'
|
||||
|
||||
assert events[1].author == sequential_agent.name
|
||||
assert events[1].actions.end_of_agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live(request: pytest.FixtureRequest):
|
||||
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
|
||||
|
||||
Reference in New Issue
Block a user