diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 28d513d2..939334a3 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -163,44 +163,41 @@ class BaseAgent(BaseModel): self, ctx: InvocationContext, state_type: Type[AgentState], - default_state: AgentState, - ) -> tuple[AgentState, bool]: + ) -> Optional[AgentState]: """Loads the agent state from the invocation context, handling resumption. Args: ctx: The invocation context. state_type: The type of the agent state. - default_state: The default state to use if not resuming. Returns: - tuple[AgentState, bool]: The current state and a boolean indicating if - resuming. + The current state if resuming, otherwise None. """ + if not ctx.is_resumable: + return None + if self.name not in ctx.agent_states: - return default_state, False + return None else: - return state_type.model_validate(ctx.agent_states.get(self.name)), True + return state_type.model_validate(ctx.agent_states.get(self.name)) def _create_agent_state_event( self, ctx: InvocationContext, *, - state: Optional[BaseAgentState] = None, + agent_state: Optional[BaseAgentState] = None, end_of_agent: bool = False, ) -> Event: - """Creates an event for agent state. + """Returns an event with agent state. Args: ctx: The invocation context. - state: The agent state to checkpoint. + agent_state: The agent state to checkpoint. end_of_agent: Whether the agent is finished running. - - Returns: - An Event object representing the checkpoint. """ event_actions = EventActions() - if state: - event_actions.agent_state = state.model_dump(mode='json') + if agent_state: + event_actions.agent_state = agent_state.model_dump(mode='json') if end_of_agent: event_actions.end_of_agent = True return Event( diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 3cd7d9dd..27c4b6db 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -208,6 +208,14 @@ class InvocationContext(BaseModel): of this invocation. """ + @property + def is_resumable(self) -> bool: + """Returns whether the current invocation is resumable.""" + return ( + self.resumability_config is not None + and self.resumability_config.is_resumable + ) + def reset_agent_state(self, agent_name: str) -> None: """Resets the state of an agent, allowing it to be re-run.""" self.agent_states.pop(agent_name, None) @@ -284,10 +292,7 @@ class InvocationContext(BaseModel): Returns: Whether to pause the invocation right after this event. """ - if ( - not self.resumability_config - or not self.resumability_config.is_resumable - ): + if not self.is_resumable: return False if not event.long_running_tool_ids or not event.get_function_calls(): diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index 077412c3..5cc5b654 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -16,6 +16,7 @@ from __future__ import annotations +import logging from typing import AsyncGenerator from typing import ClassVar from typing import Type @@ -32,6 +33,8 @@ from .invocation_context import InvocationContext from .llm_agent import LlmAgent from .sequential_agent_config import SequentialAgentConfig +logger = logging.getLogger('google_adk.' + __name__) + @experimental class SequentialAgentState(BaseAgentState): @@ -51,12 +54,27 @@ class SequentialAgent(BaseAgent): async def _run_async_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: - # Skip if there is no sub-agent. if not self.sub_agents: return - for sub_agent in self.sub_agents: - pause_invocation = False + # Initialize or resume the execution state from the agent state. + agent_state = self._load_agent_state(ctx, SequentialAgentState) + start_index = self._get_start_index(agent_state) + + pause_invocation = False + resuming_sub_agent = agent_state is not None + for i in range(start_index, len(self.sub_agents)): + sub_agent = self.sub_agents[i] + if not resuming_sub_agent: + # If we are resuming from the current event, it means the same event has + # already been logged, so we should avoid yielding it again. + if ctx.is_resumable: + agent_state = SequentialAgentState(current_sub_agent=sub_agent.name) + yield self._create_agent_state_event(ctx, agent_state=agent_state) + + # Reset the sub-agent's state in the context to ensure that each + # sub-agent starts fresh. + ctx.reset_agent_state(sub_agent.name) async with Aclosing(sub_agent.run_async(ctx)) as agen: async for event in agen: @@ -64,11 +82,41 @@ class SequentialAgent(BaseAgent): if ctx.should_pause_invocation(event): pause_invocation = True - # Indicates the invocation should pause when receiving signal from - # the current sub_agent. + # Skip the rest of the sub-agents if the invocation is paused. if pause_invocation: return + # Reset the flag for the next sub-agent. + resuming_sub_agent = False + + if ctx.is_resumable: + yield self._create_agent_state_event(ctx, end_of_agent=True) + + def _get_start_index( + self, + agent_state: SequentialAgentState, + ) -> int: + """Calculates the start index for the sub-agent loop.""" + if not agent_state: + return 0 + + if not agent_state.current_sub_agent: + # This means the process was finished. + return len(self.sub_agents) + + try: + sub_agent_names = [sub_agent.name for sub_agent in self.sub_agents] + return sub_agent_names.index(agent_state.current_sub_agent) + except ValueError: + # A sub-agent was removed so the agent name is not found. + # For now, we restart from the beginning. + logger.warning( + 'Sub-agent %s was removed so the agent name is not found. Restarting' + ' from the beginning.', + agent_state.current_sub_agent, + ) + return 0 + @override async def _run_live_impl( self, ctx: InvocationContext diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index d6355ca3..b13c7cc8 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -350,6 +350,8 @@ def _get_current_turn_contents( # Find the latest event that starts the current turn and process from there for i in range(len(events) - 1, -1, -1): event = events[i] + if not event.content: + continue if event.author == 'user' or _is_other_agent_reply(agent_name, event): return _get_contents(current_branch, events[i:], agent_name) diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index e7ccc451..176fd5f8 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -26,6 +26,7 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgentState from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext +from google.adk.apps.app import ResumabilityConfig from google.adk.events.event import Event from google.adk.plugins.base_plugin import BasePlugin from google.adk.plugins.plugin_manager import PluginManager @@ -733,39 +734,6 @@ async def test_run_live_incomplete_agent(request: pytest.FixtureRequest): [e async for e in agent.run_live(parent_ctx)] -@pytest.mark.asyncio -async def test_create_agent_state_event(request: pytest.FixtureRequest): - # Arrange - agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') - ctx = await _create_parent_invocation_context( - request.function.__name__, agent, branch='test_branch' - ) - state = BaseAgentState() - - # Act - event = agent._create_agent_state_event(ctx, state=state) - - # Assert - assert event.invocation_id == ctx.invocation_id - assert event.author == agent.name - assert event.branch == 'test_branch' - assert event.actions is not None - assert event.actions.agent_state is not None - assert event.actions.agent_state == state.model_dump(mode='json') - assert not event.actions.end_of_agent - - # Act - event = agent._create_agent_state_event(ctx, end_of_agent=True) - - # Assert - assert event.invocation_id == ctx.invocation_id - assert event.author == agent.name - assert event.branch == 'test_branch' - assert event.actions is not None - assert event.actions.end_of_agent - assert event.actions.agent_state is None - - def test_set_parent_agent_for_sub_agents(request: pytest.FixtureRequest): sub_agents: list[BaseAgent] = [ _TestingAgent(name=f'{request.function.__name__}_sub_agent_1'), @@ -895,7 +863,7 @@ class _TestAgentState(BaseAgentState): @pytest.mark.asyncio -async def test_load_agent_state_no_resume(): +async def test_load_agent_state_not_resumable(): agent = BaseAgent(name='test_agent') session_service = InMemorySessionService() session = await session_service.create_session( @@ -907,14 +875,15 @@ async def test_load_agent_state_no_resume(): session=session, session_service=session_service, ) - default_state = _TestAgentState(test_field='default') - state, is_resuming = agent._load_agent_state( - ctx, _TestAgentState, default_state - ) + # Test case 1: resumability_config is None + state = agent._load_agent_state(ctx, _TestAgentState) + assert state is None - assert not is_resuming - assert state == default_state + # Test case 2: is_resumable is False + ctx.resumability_config = ResumabilityConfig(is_resumable=False) + state = agent._load_agent_state(ctx, _TestAgentState) + assert state is None @pytest.mark.asyncio @@ -929,13 +898,70 @@ async def test_load_agent_state_with_resume(): agent=agent, session=session, session_service=session_service, + resumability_config=ResumabilityConfig(is_resumable=True), ) + + # Test case 1: agent state not in context + state = agent._load_agent_state(ctx, _TestAgentState) + assert state is None + + # Test case 2: agent state in context persisted_state = _TestAgentState(test_field='resumed') ctx.agent_states[agent.name] = persisted_state.model_dump(mode='json') - state, is_resuming = agent._load_agent_state( - ctx, _TestAgentState, _TestAgentState() + state = agent._load_agent_state(ctx, _TestAgentState) + assert state == persisted_state + + +@pytest.mark.asyncio +async def test_create_agent_state_event(): + agent = BaseAgent(name='test_agent') + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + ctx = InvocationContext( + invocation_id='test_invocation', + agent=agent, + session=session, + session_service=session_service, ) - assert is_resuming - assert state == persisted_state + ctx.branch = 'test_branch' + + # Test case 1: with state + state = _TestAgentState(test_field='checkpoint') + event = agent._create_agent_state_event(ctx, agent_state=state) + assert event is not None + assert event.invocation_id == ctx.invocation_id + assert event.author == agent.name + assert event.branch == 'test_branch' + assert event.actions is not None + assert event.actions.agent_state is not None + assert event.actions.agent_state == state.model_dump(mode='json') + assert not event.actions.end_of_agent + + # Test case 2: with end_of_agent + event = agent._create_agent_state_event(ctx, end_of_agent=True) + assert event is not None + assert event.invocation_id == ctx.invocation_id + assert event.author == agent.name + assert event.branch == 'test_branch' + assert event.actions is not None + assert event.actions.end_of_agent + assert event.actions.agent_state is None + + # Test case 3: with both state and end_of_agent + state = _TestAgentState(test_field='checkpoint') + event = agent._create_agent_state_event( + ctx, agent_state=state, end_of_agent=True + ) + assert event is not None + assert event.actions.agent_state == state.model_dump(mode='json') + assert event.actions.end_of_agent + + # Test case 4: with neither + event = agent._create_agent_state_event(ctx) + assert event is not None + assert event.actions.agent_state is None + assert not event.actions.end_of_agent diff --git a/tests/unittests/agents/test_invocation_context.py b/tests/unittests/agents/test_invocation_context.py index 32a1f9bd..6379a760 100644 --- a/tests/unittests/agents/test_invocation_context.py +++ b/tests/unittests/agents/test_invocation_context.py @@ -206,3 +206,22 @@ class TestInvocationContextWithAppResumablity: assert not mock_invocation_context.should_pause_invocation( nonpausable_event ) + + def test_is_resumable_true(self): + """Tests that is_resumable is True when resumability is enabled.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + assert invocation_context.is_resumable + + def test_is_resumable_false(self): + """Tests that is_resumable is False when resumability is disabled.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=False) + ) + assert not invocation_context.is_resumable + + def test_is_resumable_no_config(self): + """Tests that is_resumable is False when no resumability config is set.""" + invocation_context = self._create_test_invocation_context(None) + assert not invocation_context.is_resumable diff --git a/tests/unittests/agents/test_llm_agent_include_contents.py b/tests/unittests/agents/test_llm_agent_include_contents.py index d4d76cf4..851474fc 100644 --- a/tests/unittests/agents/test_llm_agent_include_contents.py +++ b/tests/unittests/agents/test_llm_agent_include_contents.py @@ -219,9 +219,10 @@ async def test_include_contents_none_sequential_agents(): runner = testing_utils.InMemoryRunner(sequential_agent) events = runner.run("Original user request") - assert len(events) == 2 - assert events[0].author == "agent1" - assert events[1].author == "agent2" + simplified_events = [event for event in events if event.content] + assert len(simplified_events) == 2 + assert simplified_events[0].author == "agent1" + assert simplified_events[1].author == "agent2" # Agent1 sees original user request agent1_contents = testing_utils.simplify_contents( diff --git a/tests/unittests/agents/test_sequential_agent.py b/tests/unittests/agents/test_sequential_agent.py index 56af33f5..9703e0ca 100644 --- a/tests/unittests/agents/test_sequential_agent.py +++ b/tests/unittests/agents/test_sequential_agent.py @@ -19,6 +19,8 @@ from typing import AsyncGenerator from google.adk.agents.base_agent import BaseAgent from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.agents.sequential_agent import SequentialAgentState +from google.adk.apps import ResumabilityConfig from google.adk.events.event import Event from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.genai import types @@ -54,7 +56,7 @@ class _TestingAgent(BaseAgent): async def _create_parent_invocation_context( - test_name: str, agent: BaseAgent + test_name: str, agent: BaseAgent, resumable: bool = False ) -> InvocationContext: session_service = InMemorySessionService() session = await session_service.create_session( @@ -65,6 +67,7 @@ async def _create_parent_invocation_context( agent=agent, session=session, session_service=session_service, + resumability_config=ResumabilityConfig(is_resumable=resumable), ) @@ -105,6 +108,78 @@ async def test_run_async_skip_if_no_sub_agent(request: pytest.FixtureRequest): assert not events +@pytest.mark.asyncio +async def test_run_async_with_resumability(request: pytest.FixtureRequest): + agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1') + agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2') + sequential_agent = SequentialAgent( + name=f'{request.function.__name__}_test_agent', + sub_agents=[ + agent_1, + agent_2, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent, resumable=True + ) + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # 5 events: + # 1. SequentialAgent checkpoint event for agent 1 + # 2. Agent 1 event + # 3. SequentialAgent checkpoint event for agent 2 + # 4. Agent 2 event + # 5. SequentialAgent final checkpoint event + assert len(events) == 5 + assert events[0].author == sequential_agent.name + assert not events[0].actions.end_of_agent + assert events[0].actions.agent_state['current_sub_agent'] == agent_1.name + + assert events[1].author == agent_1.name + assert events[1].content.parts[0].text == f'Hello, async {agent_1.name}!' + + assert events[2].author == sequential_agent.name + assert not events[2].actions.end_of_agent + assert events[2].actions.agent_state['current_sub_agent'] == agent_2.name + + assert events[3].author == agent_2.name + assert events[3].content.parts[0].text == f'Hello, async {agent_2.name}!' + + assert events[4].author == sequential_agent.name + assert events[4].actions.end_of_agent + + +@pytest.mark.asyncio +async def test_resume_async(request: pytest.FixtureRequest): + agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1') + agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2') + sequential_agent = SequentialAgent( + name=f'{request.function.__name__}_test_agent', + sub_agents=[ + agent_1, + agent_2, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, sequential_agent, resumable=True + ) + parent_ctx.agent_states[sequential_agent.name] = SequentialAgentState( + current_sub_agent=agent_2.name + ).model_dump(mode='json') + + events = [e async for e in sequential_agent.run_async(parent_ctx)] + + # 2 events: + # 1. Agent 2 event + # 2. SequentialAgent final checkpoint event + assert len(events) == 2 + assert events[0].author == agent_2.name + assert events[0].content.parts[0].text == f'Hello, async {agent_2.name}!' + + assert events[1].author == sequential_agent.name + assert events[1].actions.end_of_agent + + @pytest.mark.asyncio async def test_run_live(request: pytest.FixtureRequest): agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')