You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Add a2a task result aggregator
PiperOrigin-RevId: 775975982
This commit is contained in:
committed by
Copybara-Service
parent
77b869f5e3
commit
2f55de6ded
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from a2a.server.events import Event
|
||||
from a2a.types import Message
|
||||
from a2a.types import TaskState
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
|
||||
from ...utils.feature_decorator import working_in_progress
|
||||
|
||||
|
||||
@working_in_progress
|
||||
class TaskResultAggregator:
|
||||
"""Aggregates the task status updates and provides the final task state."""
|
||||
|
||||
def __init__(self):
|
||||
self._task_state = TaskState.working
|
||||
self._task_status_message = None
|
||||
|
||||
def process_event(self, event: Event):
|
||||
"""Process an event from the agent run and detect signals about the task status.
|
||||
Priority of task state:
|
||||
- failed
|
||||
- auth_required
|
||||
- input_required
|
||||
- working
|
||||
"""
|
||||
if isinstance(event, TaskStatusUpdateEvent):
|
||||
if event.status.state == TaskState.failed:
|
||||
self._task_state = TaskState.failed
|
||||
self._task_status_message = event.status.message
|
||||
elif (
|
||||
event.status.state == TaskState.auth_required
|
||||
and self._task_state != TaskState.failed
|
||||
):
|
||||
self._task_state = TaskState.auth_required
|
||||
self._task_status_message = event.status.message
|
||||
elif (
|
||||
event.status.state == TaskState.input_required
|
||||
and self._task_state
|
||||
not in (TaskState.failed, TaskState.auth_required)
|
||||
):
|
||||
self._task_state = TaskState.input_required
|
||||
self._task_status_message = event.status.message
|
||||
# final state is already recorded and make sure the intermediate state is
|
||||
# always working because other state may terminate the event aggregation
|
||||
# in a2a request handler
|
||||
elif self._task_state == TaskState.working:
|
||||
self._task_status_message = event.status.message
|
||||
event.status.state = TaskState.working
|
||||
|
||||
@property
|
||||
def task_state(self) -> TaskState:
|
||||
return self._task_state
|
||||
|
||||
@property
|
||||
def task_status_message(self) -> Message | None:
|
||||
return self._task_status_message
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,337 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip all tests in this module if Python version is less than 3.10
|
||||
pytestmark = pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
|
||||
)
|
||||
|
||||
# Import dependencies with version checking
|
||||
try:
|
||||
from a2a.types import Message
|
||||
from a2a.types import Part
|
||||
from a2a.types import Role
|
||||
from a2a.types import TaskState
|
||||
from a2a.types import TaskStatus
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
from a2a.types import TextPart
|
||||
from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator
|
||||
except ImportError as e:
|
||||
if sys.version_info < (3, 10):
|
||||
# Create dummy classes to prevent NameError during test collection
|
||||
# Tests will be skipped anyway due to pytestmark
|
||||
class DummyTypes:
|
||||
pass
|
||||
|
||||
TaskState = DummyTypes()
|
||||
TaskStatus = DummyTypes()
|
||||
TaskStatusUpdateEvent = DummyTypes()
|
||||
TaskResultAggregator = DummyTypes()
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def create_test_message(text: str) -> Message:
|
||||
"""Helper function to create a test Message object."""
|
||||
return Message(
|
||||
messageId="test-msg",
|
||||
role=Role.agent,
|
||||
parts=[Part(root=TextPart(text=text))],
|
||||
)
|
||||
|
||||
|
||||
class TestTaskResultAggregator:
|
||||
"""Test suite for TaskResultAggregator class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.aggregator = TaskResultAggregator()
|
||||
|
||||
def test_initial_state(self):
|
||||
"""Test the initial state of the aggregator."""
|
||||
assert self.aggregator.task_state == TaskState.working
|
||||
assert self.aggregator.task_status_message is None
|
||||
|
||||
def test_process_failed_event(self):
|
||||
"""Test processing a failed event."""
|
||||
status_message = create_test_message("Failed to process")
|
||||
event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.failed, message=status_message),
|
||||
final=True,
|
||||
)
|
||||
|
||||
self.aggregator.process_event(event)
|
||||
assert self.aggregator.task_state == TaskState.failed
|
||||
assert self.aggregator.task_status_message == status_message
|
||||
# Verify the event state was modified to working
|
||||
assert event.status.state == TaskState.working
|
||||
|
||||
def test_process_auth_required_event(self):
|
||||
"""Test processing an auth_required event."""
|
||||
status_message = create_test_message("Authentication needed")
|
||||
event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(
|
||||
state=TaskState.auth_required, message=status_message
|
||||
),
|
||||
final=False,
|
||||
)
|
||||
|
||||
self.aggregator.process_event(event)
|
||||
assert self.aggregator.task_state == TaskState.auth_required
|
||||
assert self.aggregator.task_status_message == status_message
|
||||
# Verify the event state was modified to working
|
||||
assert event.status.state == TaskState.working
|
||||
|
||||
def test_process_input_required_event(self):
|
||||
"""Test processing an input_required event."""
|
||||
status_message = create_test_message("Input required")
|
||||
event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(
|
||||
state=TaskState.input_required, message=status_message
|
||||
),
|
||||
final=False,
|
||||
)
|
||||
|
||||
self.aggregator.process_event(event)
|
||||
assert self.aggregator.task_state == TaskState.input_required
|
||||
assert self.aggregator.task_status_message == status_message
|
||||
# Verify the event state was modified to working
|
||||
assert event.status.state == TaskState.working
|
||||
|
||||
def test_status_message_with_none_message(self):
|
||||
"""Test that status message handles None message properly."""
|
||||
event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.failed, message=None),
|
||||
final=True,
|
||||
)
|
||||
|
||||
self.aggregator.process_event(event)
|
||||
assert self.aggregator.task_state == TaskState.failed
|
||||
assert self.aggregator.task_status_message is None
|
||||
|
||||
def test_priority_order_failed_over_auth(self):
|
||||
"""Test that failed state takes priority over auth_required."""
|
||||
# First set auth_required
|
||||
auth_message = create_test_message("Auth required")
|
||||
auth_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.auth_required, message=auth_message),
|
||||
final=False,
|
||||
)
|
||||
self.aggregator.process_event(auth_event)
|
||||
assert self.aggregator.task_state == TaskState.auth_required
|
||||
assert self.aggregator.task_status_message == auth_message
|
||||
|
||||
# Then process failed - should override
|
||||
failed_message = create_test_message("Failed")
|
||||
failed_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.failed, message=failed_message),
|
||||
final=True,
|
||||
)
|
||||
self.aggregator.process_event(failed_event)
|
||||
assert self.aggregator.task_state == TaskState.failed
|
||||
assert self.aggregator.task_status_message == failed_message
|
||||
|
||||
def test_priority_order_auth_over_input(self):
|
||||
"""Test that auth_required state takes priority over input_required."""
|
||||
# First set input_required
|
||||
input_message = create_test_message("Input needed")
|
||||
input_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(
|
||||
state=TaskState.input_required, message=input_message
|
||||
),
|
||||
final=False,
|
||||
)
|
||||
self.aggregator.process_event(input_event)
|
||||
assert self.aggregator.task_state == TaskState.input_required
|
||||
assert self.aggregator.task_status_message == input_message
|
||||
|
||||
# Then process auth_required - should override
|
||||
auth_message = create_test_message("Auth needed")
|
||||
auth_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.auth_required, message=auth_message),
|
||||
final=False,
|
||||
)
|
||||
self.aggregator.process_event(auth_event)
|
||||
assert self.aggregator.task_state == TaskState.auth_required
|
||||
assert self.aggregator.task_status_message == auth_message
|
||||
|
||||
def test_ignore_non_status_update_events(self):
|
||||
"""Test that non-TaskStatusUpdateEvent events are ignored."""
|
||||
mock_event = Mock()
|
||||
|
||||
initial_state = self.aggregator.task_state
|
||||
initial_message = self.aggregator.task_status_message
|
||||
self.aggregator.process_event(mock_event)
|
||||
|
||||
# State should remain unchanged
|
||||
assert self.aggregator.task_state == initial_state
|
||||
assert self.aggregator.task_status_message == initial_message
|
||||
|
||||
def test_working_state_does_not_override_higher_priority(self):
|
||||
"""Test that working state doesn't override higher priority states."""
|
||||
# First set failed state
|
||||
failed_message = create_test_message("Failure message")
|
||||
failed_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.failed, message=failed_message),
|
||||
final=True,
|
||||
)
|
||||
self.aggregator.process_event(failed_event)
|
||||
assert self.aggregator.task_state == TaskState.failed
|
||||
assert self.aggregator.task_status_message == failed_message
|
||||
|
||||
# Then process working - should not override state and should not update message
|
||||
# because the current task state is not working
|
||||
working_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.working),
|
||||
final=False,
|
||||
)
|
||||
self.aggregator.process_event(working_event)
|
||||
assert self.aggregator.task_state == TaskState.failed
|
||||
# Working events don't update the status message when task state is not working
|
||||
assert self.aggregator.task_status_message == failed_message
|
||||
|
||||
def test_status_message_priority_ordering(self):
|
||||
"""Test that status messages follow the same priority ordering as states."""
|
||||
# Start with input_required
|
||||
input_message = create_test_message("Input message")
|
||||
input_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(
|
||||
state=TaskState.input_required, message=input_message
|
||||
),
|
||||
final=False,
|
||||
)
|
||||
self.aggregator.process_event(input_event)
|
||||
assert self.aggregator.task_status_message == input_message
|
||||
|
||||
# Override with auth_required
|
||||
auth_message = create_test_message("Auth message")
|
||||
auth_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.auth_required, message=auth_message),
|
||||
final=False,
|
||||
)
|
||||
self.aggregator.process_event(auth_event)
|
||||
assert self.aggregator.task_status_message == auth_message
|
||||
|
||||
# Override with failed
|
||||
failed_message = create_test_message("Failed message")
|
||||
failed_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.failed, message=failed_message),
|
||||
final=True,
|
||||
)
|
||||
self.aggregator.process_event(failed_event)
|
||||
assert self.aggregator.task_status_message == failed_message
|
||||
|
||||
# Working should not override failed message because current task state is failed
|
||||
working_message = create_test_message("Working message")
|
||||
working_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.working, message=working_message),
|
||||
final=False,
|
||||
)
|
||||
self.aggregator.process_event(working_event)
|
||||
# State should still be failed, and message should remain the failed message
|
||||
# because working events only update message when task state is working
|
||||
assert self.aggregator.task_state == TaskState.failed
|
||||
assert self.aggregator.task_status_message == failed_message
|
||||
|
||||
def test_process_working_event_updates_message(self):
|
||||
"""Test that working state events update the status message."""
|
||||
working_message = create_test_message("Working on task")
|
||||
event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.working, message=working_message),
|
||||
final=False,
|
||||
)
|
||||
|
||||
self.aggregator.process_event(event)
|
||||
assert self.aggregator.task_state == TaskState.working
|
||||
assert self.aggregator.task_status_message == working_message
|
||||
# Verify the event state was modified to working (should remain working)
|
||||
assert event.status.state == TaskState.working
|
||||
|
||||
def test_working_event_with_none_message(self):
|
||||
"""Test that working state events handle None message properly."""
|
||||
event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.working, message=None),
|
||||
final=False,
|
||||
)
|
||||
|
||||
self.aggregator.process_event(event)
|
||||
assert self.aggregator.task_state == TaskState.working
|
||||
assert self.aggregator.task_status_message is None
|
||||
|
||||
def test_working_event_updates_message_regardless_of_state(self):
|
||||
"""Test that working events update message only when current task state is working."""
|
||||
# First set auth_required state
|
||||
auth_message = create_test_message("Auth required")
|
||||
auth_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.auth_required, message=auth_message),
|
||||
final=False,
|
||||
)
|
||||
self.aggregator.process_event(auth_event)
|
||||
assert self.aggregator.task_state == TaskState.auth_required
|
||||
assert self.aggregator.task_status_message == auth_message
|
||||
|
||||
# Then process working - should not update message because task state is not working
|
||||
working_message = create_test_message("Working on auth")
|
||||
working_event = TaskStatusUpdateEvent(
|
||||
taskId="test-task",
|
||||
contextId="test-context",
|
||||
status=TaskStatus(state=TaskState.working, message=working_message),
|
||||
final=False,
|
||||
)
|
||||
self.aggregator.process_event(working_event)
|
||||
assert (
|
||||
self.aggregator.task_state == TaskState.auth_required
|
||||
) # State unchanged
|
||||
assert (
|
||||
self.aggregator.task_status_message == auth_message
|
||||
) # Message unchanged because task state is not working
|
||||
Reference in New Issue
Block a user