From 2f55de6ded26c0b55715b78fa081dbe126d968c2 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 22:01:50 -0700 Subject: [PATCH] chore: Add a2a task result aggregator PiperOrigin-RevId: 775975982 --- src/google/adk/a2a/executor/__init__.py | 13 + .../a2a/executor/task_result_aggregator.py | 71 ++++ tests/unittests/a2a/executor/__init__.py | 13 + .../executor/test_task_result_aggregator.py | 337 ++++++++++++++++++ 4 files changed, 434 insertions(+) create mode 100644 src/google/adk/a2a/executor/__init__.py create mode 100644 src/google/adk/a2a/executor/task_result_aggregator.py create mode 100644 tests/unittests/a2a/executor/__init__.py create mode 100644 tests/unittests/a2a/executor/test_task_result_aggregator.py diff --git a/src/google/adk/a2a/executor/__init__.py b/src/google/adk/a2a/executor/__init__.py new file mode 100644 index 00000000..0a2669d7 --- /dev/null +++ b/src/google/adk/a2a/executor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/a2a/executor/task_result_aggregator.py b/src/google/adk/a2a/executor/task_result_aggregator.py new file mode 100644 index 00000000..609de40e --- /dev/null +++ b/src/google/adk/a2a/executor/task_result_aggregator.py @@ -0,0 +1,71 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from a2a.server.events import Event +from a2a.types import Message +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent + +from ...utils.feature_decorator import working_in_progress + + +@working_in_progress +class TaskResultAggregator: + """Aggregates the task status updates and provides the final task state.""" + + def __init__(self): + self._task_state = TaskState.working + self._task_status_message = None + + def process_event(self, event: Event): + """Process an event from the agent run and detect signals about the task status. + Priority of task state: + - failed + - auth_required + - input_required + - working + """ + if isinstance(event, TaskStatusUpdateEvent): + if event.status.state == TaskState.failed: + self._task_state = TaskState.failed + self._task_status_message = event.status.message + elif ( + event.status.state == TaskState.auth_required + and self._task_state != TaskState.failed + ): + self._task_state = TaskState.auth_required + self._task_status_message = event.status.message + elif ( + event.status.state == TaskState.input_required + and self._task_state + not in (TaskState.failed, TaskState.auth_required) + ): + self._task_state = TaskState.input_required + self._task_status_message = event.status.message + # final state is already recorded and make sure the intermediate state is + # always working because other state may terminate the event aggregation + # in a2a request handler + elif self._task_state == TaskState.working: + self._task_status_message = event.status.message + event.status.state = TaskState.working + + @property + def task_state(self) -> TaskState: + return self._task_state + + @property + def task_status_message(self) -> Message | None: + return self._task_status_message diff --git a/tests/unittests/a2a/executor/__init__.py b/tests/unittests/a2a/executor/__init__.py new file mode 100644 index 00000000..0a2669d7 --- /dev/null +++ b/tests/unittests/a2a/executor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/executor/test_task_result_aggregator.py b/tests/unittests/a2a/executor/test_task_result_aggregator.py new file mode 100644 index 00000000..b808cf0c --- /dev/null +++ b/tests/unittests/a2a/executor/test_task_result_aggregator.py @@ -0,0 +1,337 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import Mock + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a.types import Message + from a2a.types import Part + from a2a.types import Role + from a2a.types import TaskState + from a2a.types import TaskStatus + from a2a.types import TaskStatusUpdateEvent + from a2a.types import TextPart + from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + TaskState = DummyTypes() + TaskStatus = DummyTypes() + TaskStatusUpdateEvent = DummyTypes() + TaskResultAggregator = DummyTypes() + else: + raise e + + +def create_test_message(text: str) -> Message: + """Helper function to create a test Message object.""" + return Message( + messageId="test-msg", + role=Role.agent, + parts=[Part(root=TextPart(text=text))], + ) + + +class TestTaskResultAggregator: + """Test suite for TaskResultAggregator class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.aggregator = TaskResultAggregator() + + def test_initial_state(self): + """Test the initial state of the aggregator.""" + assert self.aggregator.task_state == TaskState.working + assert self.aggregator.task_status_message is None + + def test_process_failed_event(self): + """Test processing a failed event.""" + status_message = create_test_message("Failed to process") + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=status_message), + final=True, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message == status_message + # Verify the event state was modified to working + assert event.status.state == TaskState.working + + def test_process_auth_required_event(self): + """Test processing an auth_required event.""" + status_message = create_test_message("Authentication needed") + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus( + state=TaskState.auth_required, message=status_message + ), + final=False, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.auth_required + assert self.aggregator.task_status_message == status_message + # Verify the event state was modified to working + assert event.status.state == TaskState.working + + def test_process_input_required_event(self): + """Test processing an input_required event.""" + status_message = create_test_message("Input required") + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus( + state=TaskState.input_required, message=status_message + ), + final=False, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.input_required + assert self.aggregator.task_status_message == status_message + # Verify the event state was modified to working + assert event.status.state == TaskState.working + + def test_status_message_with_none_message(self): + """Test that status message handles None message properly.""" + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=None), + final=True, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message is None + + def test_priority_order_failed_over_auth(self): + """Test that failed state takes priority over auth_required.""" + # First set auth_required + auth_message = create_test_message("Auth required") + auth_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.auth_required, message=auth_message), + final=False, + ) + self.aggregator.process_event(auth_event) + assert self.aggregator.task_state == TaskState.auth_required + assert self.aggregator.task_status_message == auth_message + + # Then process failed - should override + failed_message = create_test_message("Failed") + failed_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=failed_message), + final=True, + ) + self.aggregator.process_event(failed_event) + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message == failed_message + + def test_priority_order_auth_over_input(self): + """Test that auth_required state takes priority over input_required.""" + # First set input_required + input_message = create_test_message("Input needed") + input_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus( + state=TaskState.input_required, message=input_message + ), + final=False, + ) + self.aggregator.process_event(input_event) + assert self.aggregator.task_state == TaskState.input_required + assert self.aggregator.task_status_message == input_message + + # Then process auth_required - should override + auth_message = create_test_message("Auth needed") + auth_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.auth_required, message=auth_message), + final=False, + ) + self.aggregator.process_event(auth_event) + assert self.aggregator.task_state == TaskState.auth_required + assert self.aggregator.task_status_message == auth_message + + def test_ignore_non_status_update_events(self): + """Test that non-TaskStatusUpdateEvent events are ignored.""" + mock_event = Mock() + + initial_state = self.aggregator.task_state + initial_message = self.aggregator.task_status_message + self.aggregator.process_event(mock_event) + + # State should remain unchanged + assert self.aggregator.task_state == initial_state + assert self.aggregator.task_status_message == initial_message + + def test_working_state_does_not_override_higher_priority(self): + """Test that working state doesn't override higher priority states.""" + # First set failed state + failed_message = create_test_message("Failure message") + failed_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=failed_message), + final=True, + ) + self.aggregator.process_event(failed_event) + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message == failed_message + + # Then process working - should not override state and should not update message + # because the current task state is not working + working_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working), + final=False, + ) + self.aggregator.process_event(working_event) + assert self.aggregator.task_state == TaskState.failed + # Working events don't update the status message when task state is not working + assert self.aggregator.task_status_message == failed_message + + def test_status_message_priority_ordering(self): + """Test that status messages follow the same priority ordering as states.""" + # Start with input_required + input_message = create_test_message("Input message") + input_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus( + state=TaskState.input_required, message=input_message + ), + final=False, + ) + self.aggregator.process_event(input_event) + assert self.aggregator.task_status_message == input_message + + # Override with auth_required + auth_message = create_test_message("Auth message") + auth_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.auth_required, message=auth_message), + final=False, + ) + self.aggregator.process_event(auth_event) + assert self.aggregator.task_status_message == auth_message + + # Override with failed + failed_message = create_test_message("Failed message") + failed_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=failed_message), + final=True, + ) + self.aggregator.process_event(failed_event) + assert self.aggregator.task_status_message == failed_message + + # Working should not override failed message because current task state is failed + working_message = create_test_message("Working message") + working_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working, message=working_message), + final=False, + ) + self.aggregator.process_event(working_event) + # State should still be failed, and message should remain the failed message + # because working events only update message when task state is working + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message == failed_message + + def test_process_working_event_updates_message(self): + """Test that working state events update the status message.""" + working_message = create_test_message("Working on task") + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working, message=working_message), + final=False, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.working + assert self.aggregator.task_status_message == working_message + # Verify the event state was modified to working (should remain working) + assert event.status.state == TaskState.working + + def test_working_event_with_none_message(self): + """Test that working state events handle None message properly.""" + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working, message=None), + final=False, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.working + assert self.aggregator.task_status_message is None + + def test_working_event_updates_message_regardless_of_state(self): + """Test that working events update message only when current task state is working.""" + # First set auth_required state + auth_message = create_test_message("Auth required") + auth_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.auth_required, message=auth_message), + final=False, + ) + self.aggregator.process_event(auth_event) + assert self.aggregator.task_state == TaskState.auth_required + assert self.aggregator.task_status_message == auth_message + + # Then process working - should not update message because task state is not working + working_message = create_test_message("Working on auth") + working_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working, message=working_message), + final=False, + ) + self.aggregator.process_event(working_event) + assert ( + self.aggregator.task_state == TaskState.auth_required + ) # State unchanged + assert ( + self.aggregator.task_status_message == auth_message + ) # Message unchanged because task state is not working