chore: add test for parallel agent to verify correct handling of exceptions

PiperOrigin-RevId: 797825924
This commit is contained in:
Kacper Jawoszek
2025-08-21 10:42:46 -07:00
committed by Copybara-Service
parent 70f50db653
commit 9645cee3a6
+70 -19
View File
@@ -33,12 +33,8 @@ class _TestingAgent(BaseAgent):
delay: float = 0
"""The delay before the agent generates an event."""
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
await asyncio.sleep(self.delay)
yield Event(
def event(self, ctx: InvocationContext):
return Event(
author=self.name,
branch=ctx.branch,
invocation_id=ctx.invocation_id,
@@ -47,6 +43,13 @@ class _TestingAgent(BaseAgent):
),
)
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
await asyncio.sleep(self.delay)
yield self.event(ctx)
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
@@ -137,7 +140,7 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
assert events[2].branch != events[0].branch
class _TestingAgentWithMultipleEvents(BaseAgent):
class _TestingAgentWithMultipleEvents(_TestingAgent):
"""Mock agent for testing."""
@override
@@ -145,18 +148,11 @@ class _TestingAgentWithMultipleEvents(BaseAgent):
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
for _ in range(0, 3):
event = Event(
author=self.name,
branch=ctx.branch,
invocation_id=ctx.invocation_id,
content=types.Content(
parts=[types.Part(text=f'Hello, async {self.name}!')]
),
)
yield event
# Check that the event was processed by the consumer.
assert event.custom_metadata is not None
assert event.custom_metadata['processed']
event = self.event(ctx)
yield event
# Check that the event was processed by the consumer.
assert event.custom_metadata is not None
assert event.custom_metadata['processed']
@pytest.mark.asyncio
@@ -186,3 +182,58 @@ async def test_generating_one_event_per_agent_at_once(
async for event in agen:
event.custom_metadata = {'processed': True}
# Asserts on event are done in _TestingAgentWithMultipleEvents.
class _TestingAgentWithException(_TestingAgent):
"""Mock agent for testing."""
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
yield self.event(ctx)
raise Exception()
class _TestingAgentInfiniteEvents(_TestingAgent):
"""Mock agent for testing."""
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
while True:
yield self.event(ctx)
@pytest.mark.asyncio
async def test_stop_agent_if_sub_agent_fails(
request: pytest.FixtureRequest,
):
# This test is to verify that the parallel agent and subagents will all stop
# processing and throw exception to top level runner in case of exception.
agent1 = _TestingAgentWithException(
name=f'{request.function.__name__}_test_agent_1'
)
agent2 = _TestingAgentInfiniteEvents(
name=f'{request.function.__name__}_test_agent_2'
)
parallel_agent = ParallelAgent(
name=f'{request.function.__name__}_test_parallel_agent',
sub_agents=[
agent1,
agent2,
],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, parallel_agent
)
agen = parallel_agent.run_async(parent_ctx)
# We expect to receive an exception from one of subagents.
# The exception should be propagated to root agent and other subagents.
# Otherwise we'll have an infinite loop.
with pytest.raises(Exception):
async for _ in agen:
# The infinite agent could iterate a few times depending on scheduling.
pass