You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Modify runner to support resuming an invocation (optionally with a function response)
PiperOrigin-RevId: 813008406
This commit is contained in:
committed by
Copybara-Service
parent
f005414895
commit
fbf75761bb
@@ -34,6 +34,7 @@ from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.session import Session
|
||||
from .active_streaming_tool import ActiveStreamingTool
|
||||
from .base_agent import BaseAgent
|
||||
from .base_agent import BaseAgentState
|
||||
from .context_cache_config import ContextCacheConfig
|
||||
from .live_request_queue import LiveRequestQueue
|
||||
from .run_config import RunConfig
|
||||
@@ -221,6 +222,37 @@ class InvocationContext(BaseModel):
|
||||
self.agent_states.pop(agent_name, None)
|
||||
self.end_of_agents.pop(agent_name, None)
|
||||
|
||||
def populate_invocation_agent_states(self) -> None:
|
||||
"""Populates agent states for the current invocation if it is resumable.
|
||||
|
||||
For history events that contain agent state information, set the
|
||||
agent_state and end_of_agent of the agent that generated the event.
|
||||
|
||||
For non-workflow agents, also set an initial agent_state if it has
|
||||
already generated some contents.
|
||||
"""
|
||||
if not self.is_resumable:
|
||||
return
|
||||
for event in self._get_events(current_invocation=True):
|
||||
if event.actions.end_of_agent:
|
||||
self.end_of_agents[event.author] = True
|
||||
# Delete agent_state when it is end
|
||||
self.agent_states.pop(event.author, None)
|
||||
elif event.actions.agent_state is not None:
|
||||
self.agent_states[event.author] = event.actions.agent_state
|
||||
# Invalidate the end_of_agent flag
|
||||
self.end_of_agents[event.author] = False
|
||||
elif (
|
||||
event.author != "user"
|
||||
and event.content
|
||||
and not self.agent_states.get(event.author)
|
||||
):
|
||||
# If the agent has generated some contents but its agent_state is not
|
||||
# set, set its agent_state to an empty agent_state.
|
||||
self.agent_states[event.author] = BaseAgentState()
|
||||
# Invalidate the end_of_agent flag
|
||||
self.end_of_agents[event.author] = False
|
||||
|
||||
def increment_llm_call_count(
|
||||
self,
|
||||
):
|
||||
|
||||
+122
-12
@@ -29,6 +29,7 @@ from google.genai import types
|
||||
|
||||
from .agents.active_streaming_tool import ActiveStreamingTool
|
||||
from .agents.base_agent import BaseAgent
|
||||
from .agents.base_agent import BaseAgentState
|
||||
from .agents.context_cache_config import ContextCacheConfig
|
||||
from .agents.invocation_context import InvocationContext
|
||||
from .agents.invocation_context import new_invocation_context_id
|
||||
@@ -272,7 +273,8 @@ class Runner:
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
new_message: types.Content,
|
||||
invocation_id: Optional[str] = None,
|
||||
new_message: Optional[types.Content] = None,
|
||||
state_delta: Optional[dict[str, Any]] = None,
|
||||
run_config: Optional[RunConfig] = None,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
@@ -281,6 +283,8 @@ class Runner:
|
||||
Args:
|
||||
user_id: The user ID of the session.
|
||||
session_id: The session ID of the session.
|
||||
invocation_id: The invocation ID of the session, set this to resume an
|
||||
interrupted invocation.
|
||||
new_message: A new message to append to the session.
|
||||
state_delta: Optional state changes to apply to the session.
|
||||
run_config: The run config for the agent.
|
||||
@@ -289,15 +293,17 @@ class Runner:
|
||||
The events generated by the agent.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session is not found.
|
||||
ValueError: If the session is not found; If both invocation_id and
|
||||
new_message are None.
|
||||
"""
|
||||
run_config = run_config or RunConfig()
|
||||
|
||||
if not new_message.role:
|
||||
if new_message and not new_message.role:
|
||||
new_message.role = 'user'
|
||||
|
||||
async def _run_with_trace(
|
||||
new_message: types.Content,
|
||||
new_message: Optional[types.Content] = None,
|
||||
invocation_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
with tracer.start_as_current_span('invocation'):
|
||||
session = await self.session_service.get_session(
|
||||
@@ -305,13 +311,39 @@ class Runner:
|
||||
)
|
||||
if not session:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
if not invocation_id and not new_message:
|
||||
raise ValueError('Both invocation_id and new_message are None.')
|
||||
|
||||
invocation_context = await self._setup_context_for_new_invocation(
|
||||
session=session,
|
||||
new_message=new_message,
|
||||
run_config=run_config,
|
||||
state_delta=state_delta,
|
||||
)
|
||||
if invocation_id:
|
||||
if (
|
||||
not self.resumability_config
|
||||
or not self.resumability_config.is_resumable
|
||||
):
|
||||
raise ValueError(
|
||||
f'invocation_id: {invocation_id} is provided but the app is not'
|
||||
' resumable.'
|
||||
)
|
||||
invocation_context = await self._setup_context_for_resumed_invocation(
|
||||
session=session,
|
||||
new_message=new_message,
|
||||
invocation_id=invocation_id,
|
||||
run_config=run_config,
|
||||
state_delta=state_delta,
|
||||
)
|
||||
if invocation_context.end_of_agents.get(self.agent.name):
|
||||
# Directly return if the root agent has already ended.
|
||||
# TODO: Handle the case where the invocation-to-resume started from
|
||||
# a sub_agent:
|
||||
# invocation1: root_agent -> sub_agent1
|
||||
# invocation2: sub_agent1 [paused][resume]
|
||||
return
|
||||
else:
|
||||
invocation_context = await self._setup_context_for_new_invocation(
|
||||
session=session,
|
||||
new_message=new_message, # new_message is not None.
|
||||
run_config=run_config,
|
||||
state_delta=state_delta,
|
||||
)
|
||||
|
||||
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
|
||||
async with Aclosing(ctx.agent.run_async(ctx)) as agen:
|
||||
@@ -329,7 +361,7 @@ class Runner:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
async with Aclosing(_run_with_trace(new_message)) as agen:
|
||||
async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
@@ -462,6 +494,11 @@ class Runner:
|
||||
author='user',
|
||||
content=new_message,
|
||||
)
|
||||
# If new_message is a function response, find the matching function call
|
||||
# and use its branch as the new event's branch.
|
||||
if function_call := invocation_context._find_matching_function_call(event):
|
||||
event.branch = function_call.branch
|
||||
|
||||
await self.session_service.append_event(session=session, event=event)
|
||||
|
||||
async def run_live(
|
||||
@@ -692,10 +729,82 @@ class Runner:
|
||||
invocation_context.agent = self._find_agent_to_run(session, self.agent)
|
||||
return invocation_context
|
||||
|
||||
async def _setup_context_for_resumed_invocation(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
new_message: Optional[types.Content],
|
||||
invocation_id: Optional[str],
|
||||
run_config: RunConfig,
|
||||
state_delta: Optional[dict[str, Any]],
|
||||
) -> InvocationContext:
|
||||
"""Sets up the context for a resumed invocation.
|
||||
|
||||
Args:
|
||||
session: The session to setup the invocation context for.
|
||||
new_message: The new message to process and append to the session.
|
||||
invocation_id: The invocation id to resume.
|
||||
run_config: The run config of the agent.
|
||||
state_delta: Optional state changes to apply to the session.
|
||||
|
||||
Returns:
|
||||
The invocation context for the resumed invocation.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session has no events to resume; If no user message is
|
||||
available for resuming the invocation; Or if the app is not resumable.
|
||||
"""
|
||||
if not session.events:
|
||||
raise ValueError(f'Session {session.id} has no events to resume.')
|
||||
|
||||
# Step 1: Maybe retrive a previous user message for the invocation.
|
||||
user_message = new_message or self._find_user_message_for_invocation(
|
||||
session.events, invocation_id
|
||||
)
|
||||
if not user_message:
|
||||
raise ValueError(
|
||||
f'No user message available for resuming invocation: {invocation_id}'
|
||||
)
|
||||
# Step 2: Create invocation context.
|
||||
invocation_context = self._new_invocation_context(
|
||||
session,
|
||||
new_message=user_message,
|
||||
run_config=run_config,
|
||||
invocation_id=invocation_id,
|
||||
)
|
||||
# Step 3: Maybe handle new message.
|
||||
if new_message:
|
||||
await self._handle_new_message(
|
||||
session=session,
|
||||
new_message=user_message,
|
||||
invocation_context=invocation_context,
|
||||
run_config=run_config,
|
||||
state_delta=state_delta,
|
||||
)
|
||||
# Step 4: Populate agent states for the current invocation.
|
||||
invocation_context.populate_invocation_agent_states()
|
||||
return invocation_context
|
||||
|
||||
def _find_user_message_for_invocation(
|
||||
self, events: list[Event], invocation_id: str
|
||||
) -> Optional[types.Content]:
|
||||
"""Finds the user message that started a specific invocation."""
|
||||
for event in events:
|
||||
if (
|
||||
event.invocation_id == invocation_id
|
||||
and event.author == 'user'
|
||||
and event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].text
|
||||
):
|
||||
return event.content
|
||||
return None
|
||||
|
||||
def _new_invocation_context(
|
||||
self,
|
||||
session: Session,
|
||||
*,
|
||||
invocation_id: Optional[str] = None,
|
||||
new_message: Optional[types.Content] = None,
|
||||
live_request_queue: Optional[LiveRequestQueue] = None,
|
||||
run_config: Optional[RunConfig] = None,
|
||||
@@ -704,6 +813,7 @@ class Runner:
|
||||
|
||||
Args:
|
||||
session: The session for the context.
|
||||
invocation_id: The invocation id for the context.
|
||||
new_message: The new message for the context.
|
||||
live_request_queue: The live request queue for the context.
|
||||
run_config: The run config for the context.
|
||||
@@ -712,7 +822,7 @@ class Runner:
|
||||
The new invocation context.
|
||||
"""
|
||||
run_config = run_config or RunConfig()
|
||||
invocation_id = new_invocation_context_id()
|
||||
invocation_id = invocation_id or new_invocation_context_id()
|
||||
|
||||
if run_config.support_cfc and isinstance(self.agent, LlmAgent):
|
||||
model_name = self.agent.canonical_model.model
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.base_agent import BaseAgentState
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.apps import ResumabilityConfig
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.events.event_actions import EventActions
|
||||
from google.adk.sessions.base_session_service import BaseSessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai.types import Content
|
||||
@@ -227,6 +229,124 @@ class TestInvocationContextWithAppResumablity:
|
||||
invocation_context = self._create_test_invocation_context(None)
|
||||
assert not invocation_context.is_resumable
|
||||
|
||||
def test_populate_invocation_agent_states_not_resumable(self):
|
||||
"""Tests that populate_invocation_agent_states does nothing if not resumable."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=False)
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='inv_1',
|
||||
author='agent1',
|
||||
actions=EventActions(end_of_agent=True, agent_state=None),
|
||||
)
|
||||
invocation_context.session.events = [event]
|
||||
invocation_context.populate_invocation_agent_states()
|
||||
assert not invocation_context.agent_states
|
||||
assert not invocation_context.end_of_agents
|
||||
|
||||
def test_populate_invocation_agent_states_end_of_agent(self):
|
||||
"""Tests that populate_invocation_agent_states handles end_of_agent."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='inv_1',
|
||||
author='agent1',
|
||||
actions=EventActions(end_of_agent=True, agent_state=None),
|
||||
)
|
||||
invocation_context.session.events = [event]
|
||||
invocation_context.populate_invocation_agent_states()
|
||||
assert not invocation_context.agent_states
|
||||
assert invocation_context.end_of_agents == {'agent1': True}
|
||||
|
||||
def test_populate_invocation_agent_states_with_agent_state(self):
|
||||
"""Tests that populate_invocation_agent_states handles agent_state."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='inv_1',
|
||||
author='agent1',
|
||||
actions=EventActions(
|
||||
end_of_agent=False,
|
||||
agent_state=BaseAgentState().model_dump(mode='json'),
|
||||
),
|
||||
)
|
||||
invocation_context.session.events = [event]
|
||||
invocation_context.populate_invocation_agent_states()
|
||||
assert invocation_context.agent_states == {'agent1': {}}
|
||||
assert invocation_context.end_of_agents == {'agent1': False}
|
||||
|
||||
def test_populate_invocation_agent_states_with_agent_state_and_end_of_agent(
|
||||
self,
|
||||
):
|
||||
"""Tests that populate_invocation_agent_states handles agent_state and end_of_agent."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='inv_1',
|
||||
author='agent1',
|
||||
actions=EventActions(
|
||||
end_of_agent=True,
|
||||
agent_state=BaseAgentState().model_dump(mode='json'),
|
||||
),
|
||||
)
|
||||
invocation_context.session.events = [event]
|
||||
invocation_context.populate_invocation_agent_states()
|
||||
# When both agent_state and end_of_agent are set, agent_state should be
|
||||
# cleared, as end_of_agent is of a higher priority.
|
||||
assert not invocation_context.agent_states
|
||||
assert invocation_context.end_of_agents == {'agent1': True}
|
||||
|
||||
def test_populate_invocation_agent_states_with_content_no_state(self):
|
||||
"""Tests that populate_invocation_agent_states creates default state."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='inv_1',
|
||||
author='agent1',
|
||||
actions=EventActions(end_of_agent=False, agent_state=None),
|
||||
content=Content(role='model', parts=[Part(text='hi')]),
|
||||
)
|
||||
invocation_context.session.events = [event]
|
||||
invocation_context.populate_invocation_agent_states()
|
||||
assert invocation_context.agent_states == {'agent1': BaseAgentState()}
|
||||
assert invocation_context.end_of_agents == {'agent1': False}
|
||||
|
||||
def test_populate_invocation_agent_states_user_message_event(self):
|
||||
"""Tests that populate_invocation_agent_states ignores user message events for default state."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='inv_1',
|
||||
author='user',
|
||||
actions=EventActions(end_of_agent=False, agent_state=None),
|
||||
content=Content(role='user', parts=[Part(text='hi')]),
|
||||
)
|
||||
invocation_context.session.events = [event]
|
||||
invocation_context.populate_invocation_agent_states()
|
||||
assert not invocation_context.agent_states
|
||||
assert not invocation_context.end_of_agents
|
||||
|
||||
def test_populate_invocation_agent_states_no_content(self):
|
||||
"""Tests that populate_invocation_agent_states ignores events with no content if no state."""
|
||||
invocation_context = self._create_test_invocation_context(
|
||||
ResumabilityConfig(is_resumable=True)
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='inv_1',
|
||||
author='agent1',
|
||||
actions=EventActions(end_of_agent=None, agent_state=None),
|
||||
content=None,
|
||||
)
|
||||
invocation_context.session.events = [event]
|
||||
invocation_context.populate_invocation_agent_states()
|
||||
assert not invocation_context.agent_states
|
||||
assert not invocation_context.end_of_agents
|
||||
|
||||
|
||||
class TestFindMatchingFunctionCall:
|
||||
"""Test suite for find_matching_function_call."""
|
||||
|
||||
@@ -19,6 +19,8 @@ from unittest import mock
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgentState
|
||||
from google.adk.apps.app import App
|
||||
from google.adk.apps.app import ResumabilityConfig
|
||||
from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
|
||||
@@ -186,6 +188,7 @@ class TestHITLConfirmationFlowWithSingleAgent(BaseHITLTest):
|
||||
ask_for_confirmation_function_call_id = (
|
||||
events[1].content.parts[0].function_call.id
|
||||
)
|
||||
invocation_id = events[1].invocation_id
|
||||
user_confirmation = testing_utils.UserContent(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
@@ -209,6 +212,8 @@ class TestHITLConfirmationFlowWithSingleAgent(BaseHITLTest):
|
||||
),
|
||||
(agent.name, "test llm response after final tool call"),
|
||||
]
|
||||
for event in events:
|
||||
assert event.invocation_id != invocation_id
|
||||
assert (
|
||||
testing_utils.simplify_events(copy.deepcopy(events))
|
||||
== expected_parts_final
|
||||
@@ -321,6 +326,7 @@ class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest):
|
||||
ask_for_confirmation_function_call_id = (
|
||||
events[1].content.parts[0].function_call.id
|
||||
)
|
||||
invocation_id = events[1].invocation_id
|
||||
custom_payload = {
|
||||
"test_custom_payload": {
|
||||
"int_field": 123,
|
||||
@@ -358,6 +364,8 @@ class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest):
|
||||
),
|
||||
(agent.name, "test llm response after final tool call"),
|
||||
]
|
||||
for event in events:
|
||||
assert event.invocation_id != invocation_id
|
||||
assert (
|
||||
testing_utils.simplify_events(copy.deepcopy(events))
|
||||
== expected_parts_final
|
||||
@@ -380,9 +388,6 @@ class TestHITLConfirmationFlowWithResumableApp:
|
||||
return [
|
||||
_create_llm_response_from_tools(tools),
|
||||
_create_llm_response_from_text("test llm response after tool call"),
|
||||
_create_llm_response_from_text(
|
||||
"test llm response after final tool call"
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
@@ -412,7 +417,7 @@ class TestHITLConfirmationFlowWithResumableApp:
|
||||
return testing_utils.InMemoryRunner(app=app)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_pause_on_request_confirmation(
|
||||
async def test_pause_and_resume_on_request_confirmation(
|
||||
self,
|
||||
runner: testing_utils.InMemoryRunner,
|
||||
agent: LlmAgent,
|
||||
@@ -449,3 +454,189 @@ class TestHITLConfirmationFlowWithResumableApp:
|
||||
),
|
||||
),
|
||||
]
|
||||
ask_for_confirmation_function_call_id = (
|
||||
events[1].content.parts[0].function_call.id
|
||||
)
|
||||
invocation_id = events[1].invocation_id
|
||||
user_confirmation = testing_utils.UserContent(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
id=ask_for_confirmation_function_call_id,
|
||||
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
response={"confirmed": True},
|
||||
)
|
||||
)
|
||||
)
|
||||
events = await runner.run_async(
|
||||
user_confirmation, invocation_id=invocation_id
|
||||
)
|
||||
expected_parts_final = [
|
||||
(
|
||||
agent.name,
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
name=agent.tools[0].name,
|
||||
response={"result": "confirmed=True"},
|
||||
)
|
||||
),
|
||||
),
|
||||
(agent.name, "test llm response after tool call"),
|
||||
(agent.name, testing_utils.END_OF_AGENT),
|
||||
]
|
||||
for event in events:
|
||||
assert event.invocation_id == invocation_id
|
||||
assert (
|
||||
testing_utils.simplify_resumable_app_events(copy.deepcopy(events))
|
||||
== expected_parts_final
|
||||
)
|
||||
|
||||
|
||||
class TestHITLConfirmationFlowWithSequentialAgentAndResumableApp:
|
||||
"""Tests the HITL confirmation flow with a resumable sequential agent app."""
|
||||
|
||||
@pytest.fixture
|
||||
def tools(self) -> list[FunctionTool]:
|
||||
"""Provides the tools for the agent."""
|
||||
return [FunctionTool(func=_test_request_confirmation_function)]
|
||||
|
||||
@pytest.fixture
|
||||
def llm_responses(
|
||||
self, tools: list[FunctionTool]
|
||||
) -> list[GenerateContentResponse]:
|
||||
"""Provides mock LLM responses for the tests."""
|
||||
return [
|
||||
_create_llm_response_from_tools(tools),
|
||||
_create_llm_response_from_text("test llm response after tool call"),
|
||||
_create_llm_response_from_text("test llm response from second agent"),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(
|
||||
self, llm_responses: list[GenerateContentResponse]
|
||||
) -> testing_utils.MockModel:
|
||||
"""Provides a mock model with predefined responses."""
|
||||
return testing_utils.MockModel(responses=llm_responses)
|
||||
|
||||
@pytest.fixture
|
||||
def agent(
|
||||
self, mock_model: testing_utils.MockModel, tools: list[FunctionTool]
|
||||
) -> SequentialAgent:
|
||||
"""Provides a single LlmAgent for the test."""
|
||||
return SequentialAgent(
|
||||
name="root_agent",
|
||||
sub_agents=[
|
||||
LlmAgent(name="agent1", model=mock_model, tools=tools),
|
||||
LlmAgent(name="agent2", model=mock_model, tools=[]),
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self, agent: SequentialAgent) -> testing_utils.InMemoryRunner:
|
||||
"""Provides an in-memory runner for the agent."""
|
||||
# Mark the app as resumable. So that the invocation will be paused after the
|
||||
# long running tool call.
|
||||
app = App(
|
||||
name="test_app",
|
||||
resumability_config=ResumabilityConfig(is_resumable=True),
|
||||
root_agent=agent,
|
||||
)
|
||||
return testing_utils.InMemoryRunner(app=app)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_and_resume_on_request_confirmation(
|
||||
self,
|
||||
runner: testing_utils.InMemoryRunner,
|
||||
agent: SequentialAgent,
|
||||
):
|
||||
"""Tests HITL flow where all tool calls are confirmed."""
|
||||
events = runner.run("test user query")
|
||||
sub_agent1 = agent.sub_agents[0]
|
||||
sub_agent2 = agent.sub_agents[1]
|
||||
|
||||
# Verify that the invocation is paused after the long running tool call.
|
||||
# So that no intermediate function response and llm response is generated.
|
||||
# And the second sub agent is not started.
|
||||
assert testing_utils.simplify_resumable_app_events(
|
||||
copy.deepcopy(events)
|
||||
) == [
|
||||
(
|
||||
agent.name,
|
||||
SequentialAgentState(current_sub_agent=sub_agent1.name).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
),
|
||||
(
|
||||
sub_agent1.name,
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name=sub_agent1.tools[0].name, args={}
|
||||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
sub_agent1.name,
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
args={
|
||||
"originalFunctionCall": {
|
||||
"name": sub_agent1.tools[0].name,
|
||||
"id": mock.ANY,
|
||||
"args": {},
|
||||
},
|
||||
"toolConfirmation": {
|
||||
"hint": "test hint for request_confirmation",
|
||||
"confirmed": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
),
|
||||
),
|
||||
]
|
||||
ask_for_confirmation_function_call_id = (
|
||||
events[2].content.parts[0].function_call.id
|
||||
)
|
||||
invocation_id = events[2].invocation_id
|
||||
|
||||
# Resume the invocation and confirm the tool call from sub_agent1, and
|
||||
# sub_agent2 will continue.
|
||||
user_confirmation = testing_utils.UserContent(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
id=ask_for_confirmation_function_call_id,
|
||||
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
|
||||
response={"confirmed": True},
|
||||
)
|
||||
)
|
||||
)
|
||||
events = await runner.run_async(
|
||||
user_confirmation, invocation_id=invocation_id
|
||||
)
|
||||
expected_parts_final = [
|
||||
(
|
||||
sub_agent1.name,
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
name=sub_agent1.tools[0].name,
|
||||
response={"result": "confirmed=True"},
|
||||
)
|
||||
),
|
||||
),
|
||||
(sub_agent1.name, "test llm response after tool call"),
|
||||
(sub_agent1.name, testing_utils.END_OF_AGENT),
|
||||
(
|
||||
agent.name,
|
||||
SequentialAgentState(current_sub_agent=sub_agent2.name).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
),
|
||||
(sub_agent2.name, "test llm response from second agent"),
|
||||
(sub_agent2.name, testing_utils.END_OF_AGENT),
|
||||
(agent.name, testing_utils.END_OF_AGENT),
|
||||
]
|
||||
for event in events:
|
||||
assert event.invocation_id == invocation_id
|
||||
assert (
|
||||
testing_utils.simplify_resumable_app_events(copy.deepcopy(events))
|
||||
== expected_parts_final
|
||||
)
|
||||
|
||||
@@ -273,11 +273,14 @@ class InMemoryRunner:
|
||||
)
|
||||
)
|
||||
|
||||
async def run_async(self, new_message: types.ContentUnion) -> list[Event]:
|
||||
async def run_async(
|
||||
self, new_message: types.ContentUnion, invocation_id: Optional[str] = None
|
||||
) -> list[Event]:
|
||||
events = []
|
||||
async for event in self.runner.run_async(
|
||||
user_id=self.session.user_id,
|
||||
session_id=self.session.id,
|
||||
invocation_id=invocation_id,
|
||||
new_message=get_user_content(new_message),
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
Reference in New Issue
Block a user