From b2b80e7fa0d9de01dc15cde5026097efd39c3db8 Mon Sep 17 00:00:00 2001 From: "Xinran (Sherry) Tang" Date: Thu, 25 Sep 2025 15:10:38 -0700 Subject: [PATCH] feat: Pause invocations on long running function calls for resumable apps PiperOrigin-RevId: 811518771 --- src/google/adk/agents/invocation_context.py | 39 ++ src/google/adk/agents/llm_agent.py | 2 + src/google/adk/agents/loop_agent.py | 11 + src/google/adk/agents/parallel_agent.py | 10 + src/google/adk/agents/sequential_agent.py | 9 + .../agents/test_invocation_context.py | 89 ++++ .../runners/test_pause_invocation.py | 472 ++++++++++++++++++ .../runners/test_run_tool_confirmation.py | 90 +++- 8 files changed, 719 insertions(+), 3 deletions(-) create mode 100644 tests/unittests/runners/test_pause_invocation.py diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 8726df3a..3cd7d9dd 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -260,6 +260,45 @@ class InvocationContext(BaseModel): results = [event for event in results if event.branch == self.branch] return results + def should_pause_invocation(self, event: Event) -> bool: + """Returns whether to pause the invocation right after this event. + + "Pausing" an invocation is different from "ending" an invocation. A paused + invocation can be resumed later, while an ended invocation cannot. + + Pausing the current agent's run will also pause all the agents that + depend on its execution, i.e. the subsequent agents in a workflow, and the + current agent's ancestors, etc. + + Note that parallel sibling agents won't be affected, but their common + ancestors will be paused after all the non-blocking sub-agents finished + running. + + Should meet all following conditions to pause an invocation: + 1. The app is resumable. + 2. The current event has a long running function call. + + Args: + event: The current event. + + Returns: + Whether to pause the invocation right after this event. + """ + if ( + not self.resumability_config + or not self.resumability_config.is_resumable + ): + return False + + if not event.long_running_tool_ids or not event.get_function_calls(): + return False + + for fc in event.get_function_calls(): + if fc.id in event.long_running_tool_ids: + return True + + return False + def new_invocation_context_id() -> str: return "e-" + str(uuid.uuid4()) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index be559222..ae38d6c9 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -341,6 +341,8 @@ class LlmAgent(BaseAgent): async for event in agen: self.__maybe_save_output_to_state(event) yield event + if ctx.should_pause_invocation(event): + return @override async def _run_live_impl( diff --git a/src/google/adk/agents/loop_agent.py b/src/google/adk/agents/loop_agent.py index 6311945e..75606834 100644 --- a/src/google/adk/agents/loop_agent.py +++ b/src/google/adk/agents/loop_agent.py @@ -70,15 +70,26 @@ class LoopAgent(BaseAgent): while not self.max_iterations or times_looped < self.max_iterations: for sub_agent in self.sub_agents: should_exit = False + pause_invocation = False + async with Aclosing(sub_agent.run_async(ctx)) as agen: async for event in agen: yield event if event.actions.escalate: should_exit = True + if ctx.should_pause_invocation(event): + pause_invocation = True + # Indicates that the loop agent should exist after running this + # sub-agent. if should_exit: return + # Indicates that the invocation should be paused after running this + # sub-agent. + if pause_invocation: + return + times_looped += 1 return diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index b1237a1c..dfe46810 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -181,16 +181,26 @@ class ParallelAgent(BaseAgent): ) for sub_agent in self.sub_agents ] + + pause_invocation = False try: # TODO remove if once Python <3.11 is no longer supported. if sys.version_info >= (3, 11): async with Aclosing(_merge_agent_run(agent_runs)) as agen: async for event in agen: yield event + if ctx.should_pause_invocation(event): + pause_invocation = True else: async with Aclosing(_merge_agent_run_pre_3_11(agent_runs)) as agen: async for event in agen: yield event + if ctx.should_pause_invocation(event): + pause_invocation = True + + if pause_invocation: + return + finally: for sub_agent_run in agent_runs: await sub_agent_run.aclose() diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index 265ec9c1..4085e72a 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -52,9 +52,18 @@ class SequentialAgent(BaseAgent): self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: for sub_agent in self.sub_agents: + pause_invocation = False + async with Aclosing(sub_agent.run_async(ctx)) as agen: async for event in agen: yield event + if ctx.should_pause_invocation(event): + pause_invocation = True + + # Indicates the invocation should pause when receiving signal from + # the current sub_agent. + if pause_invocation: + return @override async def _run_live_impl( diff --git a/tests/unittests/agents/test_invocation_context.py b/tests/unittests/agents/test_invocation_context.py index f85bee1f..32a1f9bd 100644 --- a/tests/unittests/agents/test_invocation_context.py +++ b/tests/unittests/agents/test_invocation_context.py @@ -16,11 +16,16 @@ from unittest.mock import Mock from google.adk.agents.base_agent import BaseAgent from google.adk.agents.invocation_context import InvocationContext +from google.adk.apps import ResumabilityConfig from google.adk.events.event import Event from google.adk.sessions.base_session_service import BaseSessionService from google.adk.sessions.session import Session +from google.genai.types import FunctionCall +from google.genai.types import Part import pytest +from .. import testing_utils + class TestInvocationContext: """Test suite for InvocationContext.""" @@ -117,3 +122,87 @@ class TestInvocationContext: current_branch=True, ) assert not events + + +class TestInvocationContextWithAppResumablity: + """Test suite for InvocationContext regarding app resumability.""" + + @pytest.fixture + def long_running_function_call(self) -> FunctionCall: + """A long running function call.""" + return FunctionCall( + id='tool_call_id_1', + name='long_running_function_call', + args={}, + ) + + @pytest.fixture + def event_to_pause(self, long_running_function_call) -> Event: + """An event with a long running function call.""" + return Event( + invocation_id='inv_1', + author='agent', + content=testing_utils.ModelContent( + [Part(function_call=long_running_function_call)] + ), + long_running_tool_ids=[long_running_function_call.id], + ) + + def _create_test_invocation_context( + self, resumability_config + ) -> InvocationContext: + """Create a mock invocation context for testing.""" + ctx = InvocationContext( + session_service=Mock(spec=BaseSessionService), + agent=Mock(spec=BaseAgent), + invocation_id='inv_1', + session=Mock(spec=Session), + resumability_config=resumability_config, + ) + return ctx + + def test_should_pause_invocation_with_resumable_app(self, event_to_pause): + """Tests should_pause_invocation with a resumable app.""" + mock_invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + + assert mock_invocation_context.should_pause_invocation(event_to_pause) + + def test_should_not_pause_invocation_with_non_resumable_app( + self, event_to_pause + ): + """Tests should_pause_invocation with a non-resumable app.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=False) + ) + + assert not invocation_context.should_pause_invocation(event_to_pause) + + def test_should_not_pause_invocation_with_no_long_running_tool_ids( + self, event_to_pause + ): + """Tests should_pause_invocation with no long running tools.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + nonpausable_event = event_to_pause.model_copy( + update={'long_running_tool_ids': []} + ) + + assert not invocation_context.should_pause_invocation(nonpausable_event) + + def test_should_not_pause_invocation_with_no_function_calls( + self, event_to_pause + ): + """Tests should_pause_invocation with a non-model event.""" + mock_invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + nonpausable_event = event_to_pause.model_copy( + update={'content': testing_utils.UserContent('test text part')} + ) + + assert not mock_invocation_context.should_pause_invocation( + nonpausable_event + ) diff --git a/tests/unittests/runners/test_pause_invocation.py b/tests/unittests/runners/test_pause_invocation.py new file mode 100644 index 00000000..97ebd7df --- /dev/null +++ b/tests/unittests/runners/test_pause_invocation.py @@ -0,0 +1,472 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the resumption flow with different agent structures.""" + +import asyncio +from typing import AsyncGenerator + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.loop_agent import LoopAgent +from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.apps.app import App +from google.adk.apps.app import ResumabilityConfig +from google.adk.events.event import Event +from google.adk.tools.exit_loop_tool import exit_loop +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.long_running_tool import LongRunningFunctionTool +from google.genai.types import Part +import pytest + +from .. import testing_utils + + +def _transfer_call_part(agent_name: str) -> Part: + return Part.from_function_call( + name="transfer_to_agent", args={"agent_name": agent_name} + ) + + +def test_tool() -> str: + return "" + + +class _TestingAgent(BaseAgent): + """A testing agent that generates an event after a delay.""" + + delay: float = 0 + """The delay before the agent generates an event.""" + + def event(self, ctx: InvocationContext): + return Event( + author=self.name, + branch=ctx.branch, + invocation_id=ctx.invocation_id, + content=testing_utils.ModelContent( + parts=[Part.from_text(text="Delayed message")] + ), + ) + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + await asyncio.sleep(self.delay) + yield self.event(ctx) + + +_TRANSFER_RESPONSE_PART = Part.from_function_response( + name="transfer_to_agent", response={"result": None} +) + + +class BasePauseInvocationTest: + """Base class for pausing invocation tests with common fixtures.""" + + @pytest.fixture + def agent(self) -> BaseAgent: + """Provides a BaseAgent for the test.""" + return BaseAgent(name="test_agent") + + @pytest.fixture + def app(self, agent: BaseAgent) -> App: + """Provides an App for the test.""" + return App( + name="InMemoryRunner", # Required for using TestInMemoryRunner. + root_agent=agent, + resumability_config=ResumabilityConfig(is_resumable=True), + ) + + @pytest.fixture + def runner(self, app: App) -> testing_utils.TestInMemoryRunner: + """Provides an in-memory runner for the agent.""" + return testing_utils.TestInMemoryRunner(app=app, app_name=None) + + @staticmethod + def mock_model(responses: list[Part]) -> testing_utils.MockModel: + """Provides a mock model with predefined responses.""" + return testing_utils.MockModel.create(responses=responses) + + +class TestPauseInvocationWithSingleLlmAgent(BasePauseInvocationTest): + """Tests the resumption flow with a single LlmAgent.""" + + @pytest.fixture + def agent(self) -> BaseAgent: + """Provides a BaseAgent for the test.""" + + def test_tool() -> str: + return "" + + return LlmAgent( + name="root_agent", + model=self.mock_model( + responses=[Part.from_function_call(name="test_tool", args={})] + ), + tools=[LongRunningFunctionTool(func=test_tool)], + ) + + @pytest.mark.asyncio + async def test_pause_on_long_running_function_call( + self, + runner: testing_utils.TestInMemoryRunner, + ): + """Tests that a single LlmAgent pauses on long running function call.""" + assert testing_utils.simplify_events( + await runner.run_async_with_new_session("test") + ) == [ + ("root_agent", Part.from_function_call(name="test_tool", args={})), + ] + + +class TestPauseInvocationWithSequentialAgent(BasePauseInvocationTest): + """Tests pausing invocation with a SequentialAgent.""" + + @pytest.fixture + def agent(self) -> BaseAgent: + """Provides a BaseAgent for the test.""" + sub_agent1 = LlmAgent( + name="sub_agent_1", + model=self.mock_model( + responses=[Part.from_function_call(name="test_tool", args={})] + ), + tools=[LongRunningFunctionTool(func=test_tool)], + ) + sub_agent2 = LlmAgent( + name="sub_agent_2", + model=self.mock_model( + responses=[Part.from_function_call(name="test_tool", args={})] + ), + tools=[LongRunningFunctionTool(func=test_tool)], + ) + return SequentialAgent( + name="root_agent", + sub_agents=[sub_agent1, sub_agent2], + ) + + @pytest.mark.asyncio + async def test_pause_first_agent_on_long_running_function_call( + self, + runner: testing_utils.TestInMemoryRunner, + ): + """Tests that a single LlmAgent pauses on long running function call.""" + assert testing_utils.simplify_events( + await runner.run_async_with_new_session("test") + ) == [ + ("sub_agent_1", Part.from_function_call(name="test_tool", args={})), + ] + + @pytest.mark.asyncio + async def test_pause_second_agent_on_long_running_function_call( + self, + runner: testing_utils.TestInMemoryRunner, + ): + """Tests that a single LlmAgent pauses on long running function call.""" + # Change the base sequential agent, so that the first agent does not pause. + runner.agent.sub_agents[0].tools = [FunctionTool(func=test_tool)] + runner.agent.sub_agents[0].model = self.mock_model( + responses=[ + Part.from_function_call(name="test_tool", args={}), + Part.from_text(text="model response after tool call"), + ] + ) + assert testing_utils.simplify_events( + await runner.run_async_with_new_session("test") + ) == [ + ("sub_agent_1", Part.from_function_call(name="test_tool", args={})), + ( + "sub_agent_1", + Part.from_function_response( + name="test_tool", response={"result": ""} + ), + ), + ("sub_agent_1", "model response after tool call"), + ("sub_agent_2", Part.from_function_call(name="test_tool", args={})), + ] + + +class TestPauseInvocationWithParallelAgent(BasePauseInvocationTest): + """Tests pausing invocation with a ParallelAgent.""" + + @pytest.fixture + def agent(self) -> BaseAgent: + """Provides a BaseAgent for the test.""" + sub_agent1 = LlmAgent( + name="sub_agent_1", + model=self.mock_model( + responses=[Part.from_function_call(name="test_tool", args={})] + ), + tools=[LongRunningFunctionTool(func=test_tool)], + ) + sub_agent2 = _TestingAgent( + name="sub_agent_2", + delay=0.5, + ) + return ParallelAgent( + name="root_agent", + sub_agents=[sub_agent1, sub_agent2], + ) + + @pytest.mark.asyncio + async def test_pause_on_long_running_function_call( + self, + runner: testing_utils.TestInMemoryRunner, + ): + """Tests that a ParallelAgent pauses on long running function call.""" + assert testing_utils.simplify_events( + await runner.run_async_with_new_session("test") + ) == [ + ("sub_agent_1", Part.from_function_call(name="test_tool", args={})), + ("sub_agent_2", "Delayed message"), + ] + + +class TestPauseInvocationWithNestedParallelAgent(BasePauseInvocationTest): + """Tests pausing invocation with a nested ParallelAgent.""" + + @pytest.fixture + def agent(self) -> BaseAgent: + """Provides a BaseAgent for the test.""" + nested_sub_agent_1 = LlmAgent( + name="nested_sub_agent_1", + model=self.mock_model( + responses=[Part.from_function_call(name="test_tool", args={})] + ), + tools=[LongRunningFunctionTool(func=test_tool)], + ) + nested_sub_agent_2 = _TestingAgent( + name="nested_sub_agent_2", + delay=0.5, + ) + nested_parallel_agent = ParallelAgent( + name="nested_parallel_agent", + sub_agents=[nested_sub_agent_1, nested_sub_agent_2], + ) + sub_agent_1 = _TestingAgent( + name="sub_agent_1", + delay=0.5, + ) + return ParallelAgent( + name="root_agent", + sub_agents=[sub_agent_1, nested_parallel_agent], + ) + + @pytest.mark.asyncio + async def test_pause_on_long_running_function_call_in_nested_agent( + self, + runner: testing_utils.TestInMemoryRunner, + ): + """Tests that a nested ParallelAgent pauses on long running function call.""" + assert testing_utils.simplify_events( + await runner.run_async_with_new_session("test") + ) == [ + ( + "nested_sub_agent_1", + Part.from_function_call(name="test_tool", args={}), + ), + ("sub_agent_1", "Delayed message"), + ("nested_sub_agent_2", "Delayed message"), + ] + + @pytest.mark.asyncio + async def test_pause_on_multiple_long_running_function_calls( + self, + runner: testing_utils.TestInMemoryRunner, + ): + """Tests that a ParallelAgent pauses on long running function calls.""" + runner.agent.sub_agents[0] = LlmAgent( + name="sub_agent_1", + model=self.mock_model( + responses=[ + Part.from_function_call(name="test_tool", args={}), + Part.from_function_call(name="test_tool", args={}), + ] + ), + tools=[LongRunningFunctionTool(func=test_tool)], + ) + simplified_events = testing_utils.simplify_events( + await runner.run_async_with_new_session("test") + ) + assert len(simplified_events) == 3 + assert ( + "sub_agent_1", + Part.from_function_call(name="test_tool", args={}), + ) in simplified_events + assert ( + "nested_sub_agent_1", + Part.from_function_call(name="test_tool", args={}), + ) in simplified_events + + +class TestPauseInvocationWithLoopAgent(BasePauseInvocationTest): + """Tests pausing invocation with a LoopAgent.""" + + @pytest.fixture + def agent(self) -> BaseAgent: + """Provides a BaseAgent for the test.""" + sub_agent_1 = LlmAgent( + name="sub_agent_1", + model=self.mock_model( + responses=[ + Part.from_text(text="sub agent 1 response"), + ] + ), + ) + sub_agent_2 = LlmAgent( + name="sub_agent_2", + model=self.mock_model( + responses=[ + Part.from_function_call(name="test_tool", args={}), + ] + ), + tools=[LongRunningFunctionTool(func=test_tool)], + ) + sub_agent_3 = LlmAgent( + name="sub_agent_3", + model=self.mock_model( + responses=[ + Part.from_function_call(name="exit_loop", args={}), + ] + ), + tools=[exit_loop], + ) + return LoopAgent( + name="root_agent", + sub_agents=[sub_agent_1, sub_agent_2, sub_agent_3], + max_iterations=2, + ) + + @pytest.mark.asyncio + async def test_pause_on_long_running_function_call_in_loop( + self, + runner: testing_utils.TestInMemoryRunner, + ): + """Tests that a LoopAgent pauses on long running function call.""" + assert testing_utils.simplify_events( + await runner.run_async_with_new_session("test") + ) == [ + ("sub_agent_1", "sub agent 1 response"), + ("sub_agent_2", Part.from_function_call(name="test_tool", args={})), + ] + + +class TestPauseInvocationWithLlmAgentTree(BasePauseInvocationTest): + """Tests the pausing invocation with a tree of LlmAgents.""" + + @pytest.fixture + def agent(self) -> LlmAgent: + """Provides an LlmAgent with sub-agents for the test.""" + sub_llm_agent_1 = LlmAgent( + name="sub_llm_agent_1", + model=self.mock_model( + responses=[ + _transfer_call_part("sub_llm_agent_2"), + "llm response not used", + ] + ), + ) + sub_llm_agent_2 = LlmAgent( + name="sub_llm_agent_2", + model=self.mock_model( + responses=[ + Part.from_function_call(name="test_tool", args={}), + "llm response not used", + ] + ), + tools=[LongRunningFunctionTool(func=test_tool)], + ) + return LlmAgent( + name="root_agent", + model=self.mock_model( + responses=[ + _transfer_call_part("sub_llm_agent_1"), + "llm response not used", + ] + ), + sub_agents=[sub_llm_agent_1, sub_llm_agent_2], + ) + + @pytest.mark.asyncio + async def test_pause_on_transfer_call_part( + self, + runner: testing_utils.TestInMemoryRunner, + ): + """Tests that a tree of resumable LlmAgents yields checkpoint events.""" + assert testing_utils.simplify_events( + await runner.run_async_with_new_session("test") + ) == [ + ("root_agent", _transfer_call_part("sub_llm_agent_1")), + ("root_agent", _TRANSFER_RESPONSE_PART), + ("sub_llm_agent_1", _transfer_call_part("sub_llm_agent_2")), + ("sub_llm_agent_1", _TRANSFER_RESPONSE_PART), + ("sub_llm_agent_2", Part.from_function_call(name="test_tool", args={})), + ] + + +class TestPauseInvocationWithWithTransferLoop(BasePauseInvocationTest): + """Tests the pausing the invocation when the agent transfer forms a loop.""" + + @pytest.fixture + def agent(self) -> LlmAgent: + """Provides an LlmAgent with sub-agents for the test.""" + sub_llm_agent_1 = LlmAgent( + name="sub_llm_agent_1", + model=self.mock_model( + responses=[ + _transfer_call_part("sub_llm_agent_2"), + "llm response not used", + ] + ), + ) + sub_llm_agent_2 = LlmAgent( + name="sub_llm_agent_2", + model=self.mock_model( + responses=[ + _transfer_call_part("root_agent"), + "llm response not used", + ] + ), + ) + return LlmAgent( + name="root_agent", + model=self.mock_model( + responses=[ + _transfer_call_part("sub_llm_agent_1"), + Part.from_function_call(name="test_tool", args={}), + "llm response not used", + ] + ), + sub_agents=[sub_llm_agent_1, sub_llm_agent_2], + tools=[LongRunningFunctionTool(func=test_tool)], + ) + + @pytest.mark.asyncio + async def test_agent_tree_yields_checkpoints( + self, + runner: testing_utils.TestInMemoryRunner, + ): + """Tests that a tree of resumable LlmAgents yields checkpoint events.""" + assert testing_utils.simplify_events( + await runner.run_async_with_new_session("test") + ) == [ + ("root_agent", _transfer_call_part("sub_llm_agent_1")), + ("root_agent", _TRANSFER_RESPONSE_PART), + ("sub_llm_agent_1", _transfer_call_part("sub_llm_agent_2")), + ("sub_llm_agent_1", _TRANSFER_RESPONSE_PART), + ("sub_llm_agent_2", _transfer_call_part("root_agent")), + ("sub_llm_agent_2", _TRANSFER_RESPONSE_PART), + ("root_agent", Part.from_function_call(name="test_tool", args={})), + ] diff --git a/tests/unittests/runners/test_run_tool_confirmation.py b/tests/unittests/runners/test_run_tool_confirmation.py index c89cc22f..4fba2c70 100644 --- a/tests/unittests/runners/test_run_tool_confirmation.py +++ b/tests/unittests/runners/test_run_tool_confirmation.py @@ -19,9 +19,8 @@ from unittest import mock from google.adk.agents.base_agent import BaseAgent from google.adk.agents.llm_agent import LlmAgent -from google.adk.agents.loop_agent import LoopAgent -from google.adk.agents.parallel_agent import ParallelAgent -from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.apps.app import App +from google.adk.apps.app import ResumabilityConfig from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME from google.adk.tools.function_tool import FunctionTool from google.adk.tools.tool_context import ToolContext @@ -363,3 +362,88 @@ class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest): testing_utils.simplify_events(copy.deepcopy(events)) == expected_parts_final ) + + +class TestHITLConfirmationFlowWithResumableApp: + """Tests the HITL confirmation flow with a resumable app.""" + + @pytest.fixture + def tools(self) -> list[FunctionTool]: + """Provides the tools for the agent.""" + return [FunctionTool(func=_test_request_confirmation_function)] + + @pytest.fixture + def llm_responses( + self, tools: list[FunctionTool] + ) -> list[GenerateContentResponse]: + """Provides mock LLM responses for the tests.""" + return [ + _create_llm_response_from_tools(tools), + _create_llm_response_from_text("test llm response after tool call"), + _create_llm_response_from_text( + "test llm response after final tool call" + ), + ] + + @pytest.fixture + def mock_model( + self, llm_responses: list[GenerateContentResponse] + ) -> testing_utils.MockModel: + """Provides a mock model with predefined responses.""" + return testing_utils.MockModel(responses=llm_responses) + + @pytest.fixture + def agent( + self, mock_model: testing_utils.MockModel, tools: list[FunctionTool] + ) -> LlmAgent: + """Provides a single LlmAgent for the test.""" + return LlmAgent(name="root_agent", model=mock_model, tools=tools) + + @pytest.fixture + def runner(self, agent: LlmAgent) -> testing_utils.TestInMemoryRunner: + """Provides an in-memory runner for the agent.""" + # Mark the app as resumable. So that the invocation will be paused after the + # long running tool call. + app = App( + name="InMemoryRunner", # Required for using TestInMemoryRunner. + resumability_config=ResumabilityConfig(is_resumable=True), + root_agent=agent, + ) + return testing_utils.TestInMemoryRunner(app=app, app_name=None) + + @pytest.mark.asyncio + async def test_pause_on_request_confirmation( + self, + runner: testing_utils.TestInMemoryRunner, + agent: LlmAgent, + ): + """Tests HITL flow where all tool calls are confirmed.""" + events = await runner.run_async_with_new_session("test user query") + + # Verify that the invocation is paused after the long running tool call. + # So that no intermediate function response and llm response is generated. + assert testing_utils.simplify_events(copy.deepcopy(events)) == [ + ( + agent.name, + Part(function_call=FunctionCall(name=agent.tools[0].name, args={})), + ), + ( + agent.name, + Part( + function_call=FunctionCall( + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args={ + "originalFunctionCall": { + "name": agent.tools[0].name, + "id": mock.ANY, + "args": {}, + }, + "toolConfirmation": { + "hint": "test hint for request_confirmation", + "confirmed": False, + }, + }, + ) + ), + ), + ]