diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index cca728db..956b1233 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -50,7 +50,12 @@ from ..converters.request_converter import AgentRunRequest from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key from ..experimental import a2a_experimental +from .config import ExecuteInterceptor +from .executor_context import ExecutorContext from .task_result_aggregator import TaskResultAggregator +from .utils import execute_after_agent_interceptors +from .utils import execute_after_event_interceptors +from .utils import execute_before_agent_interceptors logger = logging.getLogger('google_adk.' + __name__) @@ -70,6 +75,8 @@ class A2aAgentExecutorConfig(BaseModel): ) event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events + execute_interceptors: Optional[list[ExecuteInterceptor]] = None + @a2a_experimental class A2aAgentExecutor(AgentExecutor): @@ -135,6 +142,10 @@ class A2aAgentExecutor(AgentExecutor): if not context.message: raise ValueError('A2A request must have a message') + context = await execute_before_agent_interceptors( + context, self._config.execute_interceptors + ) + # for new task, create a task submitted event if not context.current_task: await event_queue.enqueue_event( @@ -202,6 +213,13 @@ class A2aAgentExecutor(AgentExecutor): run_config=run_request.run_config, ) + self._executor_context = ExecutorContext( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + runner=runner, + ) + # publish the task working event await event_queue.enqueue_event( TaskStatusUpdateEvent( @@ -230,6 +248,15 @@ class A2aAgentExecutor(AgentExecutor): context.context_id, self._config.gen_ai_part_converter, ): + a2a_event = await execute_after_event_interceptors( + a2a_event, + self._executor_context, + adk_event, + self._config.execute_interceptors, + ) + if a2a_event is None: + continue + task_result_aggregator.process_event(a2a_event) await event_queue.enqueue_event(a2a_event) @@ -253,31 +280,34 @@ class A2aAgentExecutor(AgentExecutor): ) ) # public the final status update event - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=TaskState.completed, - timestamp=datetime.now(timezone.utc).isoformat(), - ), - context_id=context.context_id, - final=True, - ) + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.completed, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=True, ) else: - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=task_result_aggregator.task_state, - timestamp=datetime.now(timezone.utc).isoformat(), - message=task_result_aggregator.task_status_message, - ), - context_id=context.context_id, - final=True, - ) + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=task_result_aggregator.task_state, + timestamp=datetime.now(timezone.utc).isoformat(), + message=task_result_aggregator.task_status_message, + ), + context_id=context.context_id, + final=True, ) + final_event = await execute_after_agent_interceptors( + self._executor_context, + final_event, + self._config.execute_interceptors, + ) + await event_queue.enqueue_event(final_event) + async def _prepare_session( self, context: RequestContext, diff --git a/src/google/adk/a2a/executor/config.py b/src/google/adk/a2a/executor/config.py new file mode 100644 index 00000000..79e88546 --- /dev/null +++ b/src/google/adk/a2a/executor/config.py @@ -0,0 +1,69 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from typing import Awaitable +from typing import Callable +from typing import Optional +from typing import Union + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent +from a2a.types import TaskStatusUpdateEvent + +from ...events.event import Event +from ..converters.utils import _get_adk_metadata_key +from .executor_context import ExecutorContext + + +@dataclasses.dataclass +class ExecuteInterceptor: + """Interceptor for the A2aAgentExecutor.""" + + before_agent: Optional[ + Callable[[RequestContext], Awaitable[RequestContext]] + ] = None + """Hook executed before the agent starts processing the request. + + Allows inspection or modification of the incoming request context. + Must return a valid `RequestContext` to continue execution. + """ + + after_event: Optional[ + Callable[ + [ExecutorContext, A2AEvent, Event], + Awaitable[Union[A2AEvent, None]], + ] + ] = None + """Hook executed after an ADK event is converted to an A2A event. + + Allows mutating the outgoing event before it is enqueued. + Return `None` to filter out and drop the event entirely, + which also halts any subsequent interceptors in the chain. + """ + + after_agent: Optional[ + Callable[ + [ExecutorContext, TaskStatusUpdateEvent], + Awaitable[TaskStatusUpdateEvent], + ] + ] = None + """Hook executed after the agent finishes and the final event is prepared. + + Allows inspection or modification of the terminal status event (e.g., + completed or failed) before it is enqueued. Must return a valid + `TaskStatusUpdateEvent`. + """ diff --git a/src/google/adk/a2a/executor/executor_context.py b/src/google/adk/a2a/executor/executor_context.py new file mode 100644 index 00000000..313afee6 --- /dev/null +++ b/src/google/adk/a2a/executor/executor_context.py @@ -0,0 +1,49 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.runners import Runner + + +class ExecutorContext: + """Context for the executor.""" + + def __init__( + self, + app_name: str, + user_id: str, + session_id: str, + runner: Runner, + ): + self._app_name = app_name + self._user_id = user_id + self._session_id = session_id + self._runner = runner + + @property + def app_name(self) -> str: + return self._app_name + + @property + def user_id(self) -> str: + return self._user_id + + @property + def session_id(self) -> str: + return self._session_id + + @property + def runner(self) -> Runner: + return self._runner diff --git a/src/google/adk/a2a/executor/utils.py b/src/google/adk/a2a/executor/utils.py new file mode 100644 index 00000000..d01066ea --- /dev/null +++ b/src/google/adk/a2a/executor/utils.py @@ -0,0 +1,67 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Optional + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent +from a2a.types import TaskStatusUpdateEvent + +from ...events.event import Event +from ..converters.utils import _get_adk_metadata_key +from .config import ExecuteInterceptor +from .executor_context import ExecutorContext + + +async def execute_before_agent_interceptors( + context: RequestContext, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> RequestContext: + if execute_interceptors: + for interceptor in execute_interceptors: + if interceptor.before_agent: + context = await interceptor.before_agent(context) + return context + + +async def execute_after_event_interceptors( + a2a_event: A2AEvent, + executor_context: ExecutorContext, + adk_event: Event, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> Optional[A2AEvent]: + if execute_interceptors: + for interceptor in execute_interceptors: + if interceptor.after_event: + a2a_event = await interceptor.after_event( + executor_context, a2a_event, adk_event + ) + if a2a_event is None: + return None + return a2a_event + + +async def execute_after_agent_interceptors( + executor_context: ExecutorContext, + final_event: TaskStatusUpdateEvent, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> TaskStatusUpdateEvent: + if execute_interceptors: + for interceptor in reversed(execute_interceptors): + if interceptor.after_agent: + final_event = await interceptor.after_agent( + executor_context, final_event + ) + return final_event diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 40736d95..787b260f 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -17,13 +17,17 @@ from unittest.mock import Mock from unittest.mock import patch from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent from a2a.server.events.event_queue import EventQueue from a2a.types import Message +from a2a.types import Part +from a2a.types import Role from a2a.types import TaskState from a2a.types import TextPart from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig +from google.adk.a2a.executor.config import ExecuteInterceptor from google.adk.events.event import Event from google.adk.runners import RunConfig from google.adk.runners import Runner @@ -959,3 +963,111 @@ class TestA2aAgentExecutor: assert final_event.status.message == test_message assert final_event.task_id == "test-task-id" assert final_event.context_id == "test-context-id" + + @pytest.mark.asyncio + async def test_after_event_interceptors_receive_correct_arguments_and_can_modify_event( + self, + ): + """Test that after_event interceptors receive correct arguments and can modify the event.""" + # Create distinct mock objects for ADK event and A2A event + adk_event = Mock(spec=Event, name="ADK_EVENT") + a2a_event = Mock(spec=A2AEvent, name="A2A_EVENT") + modified_a2a_event = Mock(spec=A2AEvent, name="MODIFIED_A2A_EVENT") + + # Mocks for conversion + self.mock_event_converter.return_value = [a2a_event] + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Setup Interceptor + mock_interceptor = Mock(spec=ExecuteInterceptor) + + # after_event should return the modified event + async def side_effect_after_event(context, event, original_event): + return modified_a2a_event + + mock_interceptor.after_event = AsyncMock( + side_effect=side_effect_after_event + ) + mock_interceptor.before_agent = None + mock_interceptor.after_agent = None + + # Update config with interceptor + self.mock_config.execute_interceptors = [mock_interceptor] + # Re-initialize executor with updated config - but we can just update + # the config in place if it's mutable + # The executor uses self._config which is this mock_config basically. + # self.executor was initialized in setup_method with self.mock_config. + + # However, A2aAgentExecutor constructor does: self._config = config or ... + # So updating self.mock_config properties should work as + # it is the same object reference. + + # Mock context + self.mock_context.task_id = "task-1" + self.mock_context.context_id = "ctx-1" + # Ensure current_task is set so we skip the initial + # submitted event creation logic + # which might complicate this specific test if we don't care about it. + self.mock_context.current_task = Mock() + + # Mock runner.run_async to yield our ADK event + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([adk_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Configure session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + self.mock_runner._new_invocation_context.return_value = Mock() + + # We patch TaskResultAggregator just to avoid other errors and simplfy + with patch( + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_agg_class: + mock_agg = Mock() + mock_agg.task_status_message = None + mock_agg.task_state = TaskState.working + mock_agg_class.return_value = mock_agg + + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify aggregator processed the MODIFIED event + mock_agg.process_event.assert_called_with(modified_a2a_event) + + # Verification of arguments passed to interceptor + assert mock_interceptor.after_event.called + call_args = mock_interceptor.after_event.call_args + # call_args.args should be (executor_context, a2a_event, adk_event) + + passed_a2a_event = call_args.args[1] + passed_adk_event = call_args.args[2] + + # These assertions verify the bug fix + assert ( + passed_a2a_event is a2a_event + ), f"Expected A2A event to be passed as 2nd arg, but got {passed_a2a_event}" + assert ( + passed_adk_event is adk_event + ), f"Expected ADK event to be passed as 3rd arg, but got {passed_adk_event}" + + # Verify that the modified event was enqueued + # We check if enqueue_event was called with modified_a2a_event + # Note: enqueue_event is called multiple times. + + enqueued_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + ] + assert ( + modified_a2a_event in enqueued_events + ), "The modified event should have been enqueued"