feat: Set agent_state in invocation context right before yielding the checkpoint event

PiperOrigin-RevId: 816804179
This commit is contained in:
Xinran (Sherry) Tang
2025-10-08 12:01:27 -07:00
committed by Copybara-Service
parent 75179243b4
commit 32f2ec3a78
9 changed files with 113 additions and 39 deletions
+7 -9
View File
@@ -184,21 +184,19 @@ class BaseAgent(BaseModel):
def _create_agent_state_event(
self,
ctx: InvocationContext,
*,
agent_state: Optional[BaseAgentState] = None,
end_of_agent: bool = False,
) -> Event:
"""Returns an event with agent state.
"""Returns an event with current agent state set in the invocation context.
Args:
ctx: The invocation context.
agent_state: The agent state to checkpoint.
end_of_agent: Whether the agent is finished running.
Returns:
An event with the current agent state set in the invocation context.
"""
event_actions = EventActions()
if agent_state:
event_actions.agent_state = agent_state.model_dump(mode='json')
if end_of_agent:
if (agent_state := ctx.agent_states.get(self.name)) is not None:
event_actions.agent_state = agent_state
if ctx.end_of_agents.get(self.name):
event_actions.end_of_agent = True
return Event(
invocation_id=ctx.invocation_id,
+31 -4
View File
@@ -217,10 +217,37 @@ class InvocationContext(BaseModel):
and self.resumability_config.is_resumable
)
def reset_agent_state(self, agent_name: str) -> None:
"""Resets the state of an agent, allowing it to be re-run."""
self.agent_states.pop(agent_name, None)
self.end_of_agents.pop(agent_name, None)
def set_agent_state(
self,
agent_name: str,
*,
agent_state: Optional[BaseAgentState] = None,
end_of_agent: bool = False,
) -> None:
"""Sets the state of an agent in this invocation.
* If end_of_agent is True, will set the end_of_agent flag to True and
clear the agent_state.
* Otherwise, if agent_state is not None, will set the agent_state and
reset the end_of_agent flag to False.
* Otherwise, will clear the agent_state and end_of_agent flag, to allow the
agent to re-run.
Args:
agent_name: The name of the agent.
agent_state: The state of the agent. Will be ignored if end_of_agent is
True.
end_of_agent: Whether the agent has finished running.
"""
if end_of_agent:
self.end_of_agents[agent_name] = True
self.agent_states.pop(agent_name, None)
elif agent_state is not None:
self.agent_states[agent_name] = agent_state.model_dump(mode="json")
self.end_of_agents[agent_name] = False
else:
self.end_of_agents.pop(agent_name, None)
self.agent_states.pop(agent_name, None)
def populate_invocation_agent_states(self) -> None:
"""Populates agent states for the current invocation if it is resumable.
+4 -2
View File
@@ -388,7 +388,8 @@ class LlmAgent(BaseAgent):
async for event in agen:
yield event
yield self._create_agent_state_event(ctx, end_of_agent=True)
ctx.set_agent_state(self.name, end_of_agent=True)
yield self._create_agent_state_event(ctx)
return
async with Aclosing(self._llm_flow.run_async(ctx)) as agen:
@@ -399,7 +400,8 @@ class LlmAgent(BaseAgent):
return
if ctx.is_resumable:
yield self._create_agent_state_event(ctx, end_of_agent=True)
ctx.set_agent_state(self.name, end_of_agent=True)
yield self._create_agent_state_event(ctx)
@override
async def _run_live_impl(
+5 -3
View File
@@ -91,12 +91,13 @@ class LoopAgent(BaseAgent):
current_sub_agent=sub_agent.name,
times_looped=times_looped,
)
yield self._create_agent_state_event(ctx, agent_state=agent_state)
ctx.set_agent_state(self.name, agent_state=agent_state)
yield self._create_agent_state_event(ctx)
# Reset the sub-agent's state in the context to ensure that each
# sub-agent starts fresh.
if not is_resuming_at_current_agent:
ctx.reset_agent_state(sub_agent.name)
ctx.set_agent_state(sub_agent.name)
is_resuming_at_current_agent = False
async with Aclosing(sub_agent.run_async(ctx)) as agen:
@@ -119,7 +120,8 @@ class LoopAgent(BaseAgent):
return
if ctx.is_resumable:
yield self._create_agent_state_event(ctx, end_of_agent=True)
ctx.set_agent_state(self.name, end_of_agent=True)
yield self._create_agent_state_event(ctx)
def _get_start_state(
self,
+8 -4
View File
@@ -181,14 +181,15 @@ class ParallelAgent(BaseAgent):
agent_state = self._load_agent_state(ctx, BaseAgentState)
if ctx.is_resumable and agent_state is None:
yield self._create_agent_state_event(ctx, agent_state=BaseAgentState())
ctx.set_agent_state(self.name, agent_state=BaseAgentState())
yield self._create_agent_state_event(ctx)
agent_runs = []
# Prepare and collect async generators for each sub-agent.
for sub_agent in self.sub_agents:
if agent_state is None:
# Reset sub-agent state to make sure each sub-agent starts fresh.
ctx.reset_agent_state(sub_agent.name)
ctx.set_agent_state(sub_agent.name)
sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
@@ -215,8 +216,11 @@ class ParallelAgent(BaseAgent):
return
# Once all sub-agents are done, mark the ParallelAgent as final.
if ctx.is_resumable:
yield self._create_agent_state_event(ctx, end_of_agent=True)
if ctx.is_resumable and all(
ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents
):
ctx.set_agent_state(self.name, end_of_agent=True)
yield self._create_agent_state_event(ctx)
finally:
for sub_agent_run in agent_runs:
+5 -3
View File
@@ -70,11 +70,12 @@ class SequentialAgent(BaseAgent):
# already been logged, so we should avoid yielding it again.
if ctx.is_resumable:
agent_state = SequentialAgentState(current_sub_agent=sub_agent.name)
yield self._create_agent_state_event(ctx, agent_state=agent_state)
ctx.set_agent_state(self.name, agent_state=agent_state)
yield self._create_agent_state_event(ctx)
# Reset the sub-agent's state in the context to ensure that each
# sub-agent starts fresh.
ctx.reset_agent_state(sub_agent.name)
ctx.set_agent_state(sub_agent.name)
async with Aclosing(sub_agent.run_async(ctx)) as agen:
async for event in agen:
@@ -90,7 +91,8 @@ class SequentialAgent(BaseAgent):
resuming_sub_agent = False
if ctx.is_resumable:
yield self._create_agent_state_event(ctx, end_of_agent=True)
ctx.set_agent_state(self.name, end_of_agent=True)
yield self._create_agent_state_event(ctx)
def _get_start_index(
self,
+8 -14
View File
@@ -929,9 +929,10 @@ async def test_create_agent_state_event():
ctx.branch = 'test_branch'
# Test case 1: with state
# Test case 1: set agent state in context
state = _TestAgentState(test_field='checkpoint')
event = agent._create_agent_state_event(ctx, agent_state=state)
ctx.set_agent_state(agent.name, agent_state=state)
event = agent._create_agent_state_event(ctx)
assert event is not None
assert event.invocation_id == ctx.invocation_id
assert event.author == agent.name
@@ -941,8 +942,9 @@ async def test_create_agent_state_event():
assert event.actions.agent_state == state.model_dump(mode='json')
assert not event.actions.end_of_agent
# Test case 2: with end_of_agent
event = agent._create_agent_state_event(ctx, end_of_agent=True)
# Test case 2: set end_of_agent in context
ctx.set_agent_state(agent.name, end_of_agent=True)
event = agent._create_agent_state_event(ctx)
assert event is not None
assert event.invocation_id == ctx.invocation_id
assert event.author == agent.name
@@ -951,16 +953,8 @@ async def test_create_agent_state_event():
assert event.actions.end_of_agent
assert event.actions.agent_state is None
# Test case 3: with both state and end_of_agent
state = _TestAgentState(test_field='checkpoint')
event = agent._create_agent_state_event(
ctx, agent_state=state, end_of_agent=True
)
assert event is not None
assert event.actions.agent_state == state.model_dump(mode='json')
assert event.actions.end_of_agent
# Test case 4: with neither
# Test case 3: reset agent state and end_of_agent in context
ctx.set_agent_state(agent.name)
event = agent._create_agent_state_event(ctx)
assert event is not None
assert event.actions.agent_state is None
@@ -347,6 +347,49 @@ class TestInvocationContextWithAppResumablity:
assert not invocation_context.agent_states
assert not invocation_context.end_of_agents
def test_set_agent_state_with_end_of_agent_true(self):
"""Tests that set_agent_state clears agent_state and sets end_of_agent to True."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
invocation_context.agent_states['agent1'] = {}
invocation_context.end_of_agents['agent1'] = False
# Set state with end_of_agent=True, which should clear the existing
# agent_state.
invocation_context.set_agent_state('agent1', end_of_agent=True)
assert 'agent1' not in invocation_context.agent_states
assert invocation_context.end_of_agents['agent1']
def test_set_agent_state_with_agent_state(self):
"""Tests that set_agent_state sets agent_state and sets end_of_agent to False."""
agent_state = BaseAgentState()
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
invocation_context.end_of_agents['agent1'] = True
# Set state with agent_state=agent_state, which should set the agent_state
# and reset the end_of_agent flag to False.
invocation_context.set_agent_state('agent1', agent_state=agent_state)
assert invocation_context.agent_states['agent1'] == agent_state.model_dump(
mode='json'
)
assert invocation_context.end_of_agents['agent1'] is False
def test_reset_agent_state(self):
"""Tests that set_agent_state clears agent_state and end_of_agent."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
invocation_context.agent_states['agent1'] = {}
invocation_context.end_of_agents['agent1'] = True
# Reset state, which should clear the agent_state and end_of_agent flag.
invocation_context.set_agent_state('agent1')
assert 'agent1' not in invocation_context.agent_states
assert 'agent1' not in invocation_context.end_of_agents
class TestFindMatchingFunctionCall:
"""Test suite for find_matching_function_call."""
@@ -52,6 +52,8 @@ class _TestingAgent(BaseAgent):
) -> AsyncGenerator[Event, None]:
await asyncio.sleep(self.delay)
yield self.event(ctx)
if ctx.is_resumable:
ctx.set_agent_state(self.name, end_of_agent=True)
async def _create_parent_invocation_context(