diff --git a/src/google/adk/a2a/agent/__init__.py b/src/google/adk/a2a/agent/__init__.py new file mode 100644 index 00000000..8026986e --- /dev/null +++ b/src/google/adk/a2a/agent/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A2A agents package.""" + +from .config import A2aRemoteAgentConfig +from .config import ParametersConfig +from .config import RequestInterceptor + +__all__ = [ + "A2aRemoteAgentConfig", + "ParametersConfig", + "RequestInterceptor", +] diff --git a/src/google/adk/a2a/agent/config.py b/src/google/adk/a2a/agent/config.py new file mode 100644 index 00000000..e8f012cf --- /dev/null +++ b/src/google/adk/a2a/agent/config.py @@ -0,0 +1,76 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for A2A agents.""" + +from __future__ import annotations + +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Optional +from typing import Union + +from a2a.client.middleware import ClientCallContext +from a2a.server.events import Event as A2AEvent +from a2a.types import Message as A2AMessage +from a2a.types import MessageSendConfiguration +from pydantic import BaseModel + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event + + +class ParametersConfig(BaseModel): + """Configuration for the parameters passed to the A2A send_message request.""" + + request_metadata: Optional[dict[str, Any]] = None + client_call_context: Optional[ClientCallContext] = None + # TODO: Add support for requested_extension and + # message_send_configuration once they are supported by the A2A client. + # + # requested_extension: Optional[list[str]] = None + # message_send_configuration: Optional[MessageSendConfiguration] = None + + +class RequestInterceptor(BaseModel): + """Interceptor for A2A requests.""" + + before_request: Optional[ + Callable[ + [InvocationContext, A2AMessage, ParametersConfig], + Awaitable[tuple[Union[A2AMessage, Event], ParametersConfig]], + ] + ] = None + """Hook executed before the agent starts processing the request. + + Returns an Event if the request should be aborted and the Event + returned to the caller. + """ + + after_request: Optional[ + Callable[ + [InvocationContext, A2AEvent, Event], Awaitable[Union[Event, None]] + ] + ] = None + """Hook executed after the agent has processed the request. + + Returns None if the event should not be sent to the caller. + """ + + +class A2aRemoteAgentConfig(BaseModel): + """Configuration for the RemoteA2aAgent.""" + + request_interceptors: Optional[list[RequestInterceptor]] = None diff --git a/src/google/adk/a2a/agent/utils.py b/src/google/adk/a2a/agent/utils.py new file mode 100644 index 00000000..7cbb25eb --- /dev/null +++ b/src/google/adk/a2a/agent/utils.py @@ -0,0 +1,70 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for A2A agents.""" + +from __future__ import annotations + +from typing import Optional +from typing import Union + +from a2a.client import ClientEvent as A2AClientEvent +from a2a.client.middleware import ClientCallContext +from a2a.types import Message as A2AMessage + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event +from .config import ParametersConfig +from .config import RequestInterceptor + + +async def execute_before_request_interceptors( + request_interceptors: Optional[list[RequestInterceptor]], + ctx: InvocationContext, + a2a_request: A2AMessage, +) -> tuple[Union[A2AMessage, Event], ParametersConfig]: + """Executes registered before_request interceptors.""" + + params = ParametersConfig( + client_call_context=ClientCallContext(state=ctx.session.state) + ) + if request_interceptors: + for interceptor in request_interceptors: + if not interceptor.before_request: + continue + + result, params = await interceptor.before_request( + ctx, a2a_request, params + ) + if isinstance(result, Event): + return result, params + a2a_request = result + + return a2a_request, params + + +async def execute_after_request_interceptors( + request_interceptors: Optional[list[RequestInterceptor]], + ctx: InvocationContext, + a2a_response: A2AMessage | A2AClientEvent, + event: Event, +) -> Optional[Event]: + """Executes registered after_request interceptors.""" + if request_interceptors: + for interceptor in reversed(request_interceptors): + if interceptor.after_request: + event = await interceptor.after_request(ctx, a2a_response, event) + if not event: + return None + return event diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 2da7a4fa..5ffd123f 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -35,6 +35,7 @@ from a2a.client.errors import A2AClientHTTPError from a2a.client.middleware import ClientCallContext from a2a.types import AgentCard from a2a.types import Message as A2AMessage +from a2a.types import MessageSendConfiguration from a2a.types import Part as A2APart from a2a.types import Role from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent @@ -43,6 +44,7 @@ from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent from a2a.types import TransportProtocol as A2ATransport from google.genai import types as genai_types import httpx +from pydantic import BaseModel try: from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -50,6 +52,9 @@ except ImportError: # Fallback for older versions of a2a-sdk. AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json" +from ..a2a.agent.config import A2aRemoteAgentConfig +from ..a2a.agent.utils import execute_after_request_interceptors +from ..a2a.agent.utils import execute_before_request_interceptors from ..a2a.converters.event_converter import convert_a2a_message_to_event from ..a2a.converters.event_converter import convert_a2a_task_to_event from ..a2a.converters.event_converter import convert_event_to_a2a_message @@ -127,6 +132,7 @@ class RemoteA2aAgent(BaseAgent): Callable[[InvocationContext, A2AMessage], dict[str, Any]] ] = None, full_history_when_stateless: bool = False, + config: Optional[A2aRemoteAgentConfig] = None, **kwargs: Any, ) -> None: """Initialize RemoteA2aAgent. @@ -147,6 +153,7 @@ class RemoteA2aAgent(BaseAgent): return Tasks or context IDs) will receive all session events on every request. If False, the default behavior of sending only events since the last reply from the agent will be used. + config: Optional configuration object. **kwargs: Additional arguments passed to BaseAgent Raises: @@ -174,6 +181,7 @@ class RemoteA2aAgent(BaseAgent): self._a2a_client_factory: Optional[A2AClientFactory] = a2a_client_factory self._a2a_request_meta_provider = a2a_request_meta_provider self._full_history_when_stateless = full_history_when_stateless + self._config = config or A2aRemoteAgentConfig() # Validate and store agent card reference if isinstance(agent_card, AgentCard): @@ -558,14 +566,26 @@ class RemoteA2aAgent(BaseAgent): logger.debug(build_a2a_request_log(a2a_request)) try: - request_metadata = None - if self._a2a_request_meta_provider: - request_metadata = self._a2a_request_meta_provider(ctx, a2a_request) + a2a_request, parameters = await execute_before_request_interceptors( + self._config.request_interceptors, ctx, a2a_request + ) + if isinstance(a2a_request, Event): + yield a2a_request + return + + # Backward compatibility + if self._a2a_request_meta_provider: + parameters.request_metadata = self._a2a_request_meta_provider( + ctx, a2a_request + ) + + # TODO: Add support for requested_extension and + # message_send_configuration once they are supported by the A2A client. async for a2a_response in self._a2a_client.send_message( request=a2a_request, - request_metadata=request_metadata, - context=ClientCallContext(state=ctx.session.state), + request_metadata=parameters.request_metadata, + context=parameters.client_call_context, ): logger.debug(build_a2a_response_log(a2a_response)) @@ -573,6 +593,12 @@ class RemoteA2aAgent(BaseAgent): if not event: continue + event = await execute_after_request_interceptors( + self._config.request_interceptors, ctx, a2a_response, event + ) + if not event: + continue + # Add metadata about the request and response event.custom_metadata = event.custom_metadata or {} event.custom_metadata[A2A_METADATA_PREFIX + "request"] = ( diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 7643125d..fe155d30 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -21,7 +21,6 @@ from unittest.mock import Mock from unittest.mock import patch from a2a.client.client import ClientConfig -from a2a.client.client import Consumer from a2a.client.client_factory import ClientFactory from a2a.client.middleware import ClientCallContext from a2a.types import AgentCapabilities @@ -29,13 +28,16 @@ from a2a.types import AgentCard from a2a.types import AgentSkill from a2a.types import Artifact from a2a.types import Message as A2AMessage -from a2a.types import SendMessageSuccessResponse from a2a.types import Task as A2ATask from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatus as A2ATaskStatus from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart +from google.adk.a2a.agent import ParametersConfig +from google.adk.a2a.agent import RequestInterceptor +from google.adk.a2a.agent.utils import execute_after_request_interceptors +from google.adk.a2a.agent.utils import execute_before_request_interceptors from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX from google.adk.agents.remote_a2a_agent import AgentCardResolutionError @@ -2432,3 +2434,203 @@ class TestRemoteA2aAgentIntegration: # Verify A2A client was called mock_a2a_client.send_message.assert_called_once() + + +class TestRemoteA2aAgentInterceptors: + + @pytest.fixture + def mock_context(self): + ctx = Mock(spec=InvocationContext) + ctx.session = Mock() + ctx.session.state = {"key": "value"} + return ctx + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_none(self, mock_context): + request = Mock(spec=A2AMessage) + result_req, params = await execute_before_request_interceptors( + None, mock_context, request + ) + assert result_req is request + assert params.client_call_context.state == {"key": "value"} + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_empty(self, mock_context): + request = Mock(spec=A2AMessage) + result_req, params = await execute_before_request_interceptors( + [], mock_context, request + ) + assert result_req is request + assert params.client_call_context.state == {"key": "value"} + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_success( + self, mock_context + ): + request = Mock(spec=A2AMessage) + new_request = Mock(spec=A2AMessage) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.before_request = AsyncMock( + return_value=( + new_request, + ParametersConfig( + client_call_context=ClientCallContext(state={"updated": "true"}) + ), + ) + ) + + result_req, params = await execute_before_request_interceptors( + [interceptor1], mock_context, request + ) + + assert result_req is new_request + assert params.client_call_context.state == {"updated": "true"} + interceptor1.before_request.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_returns_event( + self, mock_context + ): + request = Mock(spec=A2AMessage) + event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.before_request = AsyncMock( + return_value=( + event, + ParametersConfig( + client_call_context=ClientCallContext(state={"updated": "true"}) + ), + ) + ) + + interceptor2 = Mock(spec=RequestInterceptor) + interceptor2.before_request = AsyncMock() + + result, params = await execute_before_request_interceptors( + [interceptor1, interceptor2], mock_context, request + ) + + assert result is event + assert params.client_call_context.state == {"updated": "true"} + interceptor1.before_request.assert_called_once() + interceptor2.before_request.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_no_before_request( + self, mock_context + ): + request = Mock(spec=A2AMessage) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.before_request = None + + result_req, params = await execute_before_request_interceptors( + [interceptor1], mock_context, request + ) + + assert result_req is request + assert params.client_call_context.state == {"key": "value"} + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_none(self, mock_context): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + result = await execute_after_request_interceptors( + None, mock_context, response, event + ) + assert result is event + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_empty(self, mock_context): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + result = await execute_after_request_interceptors( + [], mock_context, response, event + ) + assert result is event + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_success(self, mock_context): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + new_event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = AsyncMock(return_value=new_event) + + result = await execute_after_request_interceptors( + [interceptor1], mock_context, response, event + ) + + assert result is new_event + interceptor1.after_request.assert_called_once_with( + mock_context, response, event + ) + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_reverse_order( + self, mock_context + ): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + event1 = Mock(spec=Event) + event2 = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = AsyncMock(return_value=event1) + + interceptor2 = Mock(spec=RequestInterceptor) + interceptor2.after_request = AsyncMock(return_value=event2) + + result = await execute_after_request_interceptors( + [interceptor1, interceptor2], mock_context, response, event + ) + + assert result is event1 + interceptor2.after_request.assert_called_once_with( + mock_context, response, event + ) + interceptor1.after_request.assert_called_once_with( + mock_context, response, event2 + ) + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_returns_none( + self, mock_context + ): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = AsyncMock() + + interceptor2 = Mock(spec=RequestInterceptor) + interceptor2.after_request = AsyncMock(return_value=None) + + result = await execute_after_request_interceptors( + [interceptor1, interceptor2], mock_context, response, event + ) + + assert result is None + interceptor2.after_request.assert_called_once_with( + mock_context, response, event + ) + interceptor1.after_request.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_no_after_request( + self, mock_context + ): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = None + + result = await execute_after_request_interceptors( + [interceptor1], mock_context, response, event + ) + + assert result is event