feat: Modify runner to support resuming an invocation (optionally with a function response)

PiperOrigin-RevId: 813008406
This commit is contained in:
Xinran (Sherry) Tang
2025-09-29 17:34:44 -07:00
committed by Copybara-Service
parent f005414895
commit fbf75761bb
5 changed files with 473 additions and 17 deletions
@@ -34,6 +34,7 @@ from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session
from .active_streaming_tool import ActiveStreamingTool
from .base_agent import BaseAgent
from .base_agent import BaseAgentState
from .context_cache_config import ContextCacheConfig
from .live_request_queue import LiveRequestQueue
from .run_config import RunConfig
@@ -221,6 +222,37 @@ class InvocationContext(BaseModel):
self.agent_states.pop(agent_name, None)
self.end_of_agents.pop(agent_name, None)
def populate_invocation_agent_states(self) -> None:
"""Populates agent states for the current invocation if it is resumable.
For history events that contain agent state information, set the
agent_state and end_of_agent of the agent that generated the event.
For non-workflow agents, also set an initial agent_state if it has
already generated some contents.
"""
if not self.is_resumable:
return
for event in self._get_events(current_invocation=True):
if event.actions.end_of_agent:
self.end_of_agents[event.author] = True
# Delete agent_state when it is end
self.agent_states.pop(event.author, None)
elif event.actions.agent_state is not None:
self.agent_states[event.author] = event.actions.agent_state
# Invalidate the end_of_agent flag
self.end_of_agents[event.author] = False
elif (
event.author != "user"
and event.content
and not self.agent_states.get(event.author)
):
# If the agent has generated some contents but its agent_state is not
# set, set its agent_state to an empty agent_state.
self.agent_states[event.author] = BaseAgentState()
# Invalidate the end_of_agent flag
self.end_of_agents[event.author] = False
def increment_llm_call_count(
self,
):
+122 -12
View File
@@ -29,6 +29,7 @@ from google.genai import types
from .agents.active_streaming_tool import ActiveStreamingTool
from .agents.base_agent import BaseAgent
from .agents.base_agent import BaseAgentState
from .agents.context_cache_config import ContextCacheConfig
from .agents.invocation_context import InvocationContext
from .agents.invocation_context import new_invocation_context_id
@@ -272,7 +273,8 @@ class Runner:
*,
user_id: str,
session_id: str,
new_message: types.Content,
invocation_id: Optional[str] = None,
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
) -> AsyncGenerator[Event, None]:
@@ -281,6 +283,8 @@ class Runner:
Args:
user_id: The user ID of the session.
session_id: The session ID of the session.
invocation_id: The invocation ID of the session, set this to resume an
interrupted invocation.
new_message: A new message to append to the session.
state_delta: Optional state changes to apply to the session.
run_config: The run config for the agent.
@@ -289,15 +293,17 @@ class Runner:
The events generated by the agent.
Raises:
ValueError: If the session is not found.
ValueError: If the session is not found; If both invocation_id and
new_message are None.
"""
run_config = run_config or RunConfig()
if not new_message.role:
if new_message and not new_message.role:
new_message.role = 'user'
async def _run_with_trace(
new_message: types.Content,
new_message: Optional[types.Content] = None,
invocation_id: Optional[str] = None,
) -> AsyncGenerator[Event, None]:
with tracer.start_as_current_span('invocation'):
session = await self.session_service.get_session(
@@ -305,13 +311,39 @@ class Runner:
)
if not session:
raise ValueError(f'Session not found: {session_id}')
if not invocation_id and not new_message:
raise ValueError('Both invocation_id and new_message are None.')
invocation_context = await self._setup_context_for_new_invocation(
session=session,
new_message=new_message,
run_config=run_config,
state_delta=state_delta,
)
if invocation_id:
if (
not self.resumability_config
or not self.resumability_config.is_resumable
):
raise ValueError(
f'invocation_id: {invocation_id} is provided but the app is not'
' resumable.'
)
invocation_context = await self._setup_context_for_resumed_invocation(
session=session,
new_message=new_message,
invocation_id=invocation_id,
run_config=run_config,
state_delta=state_delta,
)
if invocation_context.end_of_agents.get(self.agent.name):
# Directly return if the root agent has already ended.
# TODO: Handle the case where the invocation-to-resume started from
# a sub_agent:
# invocation1: root_agent -> sub_agent1
# invocation2: sub_agent1 [paused][resume]
return
else:
invocation_context = await self._setup_context_for_new_invocation(
session=session,
new_message=new_message, # new_message is not None.
run_config=run_config,
state_delta=state_delta,
)
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
async with Aclosing(ctx.agent.run_async(ctx)) as agen:
@@ -329,7 +361,7 @@ class Runner:
async for event in agen:
yield event
async with Aclosing(_run_with_trace(new_message)) as agen:
async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
async for event in agen:
yield event
@@ -462,6 +494,11 @@ class Runner:
author='user',
content=new_message,
)
# If new_message is a function response, find the matching function call
# and use its branch as the new event's branch.
if function_call := invocation_context._find_matching_function_call(event):
event.branch = function_call.branch
await self.session_service.append_event(session=session, event=event)
async def run_live(
@@ -692,10 +729,82 @@ class Runner:
invocation_context.agent = self._find_agent_to_run(session, self.agent)
return invocation_context
async def _setup_context_for_resumed_invocation(
self,
*,
session: Session,
new_message: Optional[types.Content],
invocation_id: Optional[str],
run_config: RunConfig,
state_delta: Optional[dict[str, Any]],
) -> InvocationContext:
"""Sets up the context for a resumed invocation.
Args:
session: The session to setup the invocation context for.
new_message: The new message to process and append to the session.
invocation_id: The invocation id to resume.
run_config: The run config of the agent.
state_delta: Optional state changes to apply to the session.
Returns:
The invocation context for the resumed invocation.
Raises:
ValueError: If the session has no events to resume; If no user message is
available for resuming the invocation; Or if the app is not resumable.
"""
if not session.events:
raise ValueError(f'Session {session.id} has no events to resume.')
# Step 1: Maybe retrive a previous user message for the invocation.
user_message = new_message or self._find_user_message_for_invocation(
session.events, invocation_id
)
if not user_message:
raise ValueError(
f'No user message available for resuming invocation: {invocation_id}'
)
# Step 2: Create invocation context.
invocation_context = self._new_invocation_context(
session,
new_message=user_message,
run_config=run_config,
invocation_id=invocation_id,
)
# Step 3: Maybe handle new message.
if new_message:
await self._handle_new_message(
session=session,
new_message=user_message,
invocation_context=invocation_context,
run_config=run_config,
state_delta=state_delta,
)
# Step 4: Populate agent states for the current invocation.
invocation_context.populate_invocation_agent_states()
return invocation_context
def _find_user_message_for_invocation(
self, events: list[Event], invocation_id: str
) -> Optional[types.Content]:
"""Finds the user message that started a specific invocation."""
for event in events:
if (
event.invocation_id == invocation_id
and event.author == 'user'
and event.content
and event.content.parts
and event.content.parts[0].text
):
return event.content
return None
def _new_invocation_context(
self,
session: Session,
*,
invocation_id: Optional[str] = None,
new_message: Optional[types.Content] = None,
live_request_queue: Optional[LiveRequestQueue] = None,
run_config: Optional[RunConfig] = None,
@@ -704,6 +813,7 @@ class Runner:
Args:
session: The session for the context.
invocation_id: The invocation id for the context.
new_message: The new message for the context.
live_request_queue: The live request queue for the context.
run_config: The run config for the context.
@@ -712,7 +822,7 @@ class Runner:
The new invocation context.
"""
run_config = run_config or RunConfig()
invocation_id = new_invocation_context_id()
invocation_id = invocation_id or new_invocation_context_id()
if run_config.support_cfc and isinstance(self.agent, LlmAgent):
model_name = self.agent.canonical_model.model
@@ -15,9 +15,11 @@
from unittest.mock import Mock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.base_agent import BaseAgentState
from google.adk.agents.invocation_context import InvocationContext
from google.adk.apps import ResumabilityConfig
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.sessions.base_session_service import BaseSessionService
from google.adk.sessions.session import Session
from google.genai.types import Content
@@ -227,6 +229,124 @@ class TestInvocationContextWithAppResumablity:
invocation_context = self._create_test_invocation_context(None)
assert not invocation_context.is_resumable
def test_populate_invocation_agent_states_not_resumable(self):
"""Tests that populate_invocation_agent_states does nothing if not resumable."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=False)
)
event = Event(
invocation_id='inv_1',
author='agent1',
actions=EventActions(end_of_agent=True, agent_state=None),
)
invocation_context.session.events = [event]
invocation_context.populate_invocation_agent_states()
assert not invocation_context.agent_states
assert not invocation_context.end_of_agents
def test_populate_invocation_agent_states_end_of_agent(self):
"""Tests that populate_invocation_agent_states handles end_of_agent."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
event = Event(
invocation_id='inv_1',
author='agent1',
actions=EventActions(end_of_agent=True, agent_state=None),
)
invocation_context.session.events = [event]
invocation_context.populate_invocation_agent_states()
assert not invocation_context.agent_states
assert invocation_context.end_of_agents == {'agent1': True}
def test_populate_invocation_agent_states_with_agent_state(self):
"""Tests that populate_invocation_agent_states handles agent_state."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
event = Event(
invocation_id='inv_1',
author='agent1',
actions=EventActions(
end_of_agent=False,
agent_state=BaseAgentState().model_dump(mode='json'),
),
)
invocation_context.session.events = [event]
invocation_context.populate_invocation_agent_states()
assert invocation_context.agent_states == {'agent1': {}}
assert invocation_context.end_of_agents == {'agent1': False}
def test_populate_invocation_agent_states_with_agent_state_and_end_of_agent(
self,
):
"""Tests that populate_invocation_agent_states handles agent_state and end_of_agent."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
event = Event(
invocation_id='inv_1',
author='agent1',
actions=EventActions(
end_of_agent=True,
agent_state=BaseAgentState().model_dump(mode='json'),
),
)
invocation_context.session.events = [event]
invocation_context.populate_invocation_agent_states()
# When both agent_state and end_of_agent are set, agent_state should be
# cleared, as end_of_agent is of a higher priority.
assert not invocation_context.agent_states
assert invocation_context.end_of_agents == {'agent1': True}
def test_populate_invocation_agent_states_with_content_no_state(self):
"""Tests that populate_invocation_agent_states creates default state."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
event = Event(
invocation_id='inv_1',
author='agent1',
actions=EventActions(end_of_agent=False, agent_state=None),
content=Content(role='model', parts=[Part(text='hi')]),
)
invocation_context.session.events = [event]
invocation_context.populate_invocation_agent_states()
assert invocation_context.agent_states == {'agent1': BaseAgentState()}
assert invocation_context.end_of_agents == {'agent1': False}
def test_populate_invocation_agent_states_user_message_event(self):
"""Tests that populate_invocation_agent_states ignores user message events for default state."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
event = Event(
invocation_id='inv_1',
author='user',
actions=EventActions(end_of_agent=False, agent_state=None),
content=Content(role='user', parts=[Part(text='hi')]),
)
invocation_context.session.events = [event]
invocation_context.populate_invocation_agent_states()
assert not invocation_context.agent_states
assert not invocation_context.end_of_agents
def test_populate_invocation_agent_states_no_content(self):
"""Tests that populate_invocation_agent_states ignores events with no content if no state."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
event = Event(
invocation_id='inv_1',
author='agent1',
actions=EventActions(end_of_agent=None, agent_state=None),
content=None,
)
invocation_context.session.events = [event]
invocation_context.populate_invocation_agent_states()
assert not invocation_context.agent_states
assert not invocation_context.end_of_agents
class TestFindMatchingFunctionCall:
"""Test suite for find_matching_function_call."""
@@ -19,6 +19,8 @@ from unittest import mock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.agents.sequential_agent import SequentialAgentState
from google.adk.apps.app import App
from google.adk.apps.app import ResumabilityConfig
from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
@@ -186,6 +188,7 @@ class TestHITLConfirmationFlowWithSingleAgent(BaseHITLTest):
ask_for_confirmation_function_call_id = (
events[1].content.parts[0].function_call.id
)
invocation_id = events[1].invocation_id
user_confirmation = testing_utils.UserContent(
Part(
function_response=FunctionResponse(
@@ -209,6 +212,8 @@ class TestHITLConfirmationFlowWithSingleAgent(BaseHITLTest):
),
(agent.name, "test llm response after final tool call"),
]
for event in events:
assert event.invocation_id != invocation_id
assert (
testing_utils.simplify_events(copy.deepcopy(events))
== expected_parts_final
@@ -321,6 +326,7 @@ class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest):
ask_for_confirmation_function_call_id = (
events[1].content.parts[0].function_call.id
)
invocation_id = events[1].invocation_id
custom_payload = {
"test_custom_payload": {
"int_field": 123,
@@ -358,6 +364,8 @@ class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest):
),
(agent.name, "test llm response after final tool call"),
]
for event in events:
assert event.invocation_id != invocation_id
assert (
testing_utils.simplify_events(copy.deepcopy(events))
== expected_parts_final
@@ -380,9 +388,6 @@ class TestHITLConfirmationFlowWithResumableApp:
return [
_create_llm_response_from_tools(tools),
_create_llm_response_from_text("test llm response after tool call"),
_create_llm_response_from_text(
"test llm response after final tool call"
),
]
@pytest.fixture
@@ -412,7 +417,7 @@ class TestHITLConfirmationFlowWithResumableApp:
return testing_utils.InMemoryRunner(app=app)
@pytest.mark.asyncio
def test_pause_on_request_confirmation(
async def test_pause_and_resume_on_request_confirmation(
self,
runner: testing_utils.InMemoryRunner,
agent: LlmAgent,
@@ -449,3 +454,189 @@ class TestHITLConfirmationFlowWithResumableApp:
),
),
]
ask_for_confirmation_function_call_id = (
events[1].content.parts[0].function_call.id
)
invocation_id = events[1].invocation_id
user_confirmation = testing_utils.UserContent(
Part(
function_response=FunctionResponse(
id=ask_for_confirmation_function_call_id,
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
response={"confirmed": True},
)
)
)
events = await runner.run_async(
user_confirmation, invocation_id=invocation_id
)
expected_parts_final = [
(
agent.name,
Part(
function_response=FunctionResponse(
name=agent.tools[0].name,
response={"result": "confirmed=True"},
)
),
),
(agent.name, "test llm response after tool call"),
(agent.name, testing_utils.END_OF_AGENT),
]
for event in events:
assert event.invocation_id == invocation_id
assert (
testing_utils.simplify_resumable_app_events(copy.deepcopy(events))
== expected_parts_final
)
class TestHITLConfirmationFlowWithSequentialAgentAndResumableApp:
"""Tests the HITL confirmation flow with a resumable sequential agent app."""
@pytest.fixture
def tools(self) -> list[FunctionTool]:
"""Provides the tools for the agent."""
return [FunctionTool(func=_test_request_confirmation_function)]
@pytest.fixture
def llm_responses(
self, tools: list[FunctionTool]
) -> list[GenerateContentResponse]:
"""Provides mock LLM responses for the tests."""
return [
_create_llm_response_from_tools(tools),
_create_llm_response_from_text("test llm response after tool call"),
_create_llm_response_from_text("test llm response from second agent"),
]
@pytest.fixture
def mock_model(
self, llm_responses: list[GenerateContentResponse]
) -> testing_utils.MockModel:
"""Provides a mock model with predefined responses."""
return testing_utils.MockModel(responses=llm_responses)
@pytest.fixture
def agent(
self, mock_model: testing_utils.MockModel, tools: list[FunctionTool]
) -> SequentialAgent:
"""Provides a single LlmAgent for the test."""
return SequentialAgent(
name="root_agent",
sub_agents=[
LlmAgent(name="agent1", model=mock_model, tools=tools),
LlmAgent(name="agent2", model=mock_model, tools=[]),
],
)
@pytest.fixture
def runner(self, agent: SequentialAgent) -> testing_utils.InMemoryRunner:
"""Provides an in-memory runner for the agent."""
# Mark the app as resumable. So that the invocation will be paused after the
# long running tool call.
app = App(
name="test_app",
resumability_config=ResumabilityConfig(is_resumable=True),
root_agent=agent,
)
return testing_utils.InMemoryRunner(app=app)
@pytest.mark.asyncio
async def test_pause_and_resume_on_request_confirmation(
self,
runner: testing_utils.InMemoryRunner,
agent: SequentialAgent,
):
"""Tests HITL flow where all tool calls are confirmed."""
events = runner.run("test user query")
sub_agent1 = agent.sub_agents[0]
sub_agent2 = agent.sub_agents[1]
# Verify that the invocation is paused after the long running tool call.
# So that no intermediate function response and llm response is generated.
# And the second sub agent is not started.
assert testing_utils.simplify_resumable_app_events(
copy.deepcopy(events)
) == [
(
agent.name,
SequentialAgentState(current_sub_agent=sub_agent1.name).model_dump(
mode="json"
),
),
(
sub_agent1.name,
Part(
function_call=FunctionCall(
name=sub_agent1.tools[0].name, args={}
)
),
),
(
sub_agent1.name,
Part(
function_call=FunctionCall(
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
args={
"originalFunctionCall": {
"name": sub_agent1.tools[0].name,
"id": mock.ANY,
"args": {},
},
"toolConfirmation": {
"hint": "test hint for request_confirmation",
"confirmed": False,
},
},
)
),
),
]
ask_for_confirmation_function_call_id = (
events[2].content.parts[0].function_call.id
)
invocation_id = events[2].invocation_id
# Resume the invocation and confirm the tool call from sub_agent1, and
# sub_agent2 will continue.
user_confirmation = testing_utils.UserContent(
Part(
function_response=FunctionResponse(
id=ask_for_confirmation_function_call_id,
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
response={"confirmed": True},
)
)
)
events = await runner.run_async(
user_confirmation, invocation_id=invocation_id
)
expected_parts_final = [
(
sub_agent1.name,
Part(
function_response=FunctionResponse(
name=sub_agent1.tools[0].name,
response={"result": "confirmed=True"},
)
),
),
(sub_agent1.name, "test llm response after tool call"),
(sub_agent1.name, testing_utils.END_OF_AGENT),
(
agent.name,
SequentialAgentState(current_sub_agent=sub_agent2.name).model_dump(
mode="json"
),
),
(sub_agent2.name, "test llm response from second agent"),
(sub_agent2.name, testing_utils.END_OF_AGENT),
(agent.name, testing_utils.END_OF_AGENT),
]
for event in events:
assert event.invocation_id == invocation_id
assert (
testing_utils.simplify_resumable_app_events(copy.deepcopy(events))
== expected_parts_final
)
+4 -1
View File
@@ -273,11 +273,14 @@ class InMemoryRunner:
)
)
async def run_async(self, new_message: types.ContentUnion) -> list[Event]:
async def run_async(
self, new_message: types.ContentUnion, invocation_id: Optional[str] = None
) -> list[Event]:
events = []
async for event in self.runner.run_async(
user_id=self.session.user_id,
session_id=self.session.id,
invocation_id=invocation_id,
new_message=get_user_content(new_message),
):
events.append(event)