You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Set agent_state in invocation context right before yielding the checkpoint event
PiperOrigin-RevId: 816804179
This commit is contained in:
committed by
Copybara-Service
parent
75179243b4
commit
32f2ec3a78
@@ -184,21 +184,19 @@ class BaseAgent(BaseModel):
|
||||
def _create_agent_state_event(
|
||||
self,
|
||||
ctx: InvocationContext,
|
||||
*,
|
||||
agent_state: Optional[BaseAgentState] = None,
|
||||
end_of_agent: bool = False,
|
||||
) -> Event:
|
||||
"""Returns an event with agent state.
|
||||
"""Returns an event with current agent state set in the invocation context.
|
||||
|
||||
Args:
|
||||
ctx: The invocation context.
|
||||
agent_state: The agent state to checkpoint.
|
||||
end_of_agent: Whether the agent is finished running.
|
||||
|
||||
Returns:
|
||||
An event with the current agent state set in the invocation context.
|
||||
"""
|
||||
event_actions = EventActions()
|
||||
if agent_state:
|
||||
event_actions.agent_state = agent_state.model_dump(mode='json')
|
||||
if end_of_agent:
|
||||
if (agent_state := ctx.agent_states.get(self.name)) is not None:
|
||||
event_actions.agent_state = agent_state
|
||||
if ctx.end_of_agents.get(self.name):
|
||||
event_actions.end_of_agent = True
|
||||
return Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
|
||||
@@ -217,10 +217,37 @@ class InvocationContext(BaseModel):
|
||||
and self.resumability_config.is_resumable
|
||||
)
|
||||
|
||||
def reset_agent_state(self, agent_name: str) -> None:
|
||||
"""Resets the state of an agent, allowing it to be re-run."""
|
||||
self.agent_states.pop(agent_name, None)
|
||||
self.end_of_agents.pop(agent_name, None)
|
||||
def set_agent_state(
|
||||
self,
|
||||
agent_name: str,
|
||||
*,
|
||||
agent_state: Optional[BaseAgentState] = None,
|
||||
end_of_agent: bool = False,
|
||||
) -> None:
|
||||
"""Sets the state of an agent in this invocation.
|
||||
|
||||
* If end_of_agent is True, will set the end_of_agent flag to True and
|
||||
clear the agent_state.
|
||||
* Otherwise, if agent_state is not None, will set the agent_state and
|
||||
reset the end_of_agent flag to False.
|
||||
* Otherwise, will clear the agent_state and end_of_agent flag, to allow the
|
||||
agent to re-run.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the agent.
|
||||
agent_state: The state of the agent. Will be ignored if end_of_agent is
|
||||
True.
|
||||
end_of_agent: Whether the agent has finished running.
|
||||
"""
|
||||
if end_of_agent:
|
||||
self.end_of_agents[agent_name] = True
|
||||
self.agent_states.pop(agent_name, None)
|
||||
elif agent_state is not None:
|
||||
self.agent_states[agent_name] = agent_state.model_dump(mode="json")
|
||||
self.end_of_agents[agent_name] = False
|
||||
else:
|
||||
self.end_of_agents.pop(agent_name, None)
|
||||
self.agent_states.pop(agent_name, None)
|
||||
|
||||
def populate_invocation_agent_states(self) -> None:
|
||||
"""Populates agent states for the current invocation if it is resumable.
|
||||
|
||||
@@ -388,7 +388,8 @@ class LlmAgent(BaseAgent):
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
yield self._create_agent_state_event(ctx, end_of_agent=True)
|
||||
ctx.set_agent_state(self.name, end_of_agent=True)
|
||||
yield self._create_agent_state_event(ctx)
|
||||
return
|
||||
|
||||
async with Aclosing(self._llm_flow.run_async(ctx)) as agen:
|
||||
@@ -399,7 +400,8 @@ class LlmAgent(BaseAgent):
|
||||
return
|
||||
|
||||
if ctx.is_resumable:
|
||||
yield self._create_agent_state_event(ctx, end_of_agent=True)
|
||||
ctx.set_agent_state(self.name, end_of_agent=True)
|
||||
yield self._create_agent_state_event(ctx)
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
|
||||
@@ -91,12 +91,13 @@ class LoopAgent(BaseAgent):
|
||||
current_sub_agent=sub_agent.name,
|
||||
times_looped=times_looped,
|
||||
)
|
||||
yield self._create_agent_state_event(ctx, agent_state=agent_state)
|
||||
ctx.set_agent_state(self.name, agent_state=agent_state)
|
||||
yield self._create_agent_state_event(ctx)
|
||||
|
||||
# Reset the sub-agent's state in the context to ensure that each
|
||||
# sub-agent starts fresh.
|
||||
if not is_resuming_at_current_agent:
|
||||
ctx.reset_agent_state(sub_agent.name)
|
||||
ctx.set_agent_state(sub_agent.name)
|
||||
is_resuming_at_current_agent = False
|
||||
|
||||
async with Aclosing(sub_agent.run_async(ctx)) as agen:
|
||||
@@ -119,7 +120,8 @@ class LoopAgent(BaseAgent):
|
||||
return
|
||||
|
||||
if ctx.is_resumable:
|
||||
yield self._create_agent_state_event(ctx, end_of_agent=True)
|
||||
ctx.set_agent_state(self.name, end_of_agent=True)
|
||||
yield self._create_agent_state_event(ctx)
|
||||
|
||||
def _get_start_state(
|
||||
self,
|
||||
|
||||
@@ -181,14 +181,15 @@ class ParallelAgent(BaseAgent):
|
||||
|
||||
agent_state = self._load_agent_state(ctx, BaseAgentState)
|
||||
if ctx.is_resumable and agent_state is None:
|
||||
yield self._create_agent_state_event(ctx, agent_state=BaseAgentState())
|
||||
ctx.set_agent_state(self.name, agent_state=BaseAgentState())
|
||||
yield self._create_agent_state_event(ctx)
|
||||
|
||||
agent_runs = []
|
||||
# Prepare and collect async generators for each sub-agent.
|
||||
for sub_agent in self.sub_agents:
|
||||
if agent_state is None:
|
||||
# Reset sub-agent state to make sure each sub-agent starts fresh.
|
||||
ctx.reset_agent_state(sub_agent.name)
|
||||
ctx.set_agent_state(sub_agent.name)
|
||||
|
||||
sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
|
||||
|
||||
@@ -215,8 +216,11 @@ class ParallelAgent(BaseAgent):
|
||||
return
|
||||
|
||||
# Once all sub-agents are done, mark the ParallelAgent as final.
|
||||
if ctx.is_resumable:
|
||||
yield self._create_agent_state_event(ctx, end_of_agent=True)
|
||||
if ctx.is_resumable and all(
|
||||
ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents
|
||||
):
|
||||
ctx.set_agent_state(self.name, end_of_agent=True)
|
||||
yield self._create_agent_state_event(ctx)
|
||||
|
||||
finally:
|
||||
for sub_agent_run in agent_runs:
|
||||
|
||||
@@ -70,11 +70,12 @@ class SequentialAgent(BaseAgent):
|
||||
# already been logged, so we should avoid yielding it again.
|
||||
if ctx.is_resumable:
|
||||
agent_state = SequentialAgentState(current_sub_agent=sub_agent.name)
|
||||
yield self._create_agent_state_event(ctx, agent_state=agent_state)
|
||||
ctx.set_agent_state(self.name, agent_state=agent_state)
|
||||
yield self._create_agent_state_event(ctx)
|
||||
|
||||
# Reset the sub-agent's state in the context to ensure that each
|
||||
# sub-agent starts fresh.
|
||||
ctx.reset_agent_state(sub_agent.name)
|
||||
ctx.set_agent_state(sub_agent.name)
|
||||
|
||||
async with Aclosing(sub_agent.run_async(ctx)) as agen:
|
||||
async for event in agen:
|
||||
@@ -90,7 +91,8 @@ class SequentialAgent(BaseAgent):
|
||||
resuming_sub_agent = False
|
||||
|
||||
if ctx.is_resumable:
|
||||
yield self._create_agent_state_event(ctx, end_of_agent=True)
|
||||
ctx.set_agent_state(self.name, end_of_agent=True)
|
||||
yield self._create_agent_state_event(ctx)
|
||||
|
||||
def _get_start_index(
|
||||
self,
|
||||
|
||||
@@ -929,9 +929,10 @@ async def test_create_agent_state_event():
|
||||
|
||||
ctx.branch = 'test_branch'
|
||||
|
||||
# Test case 1: with state
|
||||
# Test case 1: set agent state in context
|
||||
state = _TestAgentState(test_field='checkpoint')
|
||||
event = agent._create_agent_state_event(ctx, agent_state=state)
|
||||
ctx.set_agent_state(agent.name, agent_state=state)
|
||||
event = agent._create_agent_state_event(ctx)
|
||||
assert event is not None
|
||||
assert event.invocation_id == ctx.invocation_id
|
||||
assert event.author == agent.name
|
||||
@@ -941,8 +942,9 @@ async def test_create_agent_state_event():
|
||||
assert event.actions.agent_state == state.model_dump(mode='json')
|
||||
assert not event.actions.end_of_agent
|
||||
|
||||
# Test case 2: with end_of_agent
|
||||
event = agent._create_agent_state_event(ctx, end_of_agent=True)
|
||||
# Test case 2: set end_of_agent in context
|
||||
ctx.set_agent_state(agent.name, end_of_agent=True)
|
||||
event = agent._create_agent_state_event(ctx)
|
||||
assert event is not None
|
||||
assert event.invocation_id == ctx.invocation_id
|
||||
assert event.author == agent.name
|
||||
@@ -951,16 +953,8 @@ async def test_create_agent_state_event():
|
||||
assert event.actions.end_of_agent
|
||||
assert event.actions.agent_state is None
|
||||
|
||||
# Test case 3: with both state and end_of_agent
|
||||
state = _TestAgentState(test_field='checkpoint')
|
||||
event = agent._create_agent_state_event(
|
||||
ctx, agent_state=state, end_of_agent=True
|
||||
)
|
||||
assert event is not None
|
||||
assert event.actions.agent_state == state.model_dump(mode='json')
|
||||
assert event.actions.end_of_agent
|
||||
|
||||
# Test case 4: with neither
|
||||
# Test case 3: reset agent state and end_of_agent in context
|
||||
ctx.set_agent_state(agent.name)
|
||||
event = agent._create_agent_state_event(ctx)
|
||||
assert event is not None
|
||||
assert event.actions.agent_state is None
|
||||
|
||||
@@ -347,6 +347,49 @@ class TestInvocationContextWithAppResumablity:
|
||||
assert not invocation_context.agent_states
|
||||
assert not invocation_context.end_of_agents
|
||||
|
||||
def test_set_agent_state_with_end_of_agent_true(self):
|
||||
"""Tests that set_agent_state clears agent_state and sets end_of_agent to True."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
invocation_context.agent_states['agent1'] = {}
|
||||
invocation_context.end_of_agents['agent1'] = False
|
||||
|
||||
# Set state with end_of_agent=True, which should clear the existing
|
||||
# agent_state.
|
||||
invocation_context.set_agent_state('agent1', end_of_agent=True)
|
||||
assert 'agent1' not in invocation_context.agent_states
|
||||
assert invocation_context.end_of_agents['agent1']
|
||||
|
||||
def test_set_agent_state_with_agent_state(self):
|
||||
"""Tests that set_agent_state sets agent_state and sets end_of_agent to False."""
|
||||
agent_state = BaseAgentState()
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
invocation_context.end_of_agents['agent1'] = True
|
||||
|
||||
# Set state with agent_state=agent_state, which should set the agent_state
|
||||
# and reset the end_of_agent flag to False.
|
||||
invocation_context.set_agent_state('agent1', agent_state=agent_state)
|
||||
assert invocation_context.agent_states['agent1'] == agent_state.model_dump(
|
||||
mode='json'
|
||||
)
|
||||
assert invocation_context.end_of_agents['agent1'] is False
|
||||
|
||||
def test_reset_agent_state(self):
|
||||
"""Tests that set_agent_state clears agent_state and end_of_agent."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
invocation_context.agent_states['agent1'] = {}
|
||||
invocation_context.end_of_agents['agent1'] = True
|
||||
|
||||
# Reset state, which should clear the agent_state and end_of_agent flag.
|
||||
invocation_context.set_agent_state('agent1')
|
||||
assert 'agent1' not in invocation_context.agent_states
|
||||
assert 'agent1' not in invocation_context.end_of_agents
|
||||
|
||||
|
||||
class TestFindMatchingFunctionCall:
|
||||
"""Test suite for find_matching_function_call."""
|
||||
|
||||
@@ -52,6 +52,8 @@ class _TestingAgent(BaseAgent):
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
await asyncio.sleep(self.delay)
|
||||
yield self.event(ctx)
|
||||
if ctx.is_resumable:
|
||||
ctx.set_agent_state(self.name, end_of_agent=True)
|
||||
|
||||
|
||||
async def _create_parent_invocation_context(
|
||||
|
||||
Reference in New Issue
Block a user