diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 939334a3..9f473de8 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -184,21 +184,19 @@ class BaseAgent(BaseModel): def _create_agent_state_event( self, ctx: InvocationContext, - *, - agent_state: Optional[BaseAgentState] = None, - end_of_agent: bool = False, ) -> Event: - """Returns an event with agent state. + """Returns an event with current agent state set in the invocation context. Args: ctx: The invocation context. - agent_state: The agent state to checkpoint. - end_of_agent: Whether the agent is finished running. + + Returns: + An event with the current agent state set in the invocation context. """ event_actions = EventActions() - if agent_state: - event_actions.agent_state = agent_state.model_dump(mode='json') - if end_of_agent: + if (agent_state := ctx.agent_states.get(self.name)) is not None: + event_actions.agent_state = agent_state + if ctx.end_of_agents.get(self.name): event_actions.end_of_agent = True return Event( invocation_id=ctx.invocation_id, diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 66ecbaf4..de18e586 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -217,10 +217,37 @@ class InvocationContext(BaseModel): and self.resumability_config.is_resumable ) - def reset_agent_state(self, agent_name: str) -> None: - """Resets the state of an agent, allowing it to be re-run.""" - self.agent_states.pop(agent_name, None) - self.end_of_agents.pop(agent_name, None) + def set_agent_state( + self, + agent_name: str, + *, + agent_state: Optional[BaseAgentState] = None, + end_of_agent: bool = False, + ) -> None: + """Sets the state of an agent in this invocation. + + * If end_of_agent is True, will set the end_of_agent flag to True and + clear the agent_state. + * Otherwise, if agent_state is not None, will set the agent_state and + reset the end_of_agent flag to False. + * Otherwise, will clear the agent_state and end_of_agent flag, to allow the + agent to re-run. + + Args: + agent_name: The name of the agent. + agent_state: The state of the agent. Will be ignored if end_of_agent is + True. + end_of_agent: Whether the agent has finished running. + """ + if end_of_agent: + self.end_of_agents[agent_name] = True + self.agent_states.pop(agent_name, None) + elif agent_state is not None: + self.agent_states[agent_name] = agent_state.model_dump(mode="json") + self.end_of_agents[agent_name] = False + else: + self.end_of_agents.pop(agent_name, None) + self.agent_states.pop(agent_name, None) def populate_invocation_agent_states(self) -> None: """Populates agent states for the current invocation if it is resumable. diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 21c82774..3bfbfea0 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -388,7 +388,8 @@ class LlmAgent(BaseAgent): async for event in agen: yield event - yield self._create_agent_state_event(ctx, end_of_agent=True) + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) return async with Aclosing(self._llm_flow.run_async(ctx)) as agen: @@ -399,7 +400,8 @@ class LlmAgent(BaseAgent): return if ctx.is_resumable: - yield self._create_agent_state_event(ctx, end_of_agent=True) + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) @override async def _run_live_impl( diff --git a/src/google/adk/agents/loop_agent.py b/src/google/adk/agents/loop_agent.py index 9965f18f..6b8b7435 100644 --- a/src/google/adk/agents/loop_agent.py +++ b/src/google/adk/agents/loop_agent.py @@ -91,12 +91,13 @@ class LoopAgent(BaseAgent): current_sub_agent=sub_agent.name, times_looped=times_looped, ) - yield self._create_agent_state_event(ctx, agent_state=agent_state) + ctx.set_agent_state(self.name, agent_state=agent_state) + yield self._create_agent_state_event(ctx) # Reset the sub-agent's state in the context to ensure that each # sub-agent starts fresh. if not is_resuming_at_current_agent: - ctx.reset_agent_state(sub_agent.name) + ctx.set_agent_state(sub_agent.name) is_resuming_at_current_agent = False async with Aclosing(sub_agent.run_async(ctx)) as agen: @@ -119,7 +120,8 @@ class LoopAgent(BaseAgent): return if ctx.is_resumable: - yield self._create_agent_state_event(ctx, end_of_agent=True) + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) def _get_start_state( self, diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index 45dabfea..69ccd8a4 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -181,14 +181,15 @@ class ParallelAgent(BaseAgent): agent_state = self._load_agent_state(ctx, BaseAgentState) if ctx.is_resumable and agent_state is None: - yield self._create_agent_state_event(ctx, agent_state=BaseAgentState()) + ctx.set_agent_state(self.name, agent_state=BaseAgentState()) + yield self._create_agent_state_event(ctx) agent_runs = [] # Prepare and collect async generators for each sub-agent. for sub_agent in self.sub_agents: if agent_state is None: # Reset sub-agent state to make sure each sub-agent starts fresh. - ctx.reset_agent_state(sub_agent.name) + ctx.set_agent_state(sub_agent.name) sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx) @@ -215,8 +216,11 @@ class ParallelAgent(BaseAgent): return # Once all sub-agents are done, mark the ParallelAgent as final. - if ctx.is_resumable: - yield self._create_agent_state_event(ctx, end_of_agent=True) + if ctx.is_resumable and all( + ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents + ): + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) finally: for sub_agent_run in agent_runs: diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index 5cc5b654..417008f9 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -70,11 +70,12 @@ class SequentialAgent(BaseAgent): # already been logged, so we should avoid yielding it again. if ctx.is_resumable: agent_state = SequentialAgentState(current_sub_agent=sub_agent.name) - yield self._create_agent_state_event(ctx, agent_state=agent_state) + ctx.set_agent_state(self.name, agent_state=agent_state) + yield self._create_agent_state_event(ctx) # Reset the sub-agent's state in the context to ensure that each # sub-agent starts fresh. - ctx.reset_agent_state(sub_agent.name) + ctx.set_agent_state(sub_agent.name) async with Aclosing(sub_agent.run_async(ctx)) as agen: async for event in agen: @@ -90,7 +91,8 @@ class SequentialAgent(BaseAgent): resuming_sub_agent = False if ctx.is_resumable: - yield self._create_agent_state_event(ctx, end_of_agent=True) + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) def _get_start_index( self, diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 176fd5f8..663179f6 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -929,9 +929,10 @@ async def test_create_agent_state_event(): ctx.branch = 'test_branch' - # Test case 1: with state + # Test case 1: set agent state in context state = _TestAgentState(test_field='checkpoint') - event = agent._create_agent_state_event(ctx, agent_state=state) + ctx.set_agent_state(agent.name, agent_state=state) + event = agent._create_agent_state_event(ctx) assert event is not None assert event.invocation_id == ctx.invocation_id assert event.author == agent.name @@ -941,8 +942,9 @@ async def test_create_agent_state_event(): assert event.actions.agent_state == state.model_dump(mode='json') assert not event.actions.end_of_agent - # Test case 2: with end_of_agent - event = agent._create_agent_state_event(ctx, end_of_agent=True) + # Test case 2: set end_of_agent in context + ctx.set_agent_state(agent.name, end_of_agent=True) + event = agent._create_agent_state_event(ctx) assert event is not None assert event.invocation_id == ctx.invocation_id assert event.author == agent.name @@ -951,16 +953,8 @@ async def test_create_agent_state_event(): assert event.actions.end_of_agent assert event.actions.agent_state is None - # Test case 3: with both state and end_of_agent - state = _TestAgentState(test_field='checkpoint') - event = agent._create_agent_state_event( - ctx, agent_state=state, end_of_agent=True - ) - assert event is not None - assert event.actions.agent_state == state.model_dump(mode='json') - assert event.actions.end_of_agent - - # Test case 4: with neither + # Test case 3: reset agent state and end_of_agent in context + ctx.set_agent_state(agent.name) event = agent._create_agent_state_event(ctx) assert event is not None assert event.actions.agent_state is None diff --git a/tests/unittests/agents/test_invocation_context.py b/tests/unittests/agents/test_invocation_context.py index d2933483..2cdf38c9 100644 --- a/tests/unittests/agents/test_invocation_context.py +++ b/tests/unittests/agents/test_invocation_context.py @@ -347,6 +347,49 @@ class TestInvocationContextWithAppResumablity: assert not invocation_context.agent_states assert not invocation_context.end_of_agents + def test_set_agent_state_with_end_of_agent_true(self): + """Tests that set_agent_state clears agent_state and sets end_of_agent to True.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + invocation_context.agent_states['agent1'] = {} + invocation_context.end_of_agents['agent1'] = False + + # Set state with end_of_agent=True, which should clear the existing + # agent_state. + invocation_context.set_agent_state('agent1', end_of_agent=True) + assert 'agent1' not in invocation_context.agent_states + assert invocation_context.end_of_agents['agent1'] + + def test_set_agent_state_with_agent_state(self): + """Tests that set_agent_state sets agent_state and sets end_of_agent to False.""" + agent_state = BaseAgentState() + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + invocation_context.end_of_agents['agent1'] = True + + # Set state with agent_state=agent_state, which should set the agent_state + # and reset the end_of_agent flag to False. + invocation_context.set_agent_state('agent1', agent_state=agent_state) + assert invocation_context.agent_states['agent1'] == agent_state.model_dump( + mode='json' + ) + assert invocation_context.end_of_agents['agent1'] is False + + def test_reset_agent_state(self): + """Tests that set_agent_state clears agent_state and end_of_agent.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + invocation_context.agent_states['agent1'] = {} + invocation_context.end_of_agents['agent1'] = True + + # Reset state, which should clear the agent_state and end_of_agent flag. + invocation_context.set_agent_state('agent1') + assert 'agent1' not in invocation_context.agent_states + assert 'agent1' not in invocation_context.end_of_agents + class TestFindMatchingFunctionCall: """Test suite for find_matching_function_call.""" diff --git a/tests/unittests/agents/test_parallel_agent.py b/tests/unittests/agents/test_parallel_agent.py index 30965808..5b6c046f 100644 --- a/tests/unittests/agents/test_parallel_agent.py +++ b/tests/unittests/agents/test_parallel_agent.py @@ -52,6 +52,8 @@ class _TestingAgent(BaseAgent): ) -> AsyncGenerator[Event, None]: await asyncio.sleep(self.delay) yield self.event(ctx) + if ctx.is_resumable: + ctx.set_agent_state(self.name, end_of_agent=True) async def _create_parent_invocation_context(