From 9645cee3a6dd0581bd37c45934ccb2eed1184053 Mon Sep 17 00:00:00 2001 From: Kacper Jawoszek Date: Thu, 21 Aug 2025 10:42:46 -0700 Subject: [PATCH] chore: add test for parallel agent to verify correct handling of exceptions PiperOrigin-RevId: 797825924 --- tests/unittests/agents/test_parallel_agent.py | 89 +++++++++++++++---- 1 file changed, 70 insertions(+), 19 deletions(-) diff --git a/tests/unittests/agents/test_parallel_agent.py b/tests/unittests/agents/test_parallel_agent.py index 53800985..3b0168a8 100644 --- a/tests/unittests/agents/test_parallel_agent.py +++ b/tests/unittests/agents/test_parallel_agent.py @@ -33,12 +33,8 @@ class _TestingAgent(BaseAgent): delay: float = 0 """The delay before the agent generates an event.""" - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - await asyncio.sleep(self.delay) - yield Event( + def event(self, ctx: InvocationContext): + return Event( author=self.name, branch=ctx.branch, invocation_id=ctx.invocation_id, @@ -47,6 +43,13 @@ class _TestingAgent(BaseAgent): ), ) + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + await asyncio.sleep(self.delay) + yield self.event(ctx) + async def _create_parent_invocation_context( test_name: str, agent: BaseAgent @@ -137,7 +140,7 @@ async def test_run_async_branches(request: pytest.FixtureRequest): assert events[2].branch != events[0].branch -class _TestingAgentWithMultipleEvents(BaseAgent): +class _TestingAgentWithMultipleEvents(_TestingAgent): """Mock agent for testing.""" @override @@ -145,18 +148,11 @@ class _TestingAgentWithMultipleEvents(BaseAgent): self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: for _ in range(0, 3): - event = Event( - author=self.name, - branch=ctx.branch, - invocation_id=ctx.invocation_id, - content=types.Content( - parts=[types.Part(text=f'Hello, async {self.name}!')] - ), - ) - yield event - # Check that the event was processed by the consumer. - assert event.custom_metadata is not None - assert event.custom_metadata['processed'] + event = self.event(ctx) + yield event + # Check that the event was processed by the consumer. + assert event.custom_metadata is not None + assert event.custom_metadata['processed'] @pytest.mark.asyncio @@ -186,3 +182,58 @@ async def test_generating_one_event_per_agent_at_once( async for event in agen: event.custom_metadata = {'processed': True} # Asserts on event are done in _TestingAgentWithMultipleEvents. + + +class _TestingAgentWithException(_TestingAgent): + """Mock agent for testing.""" + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield self.event(ctx) + raise Exception() + + +class _TestingAgentInfiniteEvents(_TestingAgent): + """Mock agent for testing.""" + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + while True: + yield self.event(ctx) + + +@pytest.mark.asyncio +async def test_stop_agent_if_sub_agent_fails( + request: pytest.FixtureRequest, +): + # This test is to verify that the parallel agent and subagents will all stop + # processing and throw exception to top level runner in case of exception. + agent1 = _TestingAgentWithException( + name=f'{request.function.__name__}_test_agent_1' + ) + agent2 = _TestingAgentInfiniteEvents( + name=f'{request.function.__name__}_test_agent_2' + ) + parallel_agent = ParallelAgent( + name=f'{request.function.__name__}_test_parallel_agent', + sub_agents=[ + agent1, + agent2, + ], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, parallel_agent + ) + + agen = parallel_agent.run_async(parent_ctx) + # We expect to receive an exception from one of subagents. + # The exception should be propagated to root agent and other subagents. + # Otherwise we'll have an infinite loop. + with pytest.raises(Exception): + async for _ in agen: + # The infinite agent could iterate a few times depending on scheduling. + pass