From 87fcd77caa9672f219c12e5a0e2ff65cbbaaf6f3 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 23 Feb 2026 02:15:18 -0800 Subject: [PATCH] feat: Add interceptor framework to A2aAgentExecutor This change introduces an interceptor mechanism allowing custom logic to be executed before agent runs, after each event, and after the agent run completes. New dependencies are added to support these features. PiperOrigin-RevId: 873952199 --- .../adk/a2a/executor/a2a_agent_executor.py | 72 +++++++---- src/google/adk/a2a/executor/config.py | 69 +++++++++++ .../adk/a2a/executor/executor_context.py | 49 ++++++++ src/google/adk/a2a/executor/utils.py | 67 +++++++++++ .../a2a/executor/test_a2a_agent_executor.py | 112 ++++++++++++++++++ 5 files changed, 348 insertions(+), 21 deletions(-) create mode 100644 src/google/adk/a2a/executor/config.py create mode 100644 src/google/adk/a2a/executor/executor_context.py create mode 100644 src/google/adk/a2a/executor/utils.py diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index cca728db..956b1233 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -50,7 +50,12 @@ from ..converters.request_converter import AgentRunRequest from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key from ..experimental import a2a_experimental +from .config import ExecuteInterceptor +from .executor_context import ExecutorContext from .task_result_aggregator import TaskResultAggregator +from .utils import execute_after_agent_interceptors +from .utils import execute_after_event_interceptors +from .utils import execute_before_agent_interceptors logger = logging.getLogger('google_adk.' + __name__) @@ -70,6 +75,8 @@ class A2aAgentExecutorConfig(BaseModel): ) event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events + execute_interceptors: Optional[list[ExecuteInterceptor]] = None + @a2a_experimental class A2aAgentExecutor(AgentExecutor): @@ -135,6 +142,10 @@ class A2aAgentExecutor(AgentExecutor): if not context.message: raise ValueError('A2A request must have a message') + context = await execute_before_agent_interceptors( + context, self._config.execute_interceptors + ) + # for new task, create a task submitted event if not context.current_task: await event_queue.enqueue_event( @@ -202,6 +213,13 @@ class A2aAgentExecutor(AgentExecutor): run_config=run_request.run_config, ) + self._executor_context = ExecutorContext( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + runner=runner, + ) + # publish the task working event await event_queue.enqueue_event( TaskStatusUpdateEvent( @@ -230,6 +248,15 @@ class A2aAgentExecutor(AgentExecutor): context.context_id, self._config.gen_ai_part_converter, ): + a2a_event = await execute_after_event_interceptors( + a2a_event, + self._executor_context, + adk_event, + self._config.execute_interceptors, + ) + if a2a_event is None: + continue + task_result_aggregator.process_event(a2a_event) await event_queue.enqueue_event(a2a_event) @@ -253,31 +280,34 @@ class A2aAgentExecutor(AgentExecutor): ) ) # public the final status update event - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=TaskState.completed, - timestamp=datetime.now(timezone.utc).isoformat(), - ), - context_id=context.context_id, - final=True, - ) + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.completed, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=True, ) else: - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=task_result_aggregator.task_state, - timestamp=datetime.now(timezone.utc).isoformat(), - message=task_result_aggregator.task_status_message, - ), - context_id=context.context_id, - final=True, - ) + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=task_result_aggregator.task_state, + timestamp=datetime.now(timezone.utc).isoformat(), + message=task_result_aggregator.task_status_message, + ), + context_id=context.context_id, + final=True, ) + final_event = await execute_after_agent_interceptors( + self._executor_context, + final_event, + self._config.execute_interceptors, + ) + await event_queue.enqueue_event(final_event) + async def _prepare_session( self, context: RequestContext, diff --git a/src/google/adk/a2a/executor/config.py b/src/google/adk/a2a/executor/config.py new file mode 100644 index 00000000..79e88546 --- /dev/null +++ b/src/google/adk/a2a/executor/config.py @@ -0,0 +1,69 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from typing import Awaitable +from typing import Callable +from typing import Optional +from typing import Union + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent +from a2a.types import TaskStatusUpdateEvent + +from ...events.event import Event +from ..converters.utils import _get_adk_metadata_key +from .executor_context import ExecutorContext + + +@dataclasses.dataclass +class ExecuteInterceptor: + """Interceptor for the A2aAgentExecutor.""" + + before_agent: Optional[ + Callable[[RequestContext], Awaitable[RequestContext]] + ] = None + """Hook executed before the agent starts processing the request. + + Allows inspection or modification of the incoming request context. + Must return a valid `RequestContext` to continue execution. + """ + + after_event: Optional[ + Callable[ + [ExecutorContext, A2AEvent, Event], + Awaitable[Union[A2AEvent, None]], + ] + ] = None + """Hook executed after an ADK event is converted to an A2A event. + + Allows mutating the outgoing event before it is enqueued. + Return `None` to filter out and drop the event entirely, + which also halts any subsequent interceptors in the chain. + """ + + after_agent: Optional[ + Callable[ + [ExecutorContext, TaskStatusUpdateEvent], + Awaitable[TaskStatusUpdateEvent], + ] + ] = None + """Hook executed after the agent finishes and the final event is prepared. + + Allows inspection or modification of the terminal status event (e.g., + completed or failed) before it is enqueued. Must return a valid + `TaskStatusUpdateEvent`. + """ diff --git a/src/google/adk/a2a/executor/executor_context.py b/src/google/adk/a2a/executor/executor_context.py new file mode 100644 index 00000000..313afee6 --- /dev/null +++ b/src/google/adk/a2a/executor/executor_context.py @@ -0,0 +1,49 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.runners import Runner + + +class ExecutorContext: + """Context for the executor.""" + + def __init__( + self, + app_name: str, + user_id: str, + session_id: str, + runner: Runner, + ): + self._app_name = app_name + self._user_id = user_id + self._session_id = session_id + self._runner = runner + + @property + def app_name(self) -> str: + return self._app_name + + @property + def user_id(self) -> str: + return self._user_id + + @property + def session_id(self) -> str: + return self._session_id + + @property + def runner(self) -> Runner: + return self._runner diff --git a/src/google/adk/a2a/executor/utils.py b/src/google/adk/a2a/executor/utils.py new file mode 100644 index 00000000..d01066ea --- /dev/null +++ b/src/google/adk/a2a/executor/utils.py @@ -0,0 +1,67 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Optional + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent +from a2a.types import TaskStatusUpdateEvent + +from ...events.event import Event +from ..converters.utils import _get_adk_metadata_key +from .config import ExecuteInterceptor +from .executor_context import ExecutorContext + + +async def execute_before_agent_interceptors( + context: RequestContext, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> RequestContext: + if execute_interceptors: + for interceptor in execute_interceptors: + if interceptor.before_agent: + context = await interceptor.before_agent(context) + return context + + +async def execute_after_event_interceptors( + a2a_event: A2AEvent, + executor_context: ExecutorContext, + adk_event: Event, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> Optional[A2AEvent]: + if execute_interceptors: + for interceptor in execute_interceptors: + if interceptor.after_event: + a2a_event = await interceptor.after_event( + executor_context, a2a_event, adk_event + ) + if a2a_event is None: + return None + return a2a_event + + +async def execute_after_agent_interceptors( + executor_context: ExecutorContext, + final_event: TaskStatusUpdateEvent, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> TaskStatusUpdateEvent: + if execute_interceptors: + for interceptor in reversed(execute_interceptors): + if interceptor.after_agent: + final_event = await interceptor.after_agent( + executor_context, final_event + ) + return final_event diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 40736d95..787b260f 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -17,13 +17,17 @@ from unittest.mock import Mock from unittest.mock import patch from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent from a2a.server.events.event_queue import EventQueue from a2a.types import Message +from a2a.types import Part +from a2a.types import Role from a2a.types import TaskState from a2a.types import TextPart from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig +from google.adk.a2a.executor.config import ExecuteInterceptor from google.adk.events.event import Event from google.adk.runners import RunConfig from google.adk.runners import Runner @@ -959,3 +963,111 @@ class TestA2aAgentExecutor: assert final_event.status.message == test_message assert final_event.task_id == "test-task-id" assert final_event.context_id == "test-context-id" + + @pytest.mark.asyncio + async def test_after_event_interceptors_receive_correct_arguments_and_can_modify_event( + self, + ): + """Test that after_event interceptors receive correct arguments and can modify the event.""" + # Create distinct mock objects for ADK event and A2A event + adk_event = Mock(spec=Event, name="ADK_EVENT") + a2a_event = Mock(spec=A2AEvent, name="A2A_EVENT") + modified_a2a_event = Mock(spec=A2AEvent, name="MODIFIED_A2A_EVENT") + + # Mocks for conversion + self.mock_event_converter.return_value = [a2a_event] + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Setup Interceptor + mock_interceptor = Mock(spec=ExecuteInterceptor) + + # after_event should return the modified event + async def side_effect_after_event(context, event, original_event): + return modified_a2a_event + + mock_interceptor.after_event = AsyncMock( + side_effect=side_effect_after_event + ) + mock_interceptor.before_agent = None + mock_interceptor.after_agent = None + + # Update config with interceptor + self.mock_config.execute_interceptors = [mock_interceptor] + # Re-initialize executor with updated config - but we can just update + # the config in place if it's mutable + # The executor uses self._config which is this mock_config basically. + # self.executor was initialized in setup_method with self.mock_config. + + # However, A2aAgentExecutor constructor does: self._config = config or ... + # So updating self.mock_config properties should work as + # it is the same object reference. + + # Mock context + self.mock_context.task_id = "task-1" + self.mock_context.context_id = "ctx-1" + # Ensure current_task is set so we skip the initial + # submitted event creation logic + # which might complicate this specific test if we don't care about it. + self.mock_context.current_task = Mock() + + # Mock runner.run_async to yield our ADK event + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([adk_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Configure session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + self.mock_runner._new_invocation_context.return_value = Mock() + + # We patch TaskResultAggregator just to avoid other errors and simplfy + with patch( + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_agg_class: + mock_agg = Mock() + mock_agg.task_status_message = None + mock_agg.task_state = TaskState.working + mock_agg_class.return_value = mock_agg + + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify aggregator processed the MODIFIED event + mock_agg.process_event.assert_called_with(modified_a2a_event) + + # Verification of arguments passed to interceptor + assert mock_interceptor.after_event.called + call_args = mock_interceptor.after_event.call_args + # call_args.args should be (executor_context, a2a_event, adk_event) + + passed_a2a_event = call_args.args[1] + passed_adk_event = call_args.args[2] + + # These assertions verify the bug fix + assert ( + passed_a2a_event is a2a_event + ), f"Expected A2A event to be passed as 2nd arg, but got {passed_a2a_event}" + assert ( + passed_adk_event is adk_event + ), f"Expected ADK event to be passed as 3rd arg, but got {passed_adk_event}" + + # Verify that the modified event was enqueued + # We check if enqueue_event was called with modified_a2a_event + # Note: enqueue_event is called multiple times. + + enqueued_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + ] + assert ( + modified_a2a_event in enqueued_events + ), "The modified event should have been enqueued"