From fbf75761bb8d89a70b32c43bbd3fa2f48b81d67c Mon Sep 17 00:00:00 2001 From: "Xinran (Sherry) Tang" Date: Mon, 29 Sep 2025 17:34:44 -0700 Subject: [PATCH] feat: Modify runner to support resuming an invocation (optionally with a function response) PiperOrigin-RevId: 813008406 --- src/google/adk/agents/invocation_context.py | 32 +++ src/google/adk/runners.py | 134 ++++++++++-- .../agents/test_invocation_context.py | 120 +++++++++++ .../runners/test_run_tool_confirmation.py | 199 +++++++++++++++++- tests/unittests/testing_utils.py | 5 +- 5 files changed, 473 insertions(+), 17 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 646fdf1a..66ecbaf4 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -34,6 +34,7 @@ from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session from .active_streaming_tool import ActiveStreamingTool from .base_agent import BaseAgent +from .base_agent import BaseAgentState from .context_cache_config import ContextCacheConfig from .live_request_queue import LiveRequestQueue from .run_config import RunConfig @@ -221,6 +222,37 @@ class InvocationContext(BaseModel): self.agent_states.pop(agent_name, None) self.end_of_agents.pop(agent_name, None) + def populate_invocation_agent_states(self) -> None: + """Populates agent states for the current invocation if it is resumable. + + For history events that contain agent state information, set the + agent_state and end_of_agent of the agent that generated the event. + + For non-workflow agents, also set an initial agent_state if it has + already generated some contents. + """ + if not self.is_resumable: + return + for event in self._get_events(current_invocation=True): + if event.actions.end_of_agent: + self.end_of_agents[event.author] = True + # Delete agent_state when it is end + self.agent_states.pop(event.author, None) + elif event.actions.agent_state is not None: + self.agent_states[event.author] = event.actions.agent_state + # Invalidate the end_of_agent flag + self.end_of_agents[event.author] = False + elif ( + event.author != "user" + and event.content + and not self.agent_states.get(event.author) + ): + # If the agent has generated some contents but its agent_state is not + # set, set its agent_state to an empty agent_state. + self.agent_states[event.author] = BaseAgentState() + # Invalidate the end_of_agent flag + self.end_of_agents[event.author] = False + def increment_llm_call_count( self, ): diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 85714636..e7066a0c 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -29,6 +29,7 @@ from google.genai import types from .agents.active_streaming_tool import ActiveStreamingTool from .agents.base_agent import BaseAgent +from .agents.base_agent import BaseAgentState from .agents.context_cache_config import ContextCacheConfig from .agents.invocation_context import InvocationContext from .agents.invocation_context import new_invocation_context_id @@ -272,7 +273,8 @@ class Runner: *, user_id: str, session_id: str, - new_message: types.Content, + invocation_id: Optional[str] = None, + new_message: Optional[types.Content] = None, state_delta: Optional[dict[str, Any]] = None, run_config: Optional[RunConfig] = None, ) -> AsyncGenerator[Event, None]: @@ -281,6 +283,8 @@ class Runner: Args: user_id: The user ID of the session. session_id: The session ID of the session. + invocation_id: The invocation ID of the session, set this to resume an + interrupted invocation. new_message: A new message to append to the session. state_delta: Optional state changes to apply to the session. run_config: The run config for the agent. @@ -289,15 +293,17 @@ class Runner: The events generated by the agent. Raises: - ValueError: If the session is not found. + ValueError: If the session is not found; If both invocation_id and + new_message are None. """ run_config = run_config or RunConfig() - if not new_message.role: + if new_message and not new_message.role: new_message.role = 'user' async def _run_with_trace( - new_message: types.Content, + new_message: Optional[types.Content] = None, + invocation_id: Optional[str] = None, ) -> AsyncGenerator[Event, None]: with tracer.start_as_current_span('invocation'): session = await self.session_service.get_session( @@ -305,13 +311,39 @@ class Runner: ) if not session: raise ValueError(f'Session not found: {session_id}') + if not invocation_id and not new_message: + raise ValueError('Both invocation_id and new_message are None.') - invocation_context = await self._setup_context_for_new_invocation( - session=session, - new_message=new_message, - run_config=run_config, - state_delta=state_delta, - ) + if invocation_id: + if ( + not self.resumability_config + or not self.resumability_config.is_resumable + ): + raise ValueError( + f'invocation_id: {invocation_id} is provided but the app is not' + ' resumable.' + ) + invocation_context = await self._setup_context_for_resumed_invocation( + session=session, + new_message=new_message, + invocation_id=invocation_id, + run_config=run_config, + state_delta=state_delta, + ) + if invocation_context.end_of_agents.get(self.agent.name): + # Directly return if the root agent has already ended. + # TODO: Handle the case where the invocation-to-resume started from + # a sub_agent: + # invocation1: root_agent -> sub_agent1 + # invocation2: sub_agent1 [paused][resume] + return + else: + invocation_context = await self._setup_context_for_new_invocation( + session=session, + new_message=new_message, # new_message is not None. + run_config=run_config, + state_delta=state_delta, + ) async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing(ctx.agent.run_async(ctx)) as agen: @@ -329,7 +361,7 @@ class Runner: async for event in agen: yield event - async with Aclosing(_run_with_trace(new_message)) as agen: + async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen: async for event in agen: yield event @@ -462,6 +494,11 @@ class Runner: author='user', content=new_message, ) + # If new_message is a function response, find the matching function call + # and use its branch as the new event's branch. + if function_call := invocation_context._find_matching_function_call(event): + event.branch = function_call.branch + await self.session_service.append_event(session=session, event=event) async def run_live( @@ -692,10 +729,82 @@ class Runner: invocation_context.agent = self._find_agent_to_run(session, self.agent) return invocation_context + async def _setup_context_for_resumed_invocation( + self, + *, + session: Session, + new_message: Optional[types.Content], + invocation_id: Optional[str], + run_config: RunConfig, + state_delta: Optional[dict[str, Any]], + ) -> InvocationContext: + """Sets up the context for a resumed invocation. + + Args: + session: The session to setup the invocation context for. + new_message: The new message to process and append to the session. + invocation_id: The invocation id to resume. + run_config: The run config of the agent. + state_delta: Optional state changes to apply to the session. + + Returns: + The invocation context for the resumed invocation. + + Raises: + ValueError: If the session has no events to resume; If no user message is + available for resuming the invocation; Or if the app is not resumable. + """ + if not session.events: + raise ValueError(f'Session {session.id} has no events to resume.') + + # Step 1: Maybe retrive a previous user message for the invocation. + user_message = new_message or self._find_user_message_for_invocation( + session.events, invocation_id + ) + if not user_message: + raise ValueError( + f'No user message available for resuming invocation: {invocation_id}' + ) + # Step 2: Create invocation context. + invocation_context = self._new_invocation_context( + session, + new_message=user_message, + run_config=run_config, + invocation_id=invocation_id, + ) + # Step 3: Maybe handle new message. + if new_message: + await self._handle_new_message( + session=session, + new_message=user_message, + invocation_context=invocation_context, + run_config=run_config, + state_delta=state_delta, + ) + # Step 4: Populate agent states for the current invocation. + invocation_context.populate_invocation_agent_states() + return invocation_context + + def _find_user_message_for_invocation( + self, events: list[Event], invocation_id: str + ) -> Optional[types.Content]: + """Finds the user message that started a specific invocation.""" + for event in events: + if ( + event.invocation_id == invocation_id + and event.author == 'user' + and event.content + and event.content.parts + and event.content.parts[0].text + ): + return event.content + return None + def _new_invocation_context( self, session: Session, *, + invocation_id: Optional[str] = None, new_message: Optional[types.Content] = None, live_request_queue: Optional[LiveRequestQueue] = None, run_config: Optional[RunConfig] = None, @@ -704,6 +813,7 @@ class Runner: Args: session: The session for the context. + invocation_id: The invocation id for the context. new_message: The new message for the context. live_request_queue: The live request queue for the context. run_config: The run config for the context. @@ -712,7 +822,7 @@ class Runner: The new invocation context. """ run_config = run_config or RunConfig() - invocation_id = new_invocation_context_id() + invocation_id = invocation_id or new_invocation_context_id() if run_config.support_cfc and isinstance(self.agent, LlmAgent): model_name = self.agent.canonical_model.model diff --git a/tests/unittests/agents/test_invocation_context.py b/tests/unittests/agents/test_invocation_context.py index 79664542..d2933483 100644 --- a/tests/unittests/agents/test_invocation_context.py +++ b/tests/unittests/agents/test_invocation_context.py @@ -15,9 +15,11 @@ from unittest.mock import Mock from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.base_agent import BaseAgentState from google.adk.agents.invocation_context import InvocationContext from google.adk.apps import ResumabilityConfig from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions from google.adk.sessions.base_session_service import BaseSessionService from google.adk.sessions.session import Session from google.genai.types import Content @@ -227,6 +229,124 @@ class TestInvocationContextWithAppResumablity: invocation_context = self._create_test_invocation_context(None) assert not invocation_context.is_resumable + def test_populate_invocation_agent_states_not_resumable(self): + """Tests that populate_invocation_agent_states does nothing if not resumable.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=False) + ) + event = Event( + invocation_id='inv_1', + author='agent1', + actions=EventActions(end_of_agent=True, agent_state=None), + ) + invocation_context.session.events = [event] + invocation_context.populate_invocation_agent_states() + assert not invocation_context.agent_states + assert not invocation_context.end_of_agents + + def test_populate_invocation_agent_states_end_of_agent(self): + """Tests that populate_invocation_agent_states handles end_of_agent.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + event = Event( + invocation_id='inv_1', + author='agent1', + actions=EventActions(end_of_agent=True, agent_state=None), + ) + invocation_context.session.events = [event] + invocation_context.populate_invocation_agent_states() + assert not invocation_context.agent_states + assert invocation_context.end_of_agents == {'agent1': True} + + def test_populate_invocation_agent_states_with_agent_state(self): + """Tests that populate_invocation_agent_states handles agent_state.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + event = Event( + invocation_id='inv_1', + author='agent1', + actions=EventActions( + end_of_agent=False, + agent_state=BaseAgentState().model_dump(mode='json'), + ), + ) + invocation_context.session.events = [event] + invocation_context.populate_invocation_agent_states() + assert invocation_context.agent_states == {'agent1': {}} + assert invocation_context.end_of_agents == {'agent1': False} + + def test_populate_invocation_agent_states_with_agent_state_and_end_of_agent( + self, + ): + """Tests that populate_invocation_agent_states handles agent_state and end_of_agent.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + event = Event( + invocation_id='inv_1', + author='agent1', + actions=EventActions( + end_of_agent=True, + agent_state=BaseAgentState().model_dump(mode='json'), + ), + ) + invocation_context.session.events = [event] + invocation_context.populate_invocation_agent_states() + # When both agent_state and end_of_agent are set, agent_state should be + # cleared, as end_of_agent is of a higher priority. + assert not invocation_context.agent_states + assert invocation_context.end_of_agents == {'agent1': True} + + def test_populate_invocation_agent_states_with_content_no_state(self): + """Tests that populate_invocation_agent_states creates default state.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + event = Event( + invocation_id='inv_1', + author='agent1', + actions=EventActions(end_of_agent=False, agent_state=None), + content=Content(role='model', parts=[Part(text='hi')]), + ) + invocation_context.session.events = [event] + invocation_context.populate_invocation_agent_states() + assert invocation_context.agent_states == {'agent1': BaseAgentState()} + assert invocation_context.end_of_agents == {'agent1': False} + + def test_populate_invocation_agent_states_user_message_event(self): + """Tests that populate_invocation_agent_states ignores user message events for default state.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + event = Event( + invocation_id='inv_1', + author='user', + actions=EventActions(end_of_agent=False, agent_state=None), + content=Content(role='user', parts=[Part(text='hi')]), + ) + invocation_context.session.events = [event] + invocation_context.populate_invocation_agent_states() + assert not invocation_context.agent_states + assert not invocation_context.end_of_agents + + def test_populate_invocation_agent_states_no_content(self): + """Tests that populate_invocation_agent_states ignores events with no content if no state.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + event = Event( + invocation_id='inv_1', + author='agent1', + actions=EventActions(end_of_agent=None, agent_state=None), + content=None, + ) + invocation_context.session.events = [event] + invocation_context.populate_invocation_agent_states() + assert not invocation_context.agent_states + assert not invocation_context.end_of_agents + class TestFindMatchingFunctionCall: """Test suite for find_matching_function_call.""" diff --git a/tests/unittests/runners/test_run_tool_confirmation.py b/tests/unittests/runners/test_run_tool_confirmation.py index 97e93f15..0896a4ca 100644 --- a/tests/unittests/runners/test_run_tool_confirmation.py +++ b/tests/unittests/runners/test_run_tool_confirmation.py @@ -19,6 +19,8 @@ from unittest import mock from google.adk.agents.base_agent import BaseAgent from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.agents.sequential_agent import SequentialAgentState from google.adk.apps.app import App from google.adk.apps.app import ResumabilityConfig from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME @@ -186,6 +188,7 @@ class TestHITLConfirmationFlowWithSingleAgent(BaseHITLTest): ask_for_confirmation_function_call_id = ( events[1].content.parts[0].function_call.id ) + invocation_id = events[1].invocation_id user_confirmation = testing_utils.UserContent( Part( function_response=FunctionResponse( @@ -209,6 +212,8 @@ class TestHITLConfirmationFlowWithSingleAgent(BaseHITLTest): ), (agent.name, "test llm response after final tool call"), ] + for event in events: + assert event.invocation_id != invocation_id assert ( testing_utils.simplify_events(copy.deepcopy(events)) == expected_parts_final @@ -321,6 +326,7 @@ class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest): ask_for_confirmation_function_call_id = ( events[1].content.parts[0].function_call.id ) + invocation_id = events[1].invocation_id custom_payload = { "test_custom_payload": { "int_field": 123, @@ -358,6 +364,8 @@ class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest): ), (agent.name, "test llm response after final tool call"), ] + for event in events: + assert event.invocation_id != invocation_id assert ( testing_utils.simplify_events(copy.deepcopy(events)) == expected_parts_final @@ -380,9 +388,6 @@ class TestHITLConfirmationFlowWithResumableApp: return [ _create_llm_response_from_tools(tools), _create_llm_response_from_text("test llm response after tool call"), - _create_llm_response_from_text( - "test llm response after final tool call" - ), ] @pytest.fixture @@ -412,7 +417,7 @@ class TestHITLConfirmationFlowWithResumableApp: return testing_utils.InMemoryRunner(app=app) @pytest.mark.asyncio - def test_pause_on_request_confirmation( + async def test_pause_and_resume_on_request_confirmation( self, runner: testing_utils.InMemoryRunner, agent: LlmAgent, @@ -449,3 +454,189 @@ class TestHITLConfirmationFlowWithResumableApp: ), ), ] + ask_for_confirmation_function_call_id = ( + events[1].content.parts[0].function_call.id + ) + invocation_id = events[1].invocation_id + user_confirmation = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=ask_for_confirmation_function_call_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"confirmed": True}, + ) + ) + ) + events = await runner.run_async( + user_confirmation, invocation_id=invocation_id + ) + expected_parts_final = [ + ( + agent.name, + Part( + function_response=FunctionResponse( + name=agent.tools[0].name, + response={"result": "confirmed=True"}, + ) + ), + ), + (agent.name, "test llm response after tool call"), + (agent.name, testing_utils.END_OF_AGENT), + ] + for event in events: + assert event.invocation_id == invocation_id + assert ( + testing_utils.simplify_resumable_app_events(copy.deepcopy(events)) + == expected_parts_final + ) + + +class TestHITLConfirmationFlowWithSequentialAgentAndResumableApp: + """Tests the HITL confirmation flow with a resumable sequential agent app.""" + + @pytest.fixture + def tools(self) -> list[FunctionTool]: + """Provides the tools for the agent.""" + return [FunctionTool(func=_test_request_confirmation_function)] + + @pytest.fixture + def llm_responses( + self, tools: list[FunctionTool] + ) -> list[GenerateContentResponse]: + """Provides mock LLM responses for the tests.""" + return [ + _create_llm_response_from_tools(tools), + _create_llm_response_from_text("test llm response after tool call"), + _create_llm_response_from_text("test llm response from second agent"), + ] + + @pytest.fixture + def mock_model( + self, llm_responses: list[GenerateContentResponse] + ) -> testing_utils.MockModel: + """Provides a mock model with predefined responses.""" + return testing_utils.MockModel(responses=llm_responses) + + @pytest.fixture + def agent( + self, mock_model: testing_utils.MockModel, tools: list[FunctionTool] + ) -> SequentialAgent: + """Provides a single LlmAgent for the test.""" + return SequentialAgent( + name="root_agent", + sub_agents=[ + LlmAgent(name="agent1", model=mock_model, tools=tools), + LlmAgent(name="agent2", model=mock_model, tools=[]), + ], + ) + + @pytest.fixture + def runner(self, agent: SequentialAgent) -> testing_utils.InMemoryRunner: + """Provides an in-memory runner for the agent.""" + # Mark the app as resumable. So that the invocation will be paused after the + # long running tool call. + app = App( + name="test_app", + resumability_config=ResumabilityConfig(is_resumable=True), + root_agent=agent, + ) + return testing_utils.InMemoryRunner(app=app) + + @pytest.mark.asyncio + async def test_pause_and_resume_on_request_confirmation( + self, + runner: testing_utils.InMemoryRunner, + agent: SequentialAgent, + ): + """Tests HITL flow where all tool calls are confirmed.""" + events = runner.run("test user query") + sub_agent1 = agent.sub_agents[0] + sub_agent2 = agent.sub_agents[1] + + # Verify that the invocation is paused after the long running tool call. + # So that no intermediate function response and llm response is generated. + # And the second sub agent is not started. + assert testing_utils.simplify_resumable_app_events( + copy.deepcopy(events) + ) == [ + ( + agent.name, + SequentialAgentState(current_sub_agent=sub_agent1.name).model_dump( + mode="json" + ), + ), + ( + sub_agent1.name, + Part( + function_call=FunctionCall( + name=sub_agent1.tools[0].name, args={} + ) + ), + ), + ( + sub_agent1.name, + Part( + function_call=FunctionCall( + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args={ + "originalFunctionCall": { + "name": sub_agent1.tools[0].name, + "id": mock.ANY, + "args": {}, + }, + "toolConfirmation": { + "hint": "test hint for request_confirmation", + "confirmed": False, + }, + }, + ) + ), + ), + ] + ask_for_confirmation_function_call_id = ( + events[2].content.parts[0].function_call.id + ) + invocation_id = events[2].invocation_id + + # Resume the invocation and confirm the tool call from sub_agent1, and + # sub_agent2 will continue. + user_confirmation = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=ask_for_confirmation_function_call_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"confirmed": True}, + ) + ) + ) + events = await runner.run_async( + user_confirmation, invocation_id=invocation_id + ) + expected_parts_final = [ + ( + sub_agent1.name, + Part( + function_response=FunctionResponse( + name=sub_agent1.tools[0].name, + response={"result": "confirmed=True"}, + ) + ), + ), + (sub_agent1.name, "test llm response after tool call"), + (sub_agent1.name, testing_utils.END_OF_AGENT), + ( + agent.name, + SequentialAgentState(current_sub_agent=sub_agent2.name).model_dump( + mode="json" + ), + ), + (sub_agent2.name, "test llm response from second agent"), + (sub_agent2.name, testing_utils.END_OF_AGENT), + (agent.name, testing_utils.END_OF_AGENT), + ] + for event in events: + assert event.invocation_id == invocation_id + assert ( + testing_utils.simplify_resumable_app_events(copy.deepcopy(events)) + == expected_parts_final + ) diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index 07b097ae..4b0a6ed1 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -273,11 +273,14 @@ class InMemoryRunner: ) ) - async def run_async(self, new_message: types.ContentUnion) -> list[Event]: + async def run_async( + self, new_message: types.ContentUnion, invocation_id: Optional[str] = None + ) -> list[Event]: events = [] async for event in self.runner.run_async( user_id=self.session.user_id, session_id=self.session.id, + invocation_id=invocation_id, new_message=get_user_content(new_message), ): events.append(event)