From a17f3b2e6d2d48c433b42e27763f3d6df80243ca Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 15 Oct 2025 15:20:57 -0700 Subject: [PATCH] feat: Allow custom request and event converters in A2aAgentExecutor This change introduces type aliases for request and event conversion functions: - `A2ARequestToADKRunArgsConverter`: For converting A2A `RequestContext` to an `ADKRunArgs` Pydantic model. - `AdkEventToA2AEventsConverter`: For converting ADK `Event` to a list of A2A `A2AEvent` objects. The `convert_a2a_request_to_adk_run_args` function now returns a structured `ADKRunArgs` model instead of a generic dictionary, improving type safety. These converter types can now be provided via the `A2aAgentExecutorConfig` to customize the conversion logic used by the `A2aAgentExecutor`. The executor defaults to the existing `convert_a2a_request_to_adk_run_args` and `convert_event_to_a2a_events` functions if no custom converters are specified. This allows users to inject their own logic for handling request and event conversions, for example, to add custom metadata or transform data types, without modifying the core executor. PiperOrigin-RevId: 819934960 --- .../adk/a2a/converters/event_converter.py | 29 + .../adk/a2a/converters/request_converter.py | 65 +- .../adk/a2a/executor/a2a_agent_executor.py | 44 +- .../a2a/converters/test_request_converter.py | 106 +- .../a2a/executor/test_a2a_agent_executor.py | 1083 ++++++++--------- 5 files changed, 692 insertions(+), 635 deletions(-) diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py index 5e941a78..df824763 100644 --- a/src/google/adk/a2a/converters/event_converter.py +++ b/src/google/adk/a2a/converters/event_converter.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections.abc import Callable from datetime import datetime from datetime import timezone import logging @@ -57,6 +58,34 @@ DEFAULT_ERROR_MESSAGE = "An error occurred during processing" logger = logging.getLogger("google_adk." + __name__) +AdkEventToA2AEventsConverter = Callable[ + [ + Event, + InvocationContext, + Optional[str], + Optional[str], + GenAIPartToA2APartConverter, + ], + List[A2AEvent], +] +"""A callable that converts an ADK Event into a list of A2A events. + +This interface allows for custom logic to map ADK's event structure to the +event structure expected by the A2A server. + +Args: + event: The source ADK Event to convert. + invocation_context: The context of the ADK agent invocation. + task_id: The ID of the A2A task being processed. + context_id: The context ID from the A2A request. + part_converter: A function to convert GenAI content parts to A2A + parts. + +Returns: + A list of A2A events. +""" + + def _serialize_metadata_value(value: Any) -> str: """Safely serializes metadata values to string format. diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py index 78a6d78e..92166c3f 100644 --- a/src/google/adk/a2a/converters/request_converter.py +++ b/src/google/adk/a2a/converters/request_converter.py @@ -14,8 +14,12 @@ from __future__ import annotations +from collections.abc import Callable import sys from typing import Any +from typing import Optional + +from pydantic import BaseModel try: from a2a.server.agent_execution import RequestContext @@ -35,6 +39,39 @@ from .part_converter import A2APartToGenAIPartConverter from .part_converter import convert_a2a_part_to_genai_part +@a2a_experimental +class AgentRunRequest(BaseModel): + """Data model for arguments passed to the ADK runner.""" + + user_id: Optional[str] = None + session_id: Optional[str] = None + invocation_id: Optional[str] = None + new_message: Optional[genai_types.Content] = None + state_delta: Optional[dict[str, Any]] = None + run_config: Optional[RunConfig] = None + + +A2ARequestToAgentRunRequestConverter = Callable[ + [ + RequestContext, + A2APartToGenAIPartConverter, + ], + AgentRunRequest, +] +"""A callable that converts an A2A RequestContext to RunnerRequest for ADK runner. + +This interface allows for custom logic to map an incoming A2A RequestContext to the +structured arguments expected by the ADK runner's `run_async` method. + +Args: + request: The incoming request context from the A2A server. + part_converter: A function to convert A2A content parts to GenAI parts. + +Returns: + An RunnerRequest object containing the keyword arguments for ADK runner's run_async method. +""" + + def _get_user_id(request: RequestContext) -> str: # Get user from call context if available (auth is enabled on a2a server) if ( @@ -49,20 +86,32 @@ def _get_user_id(request: RequestContext) -> str: @a2a_experimental -def convert_a2a_request_to_adk_run_args( +def convert_a2a_request_to_agent_run_request( request: RequestContext, part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, -) -> dict[str, Any]: +) -> AgentRunRequest: + """Converts an A2A RequestContext to an AgentRunRequest model. + + Args: + request: The incoming request context from the A2A server. + part_converter: A function to convert A2A content parts to GenAI parts. + + Returns: + A AgentRunRequest object ready to be used as arguments for the ADK runner. + + Raises: + ValueError: If the request message is None. + """ if not request.message: raise ValueError('Request message cannot be None') - return { - 'user_id': _get_user_id(request), - 'session_id': request.context_id, - 'new_message': genai_types.Content( + return AgentRunRequest( + user_id=_get_user_id(request), + session_id=request.context_id, + new_message=genai_types.Content( role='user', parts=[part_converter(part) for part in request.message.parts], ), - 'run_config': RunConfig(), - } + run_config=RunConfig(), + ) diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index 4cb92843..608a8188 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -18,7 +18,6 @@ from datetime import datetime from datetime import timezone import inspect import logging -from typing import Any from typing import Awaitable from typing import Callable from typing import Optional @@ -52,12 +51,15 @@ from google.adk.runners import Runner from pydantic import BaseModel from typing_extensions import override +from ..converters.event_converter import AdkEventToA2AEventsConverter from ..converters.event_converter import convert_event_to_a2a_events from ..converters.part_converter import A2APartToGenAIPartConverter from ..converters.part_converter import convert_a2a_part_to_genai_part from ..converters.part_converter import convert_genai_part_to_a2a_part from ..converters.part_converter import GenAIPartToA2APartConverter -from ..converters.request_converter import convert_a2a_request_to_adk_run_args +from ..converters.request_converter import A2ARequestToAgentRunRequestConverter +from ..converters.request_converter import AgentRunRequest +from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key from ..experimental import a2a_experimental from .task_result_aggregator import TaskResultAggregator @@ -75,6 +77,10 @@ class A2aAgentExecutorConfig(BaseModel): gen_ai_part_converter: GenAIPartToA2APartConverter = ( convert_genai_part_to_a2a_part ) + request_converter: A2ARequestToAgentRunRequestConverter = ( + convert_a2a_request_to_agent_run_request + ) + event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events @a2a_experimental @@ -192,19 +198,20 @@ class A2aAgentExecutor(AgentExecutor): # Resolve the runner instance runner = await self._resolve_runner() - # Convert the a2a request to ADK run args - run_args = convert_a2a_request_to_adk_run_args( - context, self._config.a2a_part_converter + # Convert the a2a request to AgentRunRequest + run_request = self._config.request_converter( + context, + self._config.a2a_part_converter, ) # ensure the session exists - session = await self._prepare_session(context, run_args, runner) + session = await self._prepare_session(context, run_request, runner) # create invocation context invocation_context = runner._new_invocation_context( session=session, - new_message=run_args['new_message'], - run_config=run_args['run_config'], + new_message=run_request.new_message, + run_config=run_request.run_config, ) # publish the task working event @@ -219,16 +226,16 @@ class A2aAgentExecutor(AgentExecutor): final=False, metadata={ _get_adk_metadata_key('app_name'): runner.app_name, - _get_adk_metadata_key('user_id'): run_args['user_id'], - _get_adk_metadata_key('session_id'): run_args['session_id'], + _get_adk_metadata_key('user_id'): run_request.user_id, + _get_adk_metadata_key('session_id'): run_request.session_id, }, ) ) task_result_aggregator = TaskResultAggregator() - async with Aclosing(runner.run_async(**run_args)) as agen: + async with Aclosing(runner.run_async(**vars(run_request))) as agen: async for adk_event in agen: - for a2a_event in convert_event_to_a2a_events( + for a2a_event in self._config.event_converter( adk_event, invocation_context, context.task_id, @@ -284,12 +291,15 @@ class A2aAgentExecutor(AgentExecutor): ) async def _prepare_session( - self, context: RequestContext, run_args: dict[str, Any], runner: Runner + self, + context: RequestContext, + run_request: AgentRunRequest, + runner: Runner, ): - session_id = run_args['session_id'] + session_id = run_request.session_id # create a new session if not exists - user_id = run_args['user_id'] + user_id = run_request.user_id session = await runner.session_service.get_session( app_name=runner.app_name, user_id=user_id, @@ -302,7 +312,7 @@ class A2aAgentExecutor(AgentExecutor): state={}, session_id=session_id, ) - # Update run_args with the new session_id - run_args['session_id'] = session.id + # Update run_request with the new session_id + run_request.session_id = session.id return session diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py index 115b2312..360afbd3 100644 --- a/tests/unittests/a2a/converters/test_request_converter.py +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -27,7 +27,7 @@ pytestmark = pytest.mark.skipif( try: from a2a.server.agent_execution import RequestContext from google.adk.a2a.converters.request_converter import _get_user_id - from google.adk.a2a.converters.request_converter import convert_a2a_request_to_adk_run_args + from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request from google.adk.runners import RunConfig from google.genai import types as genai_types except ImportError as e: @@ -143,11 +143,11 @@ class TestGetUserId: assert result == "A2A_USER_None" -class TestConvertA2aRequestToAdkRunArgs: - """Test cases for convert_a2a_request_to_adk_run_args function.""" +class TestConvertA2aRequestToAgentRunRequest: + """Test cases for convert_a2a_request_to_agent_run_request function.""" def test_convert_a2a_request_basic(self): - """Test basic conversion of A2A request to ADK run args.""" + """Test basic conversion of A2A request to ADK AgentRunRequest.""" # Arrange mock_part1 = Mock() mock_part2 = Mock() @@ -173,16 +173,18 @@ class TestConvertA2aRequestToAdkRunArgs: mock_convert_part.side_effect = [mock_genai_part1, mock_genai_part2] # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "test_user" - assert result["session_id"] == "test_context_123" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part1, mock_genai_part2] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "test_user" + assert result.session_id == "test_context_123" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part1, mock_genai_part2] + assert isinstance(result.run_config, RunConfig) # Verify calls assert mock_convert_part.call_count == 2 @@ -197,7 +199,7 @@ class TestConvertA2aRequestToAdkRunArgs: # Act & Assert with pytest.raises(ValueError, match="Request message cannot be None"): - convert_a2a_request_to_adk_run_args(request) + convert_a2a_request_to_agent_run_request(request) def test_convert_a2a_request_empty_parts(self): """Test conversion with empty parts list.""" @@ -212,16 +214,18 @@ class TestConvertA2aRequestToAdkRunArgs: request.call_context = None # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "A2A_USER_test_context_123" - assert result["session_id"] == "test_context_123" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "A2A_USER_test_context_123" + assert result.session_id == "test_context_123" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [] + assert isinstance(result.run_config, RunConfig) # Verify convert_part wasn't called mock_convert_part.assert_not_called() @@ -244,16 +248,18 @@ class TestConvertA2aRequestToAdkRunArgs: mock_convert_part.return_value = mock_genai_part # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "A2A_USER_None" - assert result["session_id"] is None - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "A2A_USER_None" + assert result.session_id is None + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part] + assert isinstance(result.run_config, RunConfig) def test_convert_a2a_request_no_auth(self): """Test conversion when no authentication is available.""" @@ -273,16 +279,18 @@ class TestConvertA2aRequestToAdkRunArgs: mock_convert_part.return_value = mock_genai_part # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "A2A_USER_session_123" - assert result["session_id"] == "session_123" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "A2A_USER_session_123" + assert result.session_id == "session_123" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part] + assert isinstance(result.run_config, RunConfig) class TestIntegration: @@ -312,16 +320,18 @@ class TestIntegration: mock_convert_part.return_value = mock_genai_part # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "auth_user" # Should use authenticated user - assert result["session_id"] == "mysession" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "auth_user" # Should use authenticated user + assert result.session_id == "mysession" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part] + assert isinstance(result.run_config, RunConfig) def test_end_to_end_conversion_with_fallback_user(self): """Test end-to-end conversion with fallback user ID.""" @@ -341,15 +351,17 @@ class TestIntegration: mock_convert_part.return_value = mock_genai_part # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None assert ( - result["user_id"] == "A2A_USER_test_session_456" + result.user_id == "A2A_USER_test_session_456" ) # Should fallback to context ID - assert result["session_id"] == "test_session_456" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part] - assert isinstance(result["run_config"], RunConfig) + assert result.session_id == "test_session_456" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part] + assert isinstance(result.run_config, RunConfig) diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 11d6b3a7..4bcc7a91 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -31,10 +31,13 @@ try: from a2a.types import Message from a2a.types import TaskState from a2a.types import TextPart + from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig from google.adk.events.event import Event + from google.adk.runners import RunConfig from google.adk.runners import Runner + from google.genai.types import Content except ImportError as e: if sys.version_info < (3, 10): # Imports are not needed since tests will be skipped due to pytestmark. @@ -58,9 +61,13 @@ class TestA2aAgentExecutor: self.mock_a2a_part_converter = Mock() self.mock_gen_ai_part_converter = Mock() + self.mock_request_converter = Mock() + self.mock_event_converter = Mock() self.mock_config = A2aAgentExecutorConfig( a2a_part_converter=self.mock_a2a_part_converter, gen_ai_part_converter=self.mock_gen_ai_part_converter, + request_converter=self.mock_request_converter, + event_converter=self.mock_event_converter, ) self.executor = A2aAgentExecutor( runner=self.mock_runner, config=self.mock_config @@ -84,71 +91,73 @@ class TestA2aAgentExecutor: async def test_execute_success_new_task(self): """Test successful execution of a new task.""" # Setup - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session - ) + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) - # Mock agent run with proper async generator - mock_event = Mock(spec=Event) + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator([mock_event]): - yield item + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] - self.mock_runner.run_async = mock_run_async + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [] + # Verify request converter was called with proper arguments + self.mock_request_converter.assert_called_once_with( + self.mock_context, self.mock_a2a_part_converter + ) - # Execute - await self.executor.execute(self.mock_context, self.mock_event_queue) + # Verify event converter was called with proper arguments + self.mock_event_converter.assert_called_once_with( + mock_event, + mock_invocation_context, + self.mock_context.task_id, + self.mock_context.context_id, + self.mock_gen_ai_part_converter, + ) - # Verify task submitted event was enqueued - assert self.mock_event_queue.enqueue_event.call_count >= 3 - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False - # Verify working event was enqueued - working_event = self.mock_event_queue.enqueue_event.call_args_list[1][ - 0 - ][0] - assert working_event.status.state == TaskState.working - assert working_event.final == False + # Verify working event was enqueued + working_event = self.mock_event_queue.enqueue_event.call_args_list[1][0][0] + assert working_event.status.state == TaskState.working + assert working_event.final == False - # Verify final event was enqueued with proper message field - final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert final_event.final == True - # The TaskResultAggregator is created with default state (working), and since no messages - # are processed, it will publish a status event with the current state - assert hasattr(final_event.status, "message") - assert final_event.status.state == TaskState.working + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), and since no messages + # are processed, it will publish a status event with the current state + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_execute_no_message_error(self): @@ -164,73 +173,76 @@ class TestA2aAgentExecutor: self.mock_context.current_task = Mock() self.mock_context.task_id = "existing-task-id" - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session - ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) - # Mock agent run with proper async generator - mock_event = Mock(spec=Event) + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator([mock_event]): - yield item + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item - self.mock_runner.run_async = mock_run_async + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [] + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) - # Execute - await self.executor.execute(self.mock_context, self.mock_event_queue) + # Verify request converter was called with proper arguments + self.mock_request_converter.assert_called_once_with( + self.mock_context, self.mock_a2a_part_converter + ) - # Verify no submitted event (first call should be working event) - working_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert working_event.status.state == TaskState.working - assert working_event.final == False + # Verify event converter was called with proper arguments + self.mock_event_converter.assert_called_once_with( + mock_event, + mock_invocation_context, + self.mock_context.task_id, + self.mock_context.context_id, + self.mock_gen_ai_part_converter, + ) - # Verify final event was enqueued with proper message field - final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert final_event.final == True - # The TaskResultAggregator is created with default state (working), and since no messages - # are processed, it will publish a status event with the current state - assert hasattr(final_event.status, "message") - assert final_event.status.state == TaskState.working + # Verify no submitted event (first call should be working event) + working_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] + assert working_event.status.state == TaskState.working + assert working_event.final == False + + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), and since no messages + # are processed, it will publish a status event with the current state + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_prepare_session_new_session(self): """Test session preparation when session doesn't exist.""" - run_args = { - "user_id": "test-user", - "session_id": None, - "new_message": Mock(), - "run_config": Mock(), - } + run_args = AgentRunRequest( + user_id="test-user", + session_id=None, + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) # Mock session service self.mock_runner.session_service.get_session = AsyncMock(return_value=None) @@ -247,18 +259,18 @@ class TestA2aAgentExecutor: # Verify session was created assert result == mock_session - assert run_args["session_id"] is not None + assert run_args.session_id is not None self.mock_runner.session_service.create_session.assert_called_once() @pytest.mark.asyncio async def test_prepare_session_existing_session(self): """Test session preparation when session exists.""" - run_args = { - "user_id": "test-user", - "session_id": "existing-session", - "new_message": Mock(), - "run_config": Mock(), - } + run_args = AgentRunRequest( + user_id="test-user", + session_id="existing-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) # Mock session service mock_session = Mock() @@ -397,63 +409,55 @@ class TestA2aAgentExecutor: executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session - ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) - # Mock agent run with proper async generator - mock_event = Mock(spec=Event) + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) - async def mock_run_async(**kwargs): - async for item in self._create_async_generator([mock_event]): - yield item + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item - self.mock_runner.run_async = mock_run_async + self.mock_runner.run_async = mock_run_async - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [] + self.mock_event_converter.return_value = [] - # Execute - await executor.execute(self.mock_context, self.mock_event_queue) + # Execute + await executor.execute(self.mock_context, self.mock_event_queue) - # Verify task submitted event was enqueued - assert self.mock_event_queue.enqueue_event.call_count >= 3 - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False - # Verify final event was enqueued with proper message field - final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert final_event.final == True - # The TaskResultAggregator is created with default state (working), and since no messages - # are processed, it will publish a status event with the current state - assert hasattr(final_event.status, "message") - assert final_event.status.state == TaskState.working + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), and since no messages + # are processed, it will publish a status event with the current state + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_execute_with_async_callable_runner(self): @@ -464,63 +468,55 @@ class TestA2aAgentExecutor: executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session - ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) - # Mock agent run with proper async generator - mock_event = Mock(spec=Event) + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) - async def mock_run_async(**kwargs): - async for item in self._create_async_generator([mock_event]): - yield item + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item - self.mock_runner.run_async = mock_run_async + self.mock_runner.run_async = mock_run_async - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [] + self.mock_event_converter.return_value = [] - # Execute - await executor.execute(self.mock_context, self.mock_event_queue) + # Execute + await executor.execute(self.mock_context, self.mock_event_queue) - # Verify task submitted event was enqueued - assert self.mock_event_queue.enqueue_event.call_count >= 3 - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False - # Verify final event was enqueued with proper message field - final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert final_event.final == True - # The TaskResultAggregator is created with default state (working), and since no messages - # are processed, it will publish a status event with the current state - assert hasattr(final_event.status, "message") - assert final_event.status.state == TaskState.working + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), and since no messages + # are processed, it will publish a status event with the current state + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_handle_request_integration(self): @@ -529,83 +525,75 @@ class TestA2aAgentExecutor: self.mock_context.task_id = "test-task-id" # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + mock_aggregator.task_state = TaskState.working + # Mock the task_status_message property to return None by default + mock_aggregator.task_status_message = None + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Verify working event was enqueued + working_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "status") + and call[0][0].status.state == TaskState.working + ] + assert len(working_events) >= 1 - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] + # Verify aggregator processed events + assert mock_aggregator.process_event.call_count == len(mock_events) - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - mock_aggregator.task_state = TaskState.working - # Mock the task_status_message property to return None by default - mock_aggregator.task_status_message = None - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify working event was enqueued - working_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "status") - and call[0][0].status.state == TaskState.working - ] - assert len(working_events) >= 1 - - # Verify aggregator processed events - assert mock_aggregator.process_event.call_count == len(mock_events) - - # Verify final event has message field from aggregator and state is completed when aggregator state is working - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert ( - final_event.status.message == mock_aggregator.task_status_message - ) - # When aggregator state is working but no message, final event should be working - assert final_event.status.state == TaskState.working + # Verify final event has message field from aggregator and state is completed when aggregator state is working + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.message == mock_aggregator.task_status_message + # When aggregator state is working but no message, final event should be working + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_cancel_with_task_id(self): @@ -637,31 +625,26 @@ class TestA2aAgentExecutor: None # Make sure it goes through submitted event creation ) - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.side_effect = Exception("Test error") + self.mock_request_converter.side_effect = Exception("Test error") - # Execute (should not raise since we catch the exception) - await self.executor.execute(self.mock_context, self.mock_event_queue) + # Execute (should not raise since we catch the exception) + await self.executor.execute(self.mock_context, self.mock_event_queue) - # Verify both submitted and failure events were enqueued - # First call should be submitted event, last should be failure event - assert self.mock_event_queue.enqueue_event.call_count >= 2 + # Verify both submitted and failure events were enqueued + # First call should be submitted event, last should be failure event + assert self.mock_event_queue.enqueue_event.call_count >= 2 - # Check submitted event (first) - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Check submitted event (first) + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False - # Check failure event (last) - failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert failure_event.status.state == TaskState.failed - assert failure_event.final == True + # Check failure event (last) + failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert failure_event.status.state == TaskState.failed + assert failure_event.final == True @pytest.mark.asyncio async def test_handle_request_with_aggregator_message(self): @@ -680,69 +663,63 @@ class TestA2aAgentExecutor: test_message.parts = [Mock(spec=TextPart)] # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + mock_aggregator.task_state = TaskState.completed + # Mock the task_status_message property to return a test message + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) - - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] - - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - mock_aggregator.task_state = TaskState.completed - # Mock the task_status_message property to return a test message - mock_aggregator.task_status_message = test_message - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify final event has message field from aggregator - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.message == test_message - # When aggregator state is completed (not working), final event should be completed - assert final_event.status.state == TaskState.completed + # Verify final event has message field from aggregator + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.message == test_message + # When aggregator state is completed (not working), final event should be completed + assert final_event.status.state == TaskState.completed @pytest.mark.asyncio async def test_handle_request_with_non_working_aggregator_state(self): @@ -761,69 +738,63 @@ class TestA2aAgentExecutor: test_message.parts = [Mock(spec=TextPart)] # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + # Test with failed state - should preserve failed state + mock_aggregator.task_state = TaskState.failed + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) - - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] - - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - # Test with failed state - should preserve failed state - mock_aggregator.task_state = TaskState.failed - mock_aggregator.task_status_message = test_message - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify final event preserves the non-working state - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.message == test_message - # When aggregator state is failed (not working), final event should keep failed state - assert final_event.status.state == TaskState.failed + # Verify final event preserves the non-working state + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.message == test_message + # When aggregator state is failed (not working), final event should keep failed state + assert final_event.status.state == TaskState.failed @pytest.mark.asyncio async def test_handle_request_with_working_state_publishes_artifact_and_completed( @@ -846,84 +817,77 @@ class TestA2aAgentExecutor: test_message.parts = [Part(root=TextPart(text="test content"))] # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + # Test with working state - should publish artifact update and completed status + mock_aggregator.task_state = TaskState.working + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Verify artifact update event was published + artifact_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "artifact") and call[0][0].last_chunk == True + ] + assert len(artifact_events) == 1 + artifact_event = artifact_events[0] + assert artifact_event.task_id == "test-task-id" + assert artifact_event.context_id == "test-context-id" + # Check that artifact parts correspond to message parts + assert len(artifact_event.artifact.parts) == len(test_message.parts) + assert artifact_event.artifact.parts == test_message.parts - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] - - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - # Test with working state - should publish artifact update and completed status - mock_aggregator.task_state = TaskState.working - mock_aggregator.task_status_message = test_message - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify artifact update event was published - artifact_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "artifact") - and call[0][0].last_chunk == True - ] - assert len(artifact_events) == 1 - artifact_event = artifact_events[0] - assert artifact_event.task_id == "test-task-id" - assert artifact_event.context_id == "test-context-id" - # Check that artifact parts correspond to message parts - assert len(artifact_event.artifact.parts) == len(test_message.parts) - assert artifact_event.artifact.parts == test_message.parts - - # Verify final status event was published with completed state - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.state == TaskState.completed - assert final_event.task_id == "test-task-id" - assert final_event.context_id == "test-context-id" + # Verify final status event was published with completed state + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.state == TaskState.completed + assert final_event.task_id == "test-task-id" + assert final_event.context_id == "test-context-id" @pytest.mark.asyncio async def test_handle_request_with_non_working_state_publishes_status_only( @@ -946,76 +910,69 @@ class TestA2aAgentExecutor: test_message.parts = [Part(root=TextPart(text="test content"))] # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + # Test with auth_required state - should publish only status event + mock_aggregator.task_state = TaskState.auth_required + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Verify no artifact update event was published + artifact_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "artifact") and call[0][0].last_chunk == True + ] + assert len(artifact_events) == 0 - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] - - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - # Test with auth_required state - should publish only status event - mock_aggregator.task_state = TaskState.auth_required - mock_aggregator.task_status_message = test_message - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify no artifact update event was published - artifact_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "artifact") - and call[0][0].last_chunk == True - ] - assert len(artifact_events) == 0 - - # Verify final status event was published with the actual state and message - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.state == TaskState.auth_required - assert final_event.status.message == test_message - assert final_event.task_id == "test-task-id" - assert final_event.context_id == "test-context-id" + # Verify final status event was published with the actual state and message + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.state == TaskState.auth_required + assert final_event.status.message == test_message + assert final_event.task_id == "test-task-id" + assert final_event.context_id == "test-context-id"