chore: Add a2a task result aggregator

PiperOrigin-RevId: 775975982
This commit is contained in:
Xiang (Sean) Zhou
2025-06-25 22:01:50 -07:00
committed by Copybara-Service
parent 77b869f5e3
commit 2f55de6ded
4 changed files with 434 additions and 0 deletions
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,71 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from a2a.server.events import Event
from a2a.types import Message
from a2a.types import TaskState
from a2a.types import TaskStatusUpdateEvent
from ...utils.feature_decorator import working_in_progress
@working_in_progress
class TaskResultAggregator:
"""Aggregates the task status updates and provides the final task state."""
def __init__(self):
self._task_state = TaskState.working
self._task_status_message = None
def process_event(self, event: Event):
"""Process an event from the agent run and detect signals about the task status.
Priority of task state:
- failed
- auth_required
- input_required
- working
"""
if isinstance(event, TaskStatusUpdateEvent):
if event.status.state == TaskState.failed:
self._task_state = TaskState.failed
self._task_status_message = event.status.message
elif (
event.status.state == TaskState.auth_required
and self._task_state != TaskState.failed
):
self._task_state = TaskState.auth_required
self._task_status_message = event.status.message
elif (
event.status.state == TaskState.input_required
and self._task_state
not in (TaskState.failed, TaskState.auth_required)
):
self._task_state = TaskState.input_required
self._task_status_message = event.status.message
# final state is already recorded and make sure the intermediate state is
# always working because other state may terminate the event aggregation
# in a2a request handler
elif self._task_state == TaskState.working:
self._task_status_message = event.status.message
event.status.state = TaskState.working
@property
def task_state(self) -> TaskState:
return self._task_state
@property
def task_status_message(self) -> Message | None:
return self._task_status_message
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,337 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unittest.mock import Mock
import pytest
# Skip all tests in this module if Python version is less than 3.10
pytestmark = pytest.mark.skipif(
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
)
# Import dependencies with version checking
try:
from a2a.types import Message
from a2a.types import Part
from a2a.types import Role
from a2a.types import TaskState
from a2a.types import TaskStatus
from a2a.types import TaskStatusUpdateEvent
from a2a.types import TextPart
from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator
except ImportError as e:
if sys.version_info < (3, 10):
# Create dummy classes to prevent NameError during test collection
# Tests will be skipped anyway due to pytestmark
class DummyTypes:
pass
TaskState = DummyTypes()
TaskStatus = DummyTypes()
TaskStatusUpdateEvent = DummyTypes()
TaskResultAggregator = DummyTypes()
else:
raise e
def create_test_message(text: str) -> Message:
"""Helper function to create a test Message object."""
return Message(
messageId="test-msg",
role=Role.agent,
parts=[Part(root=TextPart(text=text))],
)
class TestTaskResultAggregator:
"""Test suite for TaskResultAggregator class."""
def setup_method(self):
"""Set up test fixtures."""
self.aggregator = TaskResultAggregator()
def test_initial_state(self):
"""Test the initial state of the aggregator."""
assert self.aggregator.task_state == TaskState.working
assert self.aggregator.task_status_message is None
def test_process_failed_event(self):
"""Test processing a failed event."""
status_message = create_test_message("Failed to process")
event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.failed, message=status_message),
final=True,
)
self.aggregator.process_event(event)
assert self.aggregator.task_state == TaskState.failed
assert self.aggregator.task_status_message == status_message
# Verify the event state was modified to working
assert event.status.state == TaskState.working
def test_process_auth_required_event(self):
"""Test processing an auth_required event."""
status_message = create_test_message("Authentication needed")
event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(
state=TaskState.auth_required, message=status_message
),
final=False,
)
self.aggregator.process_event(event)
assert self.aggregator.task_state == TaskState.auth_required
assert self.aggregator.task_status_message == status_message
# Verify the event state was modified to working
assert event.status.state == TaskState.working
def test_process_input_required_event(self):
"""Test processing an input_required event."""
status_message = create_test_message("Input required")
event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(
state=TaskState.input_required, message=status_message
),
final=False,
)
self.aggregator.process_event(event)
assert self.aggregator.task_state == TaskState.input_required
assert self.aggregator.task_status_message == status_message
# Verify the event state was modified to working
assert event.status.state == TaskState.working
def test_status_message_with_none_message(self):
"""Test that status message handles None message properly."""
event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.failed, message=None),
final=True,
)
self.aggregator.process_event(event)
assert self.aggregator.task_state == TaskState.failed
assert self.aggregator.task_status_message is None
def test_priority_order_failed_over_auth(self):
"""Test that failed state takes priority over auth_required."""
# First set auth_required
auth_message = create_test_message("Auth required")
auth_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.auth_required, message=auth_message),
final=False,
)
self.aggregator.process_event(auth_event)
assert self.aggregator.task_state == TaskState.auth_required
assert self.aggregator.task_status_message == auth_message
# Then process failed - should override
failed_message = create_test_message("Failed")
failed_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.failed, message=failed_message),
final=True,
)
self.aggregator.process_event(failed_event)
assert self.aggregator.task_state == TaskState.failed
assert self.aggregator.task_status_message == failed_message
def test_priority_order_auth_over_input(self):
"""Test that auth_required state takes priority over input_required."""
# First set input_required
input_message = create_test_message("Input needed")
input_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(
state=TaskState.input_required, message=input_message
),
final=False,
)
self.aggregator.process_event(input_event)
assert self.aggregator.task_state == TaskState.input_required
assert self.aggregator.task_status_message == input_message
# Then process auth_required - should override
auth_message = create_test_message("Auth needed")
auth_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.auth_required, message=auth_message),
final=False,
)
self.aggregator.process_event(auth_event)
assert self.aggregator.task_state == TaskState.auth_required
assert self.aggregator.task_status_message == auth_message
def test_ignore_non_status_update_events(self):
"""Test that non-TaskStatusUpdateEvent events are ignored."""
mock_event = Mock()
initial_state = self.aggregator.task_state
initial_message = self.aggregator.task_status_message
self.aggregator.process_event(mock_event)
# State should remain unchanged
assert self.aggregator.task_state == initial_state
assert self.aggregator.task_status_message == initial_message
def test_working_state_does_not_override_higher_priority(self):
"""Test that working state doesn't override higher priority states."""
# First set failed state
failed_message = create_test_message("Failure message")
failed_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.failed, message=failed_message),
final=True,
)
self.aggregator.process_event(failed_event)
assert self.aggregator.task_state == TaskState.failed
assert self.aggregator.task_status_message == failed_message
# Then process working - should not override state and should not update message
# because the current task state is not working
working_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.working),
final=False,
)
self.aggregator.process_event(working_event)
assert self.aggregator.task_state == TaskState.failed
# Working events don't update the status message when task state is not working
assert self.aggregator.task_status_message == failed_message
def test_status_message_priority_ordering(self):
"""Test that status messages follow the same priority ordering as states."""
# Start with input_required
input_message = create_test_message("Input message")
input_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(
state=TaskState.input_required, message=input_message
),
final=False,
)
self.aggregator.process_event(input_event)
assert self.aggregator.task_status_message == input_message
# Override with auth_required
auth_message = create_test_message("Auth message")
auth_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.auth_required, message=auth_message),
final=False,
)
self.aggregator.process_event(auth_event)
assert self.aggregator.task_status_message == auth_message
# Override with failed
failed_message = create_test_message("Failed message")
failed_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.failed, message=failed_message),
final=True,
)
self.aggregator.process_event(failed_event)
assert self.aggregator.task_status_message == failed_message
# Working should not override failed message because current task state is failed
working_message = create_test_message("Working message")
working_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.working, message=working_message),
final=False,
)
self.aggregator.process_event(working_event)
# State should still be failed, and message should remain the failed message
# because working events only update message when task state is working
assert self.aggregator.task_state == TaskState.failed
assert self.aggregator.task_status_message == failed_message
def test_process_working_event_updates_message(self):
"""Test that working state events update the status message."""
working_message = create_test_message("Working on task")
event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.working, message=working_message),
final=False,
)
self.aggregator.process_event(event)
assert self.aggregator.task_state == TaskState.working
assert self.aggregator.task_status_message == working_message
# Verify the event state was modified to working (should remain working)
assert event.status.state == TaskState.working
def test_working_event_with_none_message(self):
"""Test that working state events handle None message properly."""
event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.working, message=None),
final=False,
)
self.aggregator.process_event(event)
assert self.aggregator.task_state == TaskState.working
assert self.aggregator.task_status_message is None
def test_working_event_updates_message_regardless_of_state(self):
"""Test that working events update message only when current task state is working."""
# First set auth_required state
auth_message = create_test_message("Auth required")
auth_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.auth_required, message=auth_message),
final=False,
)
self.aggregator.process_event(auth_event)
assert self.aggregator.task_state == TaskState.auth_required
assert self.aggregator.task_status_message == auth_message
# Then process working - should not update message because task state is not working
working_message = create_test_message("Working on auth")
working_event = TaskStatusUpdateEvent(
taskId="test-task",
contextId="test-context",
status=TaskStatus(state=TaskState.working, message=working_message),
final=False,
)
self.aggregator.process_event(working_event)
assert (
self.aggregator.task_state == TaskState.auth_required
) # State unchanged
assert (
self.aggregator.task_status_message == auth_message
) # Message unchanged because task state is not working