feat: Add interceptor framework to A2aAgentExecutor

This change introduces an interceptor mechanism allowing custom logic to be executed before agent runs, after each event, and after the agent run completes. New dependencies are added to support these features.

PiperOrigin-RevId: 873952199
This commit is contained in:
Google Team Member
2026-02-23 02:15:18 -08:00
committed by Copybara-Service
parent 7557a92939
commit 87fcd77caa
5 changed files with 348 additions and 21 deletions
@@ -50,7 +50,12 @@ from ..converters.request_converter import AgentRunRequest
from ..converters.request_converter import convert_a2a_request_to_agent_run_request
from ..converters.utils import _get_adk_metadata_key
from ..experimental import a2a_experimental
from .config import ExecuteInterceptor
from .executor_context import ExecutorContext
from .task_result_aggregator import TaskResultAggregator
from .utils import execute_after_agent_interceptors
from .utils import execute_after_event_interceptors
from .utils import execute_before_agent_interceptors
logger = logging.getLogger('google_adk.' + __name__)
@@ -70,6 +75,8 @@ class A2aAgentExecutorConfig(BaseModel):
)
event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events
execute_interceptors: Optional[list[ExecuteInterceptor]] = None
@a2a_experimental
class A2aAgentExecutor(AgentExecutor):
@@ -135,6 +142,10 @@ class A2aAgentExecutor(AgentExecutor):
if not context.message:
raise ValueError('A2A request must have a message')
context = await execute_before_agent_interceptors(
context, self._config.execute_interceptors
)
# for new task, create a task submitted event
if not context.current_task:
await event_queue.enqueue_event(
@@ -202,6 +213,13 @@ class A2aAgentExecutor(AgentExecutor):
run_config=run_request.run_config,
)
self._executor_context = ExecutorContext(
app_name=runner.app_name,
user_id=run_request.user_id,
session_id=run_request.session_id,
runner=runner,
)
# publish the task working event
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
@@ -230,6 +248,15 @@ class A2aAgentExecutor(AgentExecutor):
context.context_id,
self._config.gen_ai_part_converter,
):
a2a_event = await execute_after_event_interceptors(
a2a_event,
self._executor_context,
adk_event,
self._config.execute_interceptors,
)
if a2a_event is None:
continue
task_result_aggregator.process_event(a2a_event)
await event_queue.enqueue_event(a2a_event)
@@ -253,31 +280,34 @@ class A2aAgentExecutor(AgentExecutor):
)
)
# public the final status update event
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=context.task_id,
status=TaskStatus(
state=TaskState.completed,
timestamp=datetime.now(timezone.utc).isoformat(),
),
context_id=context.context_id,
final=True,
)
final_event = TaskStatusUpdateEvent(
task_id=context.task_id,
status=TaskStatus(
state=TaskState.completed,
timestamp=datetime.now(timezone.utc).isoformat(),
),
context_id=context.context_id,
final=True,
)
else:
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=context.task_id,
status=TaskStatus(
state=task_result_aggregator.task_state,
timestamp=datetime.now(timezone.utc).isoformat(),
message=task_result_aggregator.task_status_message,
),
context_id=context.context_id,
final=True,
)
final_event = TaskStatusUpdateEvent(
task_id=context.task_id,
status=TaskStatus(
state=task_result_aggregator.task_state,
timestamp=datetime.now(timezone.utc).isoformat(),
message=task_result_aggregator.task_status_message,
),
context_id=context.context_id,
final=True,
)
final_event = await execute_after_agent_interceptors(
self._executor_context,
final_event,
self._config.execute_interceptors,
)
await event_queue.enqueue_event(final_event)
async def _prepare_session(
self,
context: RequestContext,
+69
View File
@@ -0,0 +1,69 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
from typing import Awaitable
from typing import Callable
from typing import Optional
from typing import Union
from a2a.server.agent_execution.context import RequestContext
from a2a.server.events import Event as A2AEvent
from a2a.types import TaskStatusUpdateEvent
from ...events.event import Event
from ..converters.utils import _get_adk_metadata_key
from .executor_context import ExecutorContext
@dataclasses.dataclass
class ExecuteInterceptor:
"""Interceptor for the A2aAgentExecutor."""
before_agent: Optional[
Callable[[RequestContext], Awaitable[RequestContext]]
] = None
"""Hook executed before the agent starts processing the request.
Allows inspection or modification of the incoming request context.
Must return a valid `RequestContext` to continue execution.
"""
after_event: Optional[
Callable[
[ExecutorContext, A2AEvent, Event],
Awaitable[Union[A2AEvent, None]],
]
] = None
"""Hook executed after an ADK event is converted to an A2A event.
Allows mutating the outgoing event before it is enqueued.
Return `None` to filter out and drop the event entirely,
which also halts any subsequent interceptors in the chain.
"""
after_agent: Optional[
Callable[
[ExecutorContext, TaskStatusUpdateEvent],
Awaitable[TaskStatusUpdateEvent],
]
] = None
"""Hook executed after the agent finishes and the final event is prepared.
Allows inspection or modification of the terminal status event (e.g.,
completed or failed) before it is enqueued. Must return a valid
`TaskStatusUpdateEvent`.
"""
@@ -0,0 +1,49 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.runners import Runner
class ExecutorContext:
"""Context for the executor."""
def __init__(
self,
app_name: str,
user_id: str,
session_id: str,
runner: Runner,
):
self._app_name = app_name
self._user_id = user_id
self._session_id = session_id
self._runner = runner
@property
def app_name(self) -> str:
return self._app_name
@property
def user_id(self) -> str:
return self._user_id
@property
def session_id(self) -> str:
return self._session_id
@property
def runner(self) -> Runner:
return self._runner
+67
View File
@@ -0,0 +1,67 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
from a2a.server.agent_execution.context import RequestContext
from a2a.server.events import Event as A2AEvent
from a2a.types import TaskStatusUpdateEvent
from ...events.event import Event
from ..converters.utils import _get_adk_metadata_key
from .config import ExecuteInterceptor
from .executor_context import ExecutorContext
async def execute_before_agent_interceptors(
context: RequestContext,
execute_interceptors: Optional[list[ExecuteInterceptor]],
) -> RequestContext:
if execute_interceptors:
for interceptor in execute_interceptors:
if interceptor.before_agent:
context = await interceptor.before_agent(context)
return context
async def execute_after_event_interceptors(
a2a_event: A2AEvent,
executor_context: ExecutorContext,
adk_event: Event,
execute_interceptors: Optional[list[ExecuteInterceptor]],
) -> Optional[A2AEvent]:
if execute_interceptors:
for interceptor in execute_interceptors:
if interceptor.after_event:
a2a_event = await interceptor.after_event(
executor_context, a2a_event, adk_event
)
if a2a_event is None:
return None
return a2a_event
async def execute_after_agent_interceptors(
executor_context: ExecutorContext,
final_event: TaskStatusUpdateEvent,
execute_interceptors: Optional[list[ExecuteInterceptor]],
) -> TaskStatusUpdateEvent:
if execute_interceptors:
for interceptor in reversed(execute_interceptors):
if interceptor.after_agent:
final_event = await interceptor.after_agent(
executor_context, final_event
)
return final_event
@@ -17,13 +17,17 @@ from unittest.mock import Mock
from unittest.mock import patch
from a2a.server.agent_execution.context import RequestContext
from a2a.server.events import Event as A2AEvent
from a2a.server.events.event_queue import EventQueue
from a2a.types import Message
from a2a.types import Part
from a2a.types import Role
from a2a.types import TaskState
from a2a.types import TextPart
from google.adk.a2a.converters.request_converter import AgentRunRequest
from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig
from google.adk.a2a.executor.config import ExecuteInterceptor
from google.adk.events.event import Event
from google.adk.runners import RunConfig
from google.adk.runners import Runner
@@ -959,3 +963,111 @@ class TestA2aAgentExecutor:
assert final_event.status.message == test_message
assert final_event.task_id == "test-task-id"
assert final_event.context_id == "test-context-id"
@pytest.mark.asyncio
async def test_after_event_interceptors_receive_correct_arguments_and_can_modify_event(
self,
):
"""Test that after_event interceptors receive correct arguments and can modify the event."""
# Create distinct mock objects for ADK event and A2A event
adk_event = Mock(spec=Event, name="ADK_EVENT")
a2a_event = Mock(spec=A2AEvent, name="A2A_EVENT")
modified_a2a_event = Mock(spec=A2AEvent, name="MODIFIED_A2A_EVENT")
# Mocks for conversion
self.mock_event_converter.return_value = [a2a_event]
self.mock_request_converter.return_value = AgentRunRequest(
user_id="test-user",
session_id="test-session",
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)
# Setup Interceptor
mock_interceptor = Mock(spec=ExecuteInterceptor)
# after_event should return the modified event
async def side_effect_after_event(context, event, original_event):
return modified_a2a_event
mock_interceptor.after_event = AsyncMock(
side_effect=side_effect_after_event
)
mock_interceptor.before_agent = None
mock_interceptor.after_agent = None
# Update config with interceptor
self.mock_config.execute_interceptors = [mock_interceptor]
# Re-initialize executor with updated config - but we can just update
# the config in place if it's mutable
# The executor uses self._config which is this mock_config basically.
# self.executor was initialized in setup_method with self.mock_config.
# However, A2aAgentExecutor constructor does: self._config = config or ...
# So updating self.mock_config properties should work as
# it is the same object reference.
# Mock context
self.mock_context.task_id = "task-1"
self.mock_context.context_id = "ctx-1"
# Ensure current_task is set so we skip the initial
# submitted event creation logic
# which might complicate this specific test if we don't care about it.
self.mock_context.current_task = Mock()
# Mock runner.run_async to yield our ADK event
async def mock_run_async(**kwargs):
async for item in self._create_async_generator([adk_event]):
yield item
self.mock_runner.run_async = mock_run_async
# Configure session service
mock_session = Mock()
mock_session.id = "test-session"
self.mock_runner.session_service.get_session = AsyncMock(
return_value=mock_session
)
self.mock_runner._new_invocation_context.return_value = Mock()
# We patch TaskResultAggregator just to avoid other errors and simplfy
with patch(
"google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator"
) as mock_agg_class:
mock_agg = Mock()
mock_agg.task_status_message = None
mock_agg.task_state = TaskState.working
mock_agg_class.return_value = mock_agg
await self.executor.execute(self.mock_context, self.mock_event_queue)
# Verify aggregator processed the MODIFIED event
mock_agg.process_event.assert_called_with(modified_a2a_event)
# Verification of arguments passed to interceptor
assert mock_interceptor.after_event.called
call_args = mock_interceptor.after_event.call_args
# call_args.args should be (executor_context, a2a_event, adk_event)
passed_a2a_event = call_args.args[1]
passed_adk_event = call_args.args[2]
# These assertions verify the bug fix
assert (
passed_a2a_event is a2a_event
), f"Expected A2A event to be passed as 2nd arg, but got {passed_a2a_event}"
assert (
passed_adk_event is adk_event
), f"Expected ADK event to be passed as 3rd arg, but got {passed_adk_event}"
# Verify that the modified event was enqueued
# We check if enqueue_event was called with modified_a2a_event
# Note: enqueue_event is called multiple times.
enqueued_events = [
call[0][0]
for call in self.mock_event_queue.enqueue_event.call_args_list
]
assert (
modified_a2a_event in enqueued_events
), "The modified event should have been enqueued"