diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py index 5e941a78..df824763 100644 --- a/src/google/adk/a2a/converters/event_converter.py +++ b/src/google/adk/a2a/converters/event_converter.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections.abc import Callable from datetime import datetime from datetime import timezone import logging @@ -57,6 +58,34 @@ DEFAULT_ERROR_MESSAGE = "An error occurred during processing" logger = logging.getLogger("google_adk." + __name__) +AdkEventToA2AEventsConverter = Callable[ + [ + Event, + InvocationContext, + Optional[str], + Optional[str], + GenAIPartToA2APartConverter, + ], + List[A2AEvent], +] +"""A callable that converts an ADK Event into a list of A2A events. + +This interface allows for custom logic to map ADK's event structure to the +event structure expected by the A2A server. + +Args: + event: The source ADK Event to convert. + invocation_context: The context of the ADK agent invocation. + task_id: The ID of the A2A task being processed. + context_id: The context ID from the A2A request. + part_converter: A function to convert GenAI content parts to A2A + parts. + +Returns: + A list of A2A events. +""" + + def _serialize_metadata_value(value: Any) -> str: """Safely serializes metadata values to string format. diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py index 78a6d78e..92166c3f 100644 --- a/src/google/adk/a2a/converters/request_converter.py +++ b/src/google/adk/a2a/converters/request_converter.py @@ -14,8 +14,12 @@ from __future__ import annotations +from collections.abc import Callable import sys from typing import Any +from typing import Optional + +from pydantic import BaseModel try: from a2a.server.agent_execution import RequestContext @@ -35,6 +39,39 @@ from .part_converter import A2APartToGenAIPartConverter from .part_converter import convert_a2a_part_to_genai_part +@a2a_experimental +class AgentRunRequest(BaseModel): + """Data model for arguments passed to the ADK runner.""" + + user_id: Optional[str] = None + session_id: Optional[str] = None + invocation_id: Optional[str] = None + new_message: Optional[genai_types.Content] = None + state_delta: Optional[dict[str, Any]] = None + run_config: Optional[RunConfig] = None + + +A2ARequestToAgentRunRequestConverter = Callable[ + [ + RequestContext, + A2APartToGenAIPartConverter, + ], + AgentRunRequest, +] +"""A callable that converts an A2A RequestContext to RunnerRequest for ADK runner. + +This interface allows for custom logic to map an incoming A2A RequestContext to the +structured arguments expected by the ADK runner's `run_async` method. + +Args: + request: The incoming request context from the A2A server. + part_converter: A function to convert A2A content parts to GenAI parts. + +Returns: + An RunnerRequest object containing the keyword arguments for ADK runner's run_async method. +""" + + def _get_user_id(request: RequestContext) -> str: # Get user from call context if available (auth is enabled on a2a server) if ( @@ -49,20 +86,32 @@ def _get_user_id(request: RequestContext) -> str: @a2a_experimental -def convert_a2a_request_to_adk_run_args( +def convert_a2a_request_to_agent_run_request( request: RequestContext, part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, -) -> dict[str, Any]: +) -> AgentRunRequest: + """Converts an A2A RequestContext to an AgentRunRequest model. + + Args: + request: The incoming request context from the A2A server. + part_converter: A function to convert A2A content parts to GenAI parts. + + Returns: + A AgentRunRequest object ready to be used as arguments for the ADK runner. + + Raises: + ValueError: If the request message is None. + """ if not request.message: raise ValueError('Request message cannot be None') - return { - 'user_id': _get_user_id(request), - 'session_id': request.context_id, - 'new_message': genai_types.Content( + return AgentRunRequest( + user_id=_get_user_id(request), + session_id=request.context_id, + new_message=genai_types.Content( role='user', parts=[part_converter(part) for part in request.message.parts], ), - 'run_config': RunConfig(), - } + run_config=RunConfig(), + ) diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index 4cb92843..608a8188 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -18,7 +18,6 @@ from datetime import datetime from datetime import timezone import inspect import logging -from typing import Any from typing import Awaitable from typing import Callable from typing import Optional @@ -52,12 +51,15 @@ from google.adk.runners import Runner from pydantic import BaseModel from typing_extensions import override +from ..converters.event_converter import AdkEventToA2AEventsConverter from ..converters.event_converter import convert_event_to_a2a_events from ..converters.part_converter import A2APartToGenAIPartConverter from ..converters.part_converter import convert_a2a_part_to_genai_part from ..converters.part_converter import convert_genai_part_to_a2a_part from ..converters.part_converter import GenAIPartToA2APartConverter -from ..converters.request_converter import convert_a2a_request_to_adk_run_args +from ..converters.request_converter import A2ARequestToAgentRunRequestConverter +from ..converters.request_converter import AgentRunRequest +from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key from ..experimental import a2a_experimental from .task_result_aggregator import TaskResultAggregator @@ -75,6 +77,10 @@ class A2aAgentExecutorConfig(BaseModel): gen_ai_part_converter: GenAIPartToA2APartConverter = ( convert_genai_part_to_a2a_part ) + request_converter: A2ARequestToAgentRunRequestConverter = ( + convert_a2a_request_to_agent_run_request + ) + event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events @a2a_experimental @@ -192,19 +198,20 @@ class A2aAgentExecutor(AgentExecutor): # Resolve the runner instance runner = await self._resolve_runner() - # Convert the a2a request to ADK run args - run_args = convert_a2a_request_to_adk_run_args( - context, self._config.a2a_part_converter + # Convert the a2a request to AgentRunRequest + run_request = self._config.request_converter( + context, + self._config.a2a_part_converter, ) # ensure the session exists - session = await self._prepare_session(context, run_args, runner) + session = await self._prepare_session(context, run_request, runner) # create invocation context invocation_context = runner._new_invocation_context( session=session, - new_message=run_args['new_message'], - run_config=run_args['run_config'], + new_message=run_request.new_message, + run_config=run_request.run_config, ) # publish the task working event @@ -219,16 +226,16 @@ class A2aAgentExecutor(AgentExecutor): final=False, metadata={ _get_adk_metadata_key('app_name'): runner.app_name, - _get_adk_metadata_key('user_id'): run_args['user_id'], - _get_adk_metadata_key('session_id'): run_args['session_id'], + _get_adk_metadata_key('user_id'): run_request.user_id, + _get_adk_metadata_key('session_id'): run_request.session_id, }, ) ) task_result_aggregator = TaskResultAggregator() - async with Aclosing(runner.run_async(**run_args)) as agen: + async with Aclosing(runner.run_async(**vars(run_request))) as agen: async for adk_event in agen: - for a2a_event in convert_event_to_a2a_events( + for a2a_event in self._config.event_converter( adk_event, invocation_context, context.task_id, @@ -284,12 +291,15 @@ class A2aAgentExecutor(AgentExecutor): ) async def _prepare_session( - self, context: RequestContext, run_args: dict[str, Any], runner: Runner + self, + context: RequestContext, + run_request: AgentRunRequest, + runner: Runner, ): - session_id = run_args['session_id'] + session_id = run_request.session_id # create a new session if not exists - user_id = run_args['user_id'] + user_id = run_request.user_id session = await runner.session_service.get_session( app_name=runner.app_name, user_id=user_id, @@ -302,7 +312,7 @@ class A2aAgentExecutor(AgentExecutor): state={}, session_id=session_id, ) - # Update run_args with the new session_id - run_args['session_id'] = session.id + # Update run_request with the new session_id + run_request.session_id = session.id return session diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py index 115b2312..360afbd3 100644 --- a/tests/unittests/a2a/converters/test_request_converter.py +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -27,7 +27,7 @@ pytestmark = pytest.mark.skipif( try: from a2a.server.agent_execution import RequestContext from google.adk.a2a.converters.request_converter import _get_user_id - from google.adk.a2a.converters.request_converter import convert_a2a_request_to_adk_run_args + from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request from google.adk.runners import RunConfig from google.genai import types as genai_types except ImportError as e: @@ -143,11 +143,11 @@ class TestGetUserId: assert result == "A2A_USER_None" -class TestConvertA2aRequestToAdkRunArgs: - """Test cases for convert_a2a_request_to_adk_run_args function.""" +class TestConvertA2aRequestToAgentRunRequest: + """Test cases for convert_a2a_request_to_agent_run_request function.""" def test_convert_a2a_request_basic(self): - """Test basic conversion of A2A request to ADK run args.""" + """Test basic conversion of A2A request to ADK AgentRunRequest.""" # Arrange mock_part1 = Mock() mock_part2 = Mock() @@ -173,16 +173,18 @@ class TestConvertA2aRequestToAdkRunArgs: mock_convert_part.side_effect = [mock_genai_part1, mock_genai_part2] # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "test_user" - assert result["session_id"] == "test_context_123" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part1, mock_genai_part2] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "test_user" + assert result.session_id == "test_context_123" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part1, mock_genai_part2] + assert isinstance(result.run_config, RunConfig) # Verify calls assert mock_convert_part.call_count == 2 @@ -197,7 +199,7 @@ class TestConvertA2aRequestToAdkRunArgs: # Act & Assert with pytest.raises(ValueError, match="Request message cannot be None"): - convert_a2a_request_to_adk_run_args(request) + convert_a2a_request_to_agent_run_request(request) def test_convert_a2a_request_empty_parts(self): """Test conversion with empty parts list.""" @@ -212,16 +214,18 @@ class TestConvertA2aRequestToAdkRunArgs: request.call_context = None # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "A2A_USER_test_context_123" - assert result["session_id"] == "test_context_123" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "A2A_USER_test_context_123" + assert result.session_id == "test_context_123" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [] + assert isinstance(result.run_config, RunConfig) # Verify convert_part wasn't called mock_convert_part.assert_not_called() @@ -244,16 +248,18 @@ class TestConvertA2aRequestToAdkRunArgs: mock_convert_part.return_value = mock_genai_part # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "A2A_USER_None" - assert result["session_id"] is None - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "A2A_USER_None" + assert result.session_id is None + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part] + assert isinstance(result.run_config, RunConfig) def test_convert_a2a_request_no_auth(self): """Test conversion when no authentication is available.""" @@ -273,16 +279,18 @@ class TestConvertA2aRequestToAdkRunArgs: mock_convert_part.return_value = mock_genai_part # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "A2A_USER_session_123" - assert result["session_id"] == "session_123" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "A2A_USER_session_123" + assert result.session_id == "session_123" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part] + assert isinstance(result.run_config, RunConfig) class TestIntegration: @@ -312,16 +320,18 @@ class TestIntegration: mock_convert_part.return_value = mock_genai_part # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None - assert result["user_id"] == "auth_user" # Should use authenticated user - assert result["session_id"] == "mysession" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part] - assert isinstance(result["run_config"], RunConfig) + assert result.user_id == "auth_user" # Should use authenticated user + assert result.session_id == "mysession" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part] + assert isinstance(result.run_config, RunConfig) def test_end_to_end_conversion_with_fallback_user(self): """Test end-to-end conversion with fallback user ID.""" @@ -341,15 +351,17 @@ class TestIntegration: mock_convert_part.return_value = mock_genai_part # Act - result = convert_a2a_request_to_adk_run_args(request, mock_convert_part) + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) # Assert assert result is not None assert ( - result["user_id"] == "A2A_USER_test_session_456" + result.user_id == "A2A_USER_test_session_456" ) # Should fallback to context ID - assert result["session_id"] == "test_session_456" - assert isinstance(result["new_message"], genai_types.Content) - assert result["new_message"].role == "user" - assert result["new_message"].parts == [mock_genai_part] - assert isinstance(result["run_config"], RunConfig) + assert result.session_id == "test_session_456" + assert isinstance(result.new_message, genai_types.Content) + assert result.new_message.role == "user" + assert result.new_message.parts == [mock_genai_part] + assert isinstance(result.run_config, RunConfig) diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 11d6b3a7..4bcc7a91 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -31,10 +31,13 @@ try: from a2a.types import Message from a2a.types import TaskState from a2a.types import TextPart + from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig from google.adk.events.event import Event + from google.adk.runners import RunConfig from google.adk.runners import Runner + from google.genai.types import Content except ImportError as e: if sys.version_info < (3, 10): # Imports are not needed since tests will be skipped due to pytestmark. @@ -58,9 +61,13 @@ class TestA2aAgentExecutor: self.mock_a2a_part_converter = Mock() self.mock_gen_ai_part_converter = Mock() + self.mock_request_converter = Mock() + self.mock_event_converter = Mock() self.mock_config = A2aAgentExecutorConfig( a2a_part_converter=self.mock_a2a_part_converter, gen_ai_part_converter=self.mock_gen_ai_part_converter, + request_converter=self.mock_request_converter, + event_converter=self.mock_event_converter, ) self.executor = A2aAgentExecutor( runner=self.mock_runner, config=self.mock_config @@ -84,71 +91,73 @@ class TestA2aAgentExecutor: async def test_execute_success_new_task(self): """Test successful execution of a new task.""" # Setup - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session - ) + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) - # Mock agent run with proper async generator - mock_event = Mock(spec=Event) + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator([mock_event]): - yield item + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] - self.mock_runner.run_async = mock_run_async + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [] + # Verify request converter was called with proper arguments + self.mock_request_converter.assert_called_once_with( + self.mock_context, self.mock_a2a_part_converter + ) - # Execute - await self.executor.execute(self.mock_context, self.mock_event_queue) + # Verify event converter was called with proper arguments + self.mock_event_converter.assert_called_once_with( + mock_event, + mock_invocation_context, + self.mock_context.task_id, + self.mock_context.context_id, + self.mock_gen_ai_part_converter, + ) - # Verify task submitted event was enqueued - assert self.mock_event_queue.enqueue_event.call_count >= 3 - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False - # Verify working event was enqueued - working_event = self.mock_event_queue.enqueue_event.call_args_list[1][ - 0 - ][0] - assert working_event.status.state == TaskState.working - assert working_event.final == False + # Verify working event was enqueued + working_event = self.mock_event_queue.enqueue_event.call_args_list[1][0][0] + assert working_event.status.state == TaskState.working + assert working_event.final == False - # Verify final event was enqueued with proper message field - final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert final_event.final == True - # The TaskResultAggregator is created with default state (working), and since no messages - # are processed, it will publish a status event with the current state - assert hasattr(final_event.status, "message") - assert final_event.status.state == TaskState.working + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), and since no messages + # are processed, it will publish a status event with the current state + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_execute_no_message_error(self): @@ -164,73 +173,76 @@ class TestA2aAgentExecutor: self.mock_context.current_task = Mock() self.mock_context.task_id = "existing-task-id" - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session - ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) - # Mock agent run with proper async generator - mock_event = Mock(spec=Event) + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator([mock_event]): - yield item + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item - self.mock_runner.run_async = mock_run_async + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [] + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) - # Execute - await self.executor.execute(self.mock_context, self.mock_event_queue) + # Verify request converter was called with proper arguments + self.mock_request_converter.assert_called_once_with( + self.mock_context, self.mock_a2a_part_converter + ) - # Verify no submitted event (first call should be working event) - working_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert working_event.status.state == TaskState.working - assert working_event.final == False + # Verify event converter was called with proper arguments + self.mock_event_converter.assert_called_once_with( + mock_event, + mock_invocation_context, + self.mock_context.task_id, + self.mock_context.context_id, + self.mock_gen_ai_part_converter, + ) - # Verify final event was enqueued with proper message field - final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert final_event.final == True - # The TaskResultAggregator is created with default state (working), and since no messages - # are processed, it will publish a status event with the current state - assert hasattr(final_event.status, "message") - assert final_event.status.state == TaskState.working + # Verify no submitted event (first call should be working event) + working_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] + assert working_event.status.state == TaskState.working + assert working_event.final == False + + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), and since no messages + # are processed, it will publish a status event with the current state + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_prepare_session_new_session(self): """Test session preparation when session doesn't exist.""" - run_args = { - "user_id": "test-user", - "session_id": None, - "new_message": Mock(), - "run_config": Mock(), - } + run_args = AgentRunRequest( + user_id="test-user", + session_id=None, + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) # Mock session service self.mock_runner.session_service.get_session = AsyncMock(return_value=None) @@ -247,18 +259,18 @@ class TestA2aAgentExecutor: # Verify session was created assert result == mock_session - assert run_args["session_id"] is not None + assert run_args.session_id is not None self.mock_runner.session_service.create_session.assert_called_once() @pytest.mark.asyncio async def test_prepare_session_existing_session(self): """Test session preparation when session exists.""" - run_args = { - "user_id": "test-user", - "session_id": "existing-session", - "new_message": Mock(), - "run_config": Mock(), - } + run_args = AgentRunRequest( + user_id="test-user", + session_id="existing-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) # Mock session service mock_session = Mock() @@ -397,63 +409,55 @@ class TestA2aAgentExecutor: executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session - ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) - # Mock agent run with proper async generator - mock_event = Mock(spec=Event) + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) - async def mock_run_async(**kwargs): - async for item in self._create_async_generator([mock_event]): - yield item + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item - self.mock_runner.run_async = mock_run_async + self.mock_runner.run_async = mock_run_async - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [] + self.mock_event_converter.return_value = [] - # Execute - await executor.execute(self.mock_context, self.mock_event_queue) + # Execute + await executor.execute(self.mock_context, self.mock_event_queue) - # Verify task submitted event was enqueued - assert self.mock_event_queue.enqueue_event.call_count >= 3 - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False - # Verify final event was enqueued with proper message field - final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert final_event.final == True - # The TaskResultAggregator is created with default state (working), and since no messages - # are processed, it will publish a status event with the current state - assert hasattr(final_event.status, "message") - assert final_event.status.state == TaskState.working + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), and since no messages + # are processed, it will publish a status event with the current state + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_execute_with_async_callable_runner(self): @@ -464,63 +468,55 @@ class TestA2aAgentExecutor: executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session - ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) - # Mock agent run with proper async generator - mock_event = Mock(spec=Event) + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) - async def mock_run_async(**kwargs): - async for item in self._create_async_generator([mock_event]): - yield item + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item - self.mock_runner.run_async = mock_run_async + self.mock_runner.run_async = mock_run_async - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [] + self.mock_event_converter.return_value = [] - # Execute - await executor.execute(self.mock_context, self.mock_event_queue) + # Execute + await executor.execute(self.mock_context, self.mock_event_queue) - # Verify task submitted event was enqueued - assert self.mock_event_queue.enqueue_event.call_count >= 3 - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False - # Verify final event was enqueued with proper message field - final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert final_event.final == True - # The TaskResultAggregator is created with default state (working), and since no messages - # are processed, it will publish a status event with the current state - assert hasattr(final_event.status, "message") - assert final_event.status.state == TaskState.working + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), and since no messages + # are processed, it will publish a status event with the current state + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_handle_request_integration(self): @@ -529,83 +525,75 @@ class TestA2aAgentExecutor: self.mock_context.task_id = "test-task-id" # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + mock_aggregator.task_state = TaskState.working + # Mock the task_status_message property to return None by default + mock_aggregator.task_status_message = None + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Verify working event was enqueued + working_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "status") + and call[0][0].status.state == TaskState.working + ] + assert len(working_events) >= 1 - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] + # Verify aggregator processed events + assert mock_aggregator.process_event.call_count == len(mock_events) - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - mock_aggregator.task_state = TaskState.working - # Mock the task_status_message property to return None by default - mock_aggregator.task_status_message = None - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify working event was enqueued - working_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "status") - and call[0][0].status.state == TaskState.working - ] - assert len(working_events) >= 1 - - # Verify aggregator processed events - assert mock_aggregator.process_event.call_count == len(mock_events) - - # Verify final event has message field from aggregator and state is completed when aggregator state is working - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert ( - final_event.status.message == mock_aggregator.task_status_message - ) - # When aggregator state is working but no message, final event should be working - assert final_event.status.state == TaskState.working + # Verify final event has message field from aggregator and state is completed when aggregator state is working + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.message == mock_aggregator.task_status_message + # When aggregator state is working but no message, final event should be working + assert final_event.status.state == TaskState.working @pytest.mark.asyncio async def test_cancel_with_task_id(self): @@ -637,31 +625,26 @@ class TestA2aAgentExecutor: None # Make sure it goes through submitted event creation ) - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.side_effect = Exception("Test error") + self.mock_request_converter.side_effect = Exception("Test error") - # Execute (should not raise since we catch the exception) - await self.executor.execute(self.mock_context, self.mock_event_queue) + # Execute (should not raise since we catch the exception) + await self.executor.execute(self.mock_context, self.mock_event_queue) - # Verify both submitted and failure events were enqueued - # First call should be submitted event, last should be failure event - assert self.mock_event_queue.enqueue_event.call_count >= 2 + # Verify both submitted and failure events were enqueued + # First call should be submitted event, last should be failure event + assert self.mock_event_queue.enqueue_event.call_count >= 2 - # Check submitted event (first) - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ - 0 - ][0] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Check submitted event (first) + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False - # Check failure event (last) - failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ - 0 - ] - assert failure_event.status.state == TaskState.failed - assert failure_event.final == True + # Check failure event (last) + failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert failure_event.status.state == TaskState.failed + assert failure_event.final == True @pytest.mark.asyncio async def test_handle_request_with_aggregator_message(self): @@ -680,69 +663,63 @@ class TestA2aAgentExecutor: test_message.parts = [Mock(spec=TextPart)] # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + mock_aggregator.task_state = TaskState.completed + # Mock the task_status_message property to return a test message + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) - - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] - - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - mock_aggregator.task_state = TaskState.completed - # Mock the task_status_message property to return a test message - mock_aggregator.task_status_message = test_message - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify final event has message field from aggregator - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.message == test_message - # When aggregator state is completed (not working), final event should be completed - assert final_event.status.state == TaskState.completed + # Verify final event has message field from aggregator + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.message == test_message + # When aggregator state is completed (not working), final event should be completed + assert final_event.status.state == TaskState.completed @pytest.mark.asyncio async def test_handle_request_with_non_working_aggregator_state(self): @@ -761,69 +738,63 @@ class TestA2aAgentExecutor: test_message.parts = [Mock(spec=TextPart)] # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + # Test with failed state - should preserve failed state + mock_aggregator.task_state = TaskState.failed + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) - - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] - - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - # Test with failed state - should preserve failed state - mock_aggregator.task_state = TaskState.failed - mock_aggregator.task_status_message = test_message - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify final event preserves the non-working state - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.message == test_message - # When aggregator state is failed (not working), final event should keep failed state - assert final_event.status.state == TaskState.failed + # Verify final event preserves the non-working state + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.message == test_message + # When aggregator state is failed (not working), final event should keep failed state + assert final_event.status.state == TaskState.failed @pytest.mark.asyncio async def test_handle_request_with_working_state_publishes_artifact_and_completed( @@ -846,84 +817,77 @@ class TestA2aAgentExecutor: test_message.parts = [Part(root=TextPart(text="test content"))] # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + # Test with working state - should publish artifact update and completed status + mock_aggregator.task_state = TaskState.working + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Verify artifact update event was published + artifact_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "artifact") and call[0][0].last_chunk == True + ] + assert len(artifact_events) == 1 + artifact_event = artifact_events[0] + assert artifact_event.task_id == "test-task-id" + assert artifact_event.context_id == "test-context-id" + # Check that artifact parts correspond to message parts + assert len(artifact_event.artifact.parts) == len(test_message.parts) + assert artifact_event.artifact.parts == test_message.parts - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] - - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - # Test with working state - should publish artifact update and completed status - mock_aggregator.task_state = TaskState.working - mock_aggregator.task_status_message = test_message - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify artifact update event was published - artifact_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "artifact") - and call[0][0].last_chunk == True - ] - assert len(artifact_events) == 1 - artifact_event = artifact_events[0] - assert artifact_event.task_id == "test-task-id" - assert artifact_event.context_id == "test-context-id" - # Check that artifact parts correspond to message parts - assert len(artifact_event.artifact.parts) == len(test_message.parts) - assert artifact_event.artifact.parts == test_message.parts - - # Verify final status event was published with completed state - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.state == TaskState.completed - assert final_event.task_id == "test-task-id" - assert final_event.context_id == "test-context-id" + # Verify final status event was published with completed state + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.state == TaskState.completed + assert final_event.task_id == "test-task-id" + assert final_event.context_id == "test-context-id" @pytest.mark.asyncio async def test_handle_request_with_non_working_state_publishes_status_only( @@ -946,76 +910,69 @@ class TestA2aAgentExecutor: test_message.parts = [Part(root=TextPart(text="test content"))] # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + self.mock_event_converter.return_value = [Mock()] + with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" - ) as mock_convert: - mock_convert.return_value = { - "user_id": "test-user", - "session_id": "test-session", - "new_message": Mock(), - "run_config": Mock(), - } + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + # Test with auth_required state - should publish only status event + mock_aggregator.task_state = TaskState.auth_required + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator - # Mock session service - mock_session = Mock() - mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( - return_value=mock_session + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue ) - # Mock invocation context - mock_invocation_context = Mock() - self.mock_runner._new_invocation_context.return_value = ( - mock_invocation_context - ) + # Verify no artifact update event was published + artifact_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "artifact") and call[0][0].last_chunk == True + ] + assert len(artifact_events) == 0 - # Mock agent run with multiple events using proper async generator - mock_events = [Mock(spec=Event), Mock(spec=Event)] - - # Configure run_async to return the async generator when awaited - async def mock_run_async(**kwargs): - async for item in self._create_async_generator(mock_events): - yield item - - self.mock_runner.run_async = mock_run_async - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" - ) as mock_convert_events: - mock_convert_events.return_value = [Mock()] - - with patch( - "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" - ) as mock_aggregator_class: - mock_aggregator = Mock() - # Test with auth_required state - should publish only status event - mock_aggregator.task_state = TaskState.auth_required - mock_aggregator.task_status_message = test_message - mock_aggregator_class.return_value = mock_aggregator - - # Execute - await self.executor._handle_request( - self.mock_context, self.mock_event_queue - ) - - # Verify no artifact update event was published - artifact_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "artifact") - and call[0][0].last_chunk == True - ] - assert len(artifact_events) == 0 - - # Verify final status event was published with the actual state and message - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.state == TaskState.auth_required - assert final_event.status.message == test_message - assert final_event.task_id == "test-task-id" - assert final_event.context_id == "test-context-id" + # Verify final status event was published with the actual state and message + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.state == TaskState.auth_required + assert final_event.status.message == test_message + assert final_event.task_id == "test-task-id" + assert final_event.context_id == "test-context-id"