feat: Introduce A2A request interceptors in RemoteA2aAgent

This change adds a new `a2a` subpackage with configuration and utility functions for intercepting requests and responses in `RemoteA2aAgent`. The `RemoteA2aAgent` now accepts an `A2aRemoteAgentConfig` to register `RequestInterceptor` instances, allowing custom logic to be executed before and after the A2A message send.

PiperOrigin-RevId: 875559286
This commit is contained in:
Google Team Member
2026-02-26 00:22:26 -08:00
committed by Copybara-Service
parent 5f806ed73a
commit 6f772d2b08
5 changed files with 406 additions and 7 deletions
+25
View File
@@ -0,0 +1,25 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A2A agents package."""
from .config import A2aRemoteAgentConfig
from .config import ParametersConfig
from .config import RequestInterceptor
__all__ = [
"A2aRemoteAgentConfig",
"ParametersConfig",
"RequestInterceptor",
]
+76
View File
@@ -0,0 +1,76 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for A2A agents."""
from __future__ import annotations
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Optional
from typing import Union
from a2a.client.middleware import ClientCallContext
from a2a.server.events import Event as A2AEvent
from a2a.types import Message as A2AMessage
from a2a.types import MessageSendConfiguration
from pydantic import BaseModel
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
class ParametersConfig(BaseModel):
"""Configuration for the parameters passed to the A2A send_message request."""
request_metadata: Optional[dict[str, Any]] = None
client_call_context: Optional[ClientCallContext] = None
# TODO: Add support for requested_extension and
# message_send_configuration once they are supported by the A2A client.
#
# requested_extension: Optional[list[str]] = None
# message_send_configuration: Optional[MessageSendConfiguration] = None
class RequestInterceptor(BaseModel):
"""Interceptor for A2A requests."""
before_request: Optional[
Callable[
[InvocationContext, A2AMessage, ParametersConfig],
Awaitable[tuple[Union[A2AMessage, Event], ParametersConfig]],
]
] = None
"""Hook executed before the agent starts processing the request.
Returns an Event if the request should be aborted and the Event
returned to the caller.
"""
after_request: Optional[
Callable[
[InvocationContext, A2AEvent, Event], Awaitable[Union[Event, None]]
]
] = None
"""Hook executed after the agent has processed the request.
Returns None if the event should not be sent to the caller.
"""
class A2aRemoteAgentConfig(BaseModel):
"""Configuration for the RemoteA2aAgent."""
request_interceptors: Optional[list[RequestInterceptor]] = None
+70
View File
@@ -0,0 +1,70 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for A2A agents."""
from __future__ import annotations
from typing import Optional
from typing import Union
from a2a.client import ClientEvent as A2AClientEvent
from a2a.client.middleware import ClientCallContext
from a2a.types import Message as A2AMessage
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from .config import ParametersConfig
from .config import RequestInterceptor
async def execute_before_request_interceptors(
request_interceptors: Optional[list[RequestInterceptor]],
ctx: InvocationContext,
a2a_request: A2AMessage,
) -> tuple[Union[A2AMessage, Event], ParametersConfig]:
"""Executes registered before_request interceptors."""
params = ParametersConfig(
client_call_context=ClientCallContext(state=ctx.session.state)
)
if request_interceptors:
for interceptor in request_interceptors:
if not interceptor.before_request:
continue
result, params = await interceptor.before_request(
ctx, a2a_request, params
)
if isinstance(result, Event):
return result, params
a2a_request = result
return a2a_request, params
async def execute_after_request_interceptors(
request_interceptors: Optional[list[RequestInterceptor]],
ctx: InvocationContext,
a2a_response: A2AMessage | A2AClientEvent,
event: Event,
) -> Optional[Event]:
"""Executes registered after_request interceptors."""
if request_interceptors:
for interceptor in reversed(request_interceptors):
if interceptor.after_request:
event = await interceptor.after_request(ctx, a2a_response, event)
if not event:
return None
return event
+31 -5
View File
@@ -35,6 +35,7 @@ from a2a.client.errors import A2AClientHTTPError
from a2a.client.middleware import ClientCallContext
from a2a.types import AgentCard
from a2a.types import Message as A2AMessage
from a2a.types import MessageSendConfiguration
from a2a.types import Part as A2APart
from a2a.types import Role
from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent
@@ -43,6 +44,7 @@ from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent
from a2a.types import TransportProtocol as A2ATransport
from google.genai import types as genai_types
import httpx
from pydantic import BaseModel
try:
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
@@ -50,6 +52,9 @@ except ImportError:
# Fallback for older versions of a2a-sdk.
AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json"
from ..a2a.agent.config import A2aRemoteAgentConfig
from ..a2a.agent.utils import execute_after_request_interceptors
from ..a2a.agent.utils import execute_before_request_interceptors
from ..a2a.converters.event_converter import convert_a2a_message_to_event
from ..a2a.converters.event_converter import convert_a2a_task_to_event
from ..a2a.converters.event_converter import convert_event_to_a2a_message
@@ -127,6 +132,7 @@ class RemoteA2aAgent(BaseAgent):
Callable[[InvocationContext, A2AMessage], dict[str, Any]]
] = None,
full_history_when_stateless: bool = False,
config: Optional[A2aRemoteAgentConfig] = None,
**kwargs: Any,
) -> None:
"""Initialize RemoteA2aAgent.
@@ -147,6 +153,7 @@ class RemoteA2aAgent(BaseAgent):
return Tasks or context IDs) will receive all session events on every
request. If False, the default behavior of sending only events since the
last reply from the agent will be used.
config: Optional configuration object.
**kwargs: Additional arguments passed to BaseAgent
Raises:
@@ -174,6 +181,7 @@ class RemoteA2aAgent(BaseAgent):
self._a2a_client_factory: Optional[A2AClientFactory] = a2a_client_factory
self._a2a_request_meta_provider = a2a_request_meta_provider
self._full_history_when_stateless = full_history_when_stateless
self._config = config or A2aRemoteAgentConfig()
# Validate and store agent card reference
if isinstance(agent_card, AgentCard):
@@ -558,14 +566,26 @@ class RemoteA2aAgent(BaseAgent):
logger.debug(build_a2a_request_log(a2a_request))
try:
request_metadata = None
if self._a2a_request_meta_provider:
request_metadata = self._a2a_request_meta_provider(ctx, a2a_request)
a2a_request, parameters = await execute_before_request_interceptors(
self._config.request_interceptors, ctx, a2a_request
)
if isinstance(a2a_request, Event):
yield a2a_request
return
# Backward compatibility
if self._a2a_request_meta_provider:
parameters.request_metadata = self._a2a_request_meta_provider(
ctx, a2a_request
)
# TODO: Add support for requested_extension and
# message_send_configuration once they are supported by the A2A client.
async for a2a_response in self._a2a_client.send_message(
request=a2a_request,
request_metadata=request_metadata,
context=ClientCallContext(state=ctx.session.state),
request_metadata=parameters.request_metadata,
context=parameters.client_call_context,
):
logger.debug(build_a2a_response_log(a2a_response))
@@ -573,6 +593,12 @@ class RemoteA2aAgent(BaseAgent):
if not event:
continue
event = await execute_after_request_interceptors(
self._config.request_interceptors, ctx, a2a_response, event
)
if not event:
continue
# Add metadata about the request and response
event.custom_metadata = event.custom_metadata or {}
event.custom_metadata[A2A_METADATA_PREFIX + "request"] = (
+204 -2
View File
@@ -21,7 +21,6 @@ from unittest.mock import Mock
from unittest.mock import patch
from a2a.client.client import ClientConfig
from a2a.client.client import Consumer
from a2a.client.client_factory import ClientFactory
from a2a.client.middleware import ClientCallContext
from a2a.types import AgentCapabilities
@@ -29,13 +28,16 @@ from a2a.types import AgentCard
from a2a.types import AgentSkill
from a2a.types import Artifact
from a2a.types import Message as A2AMessage
from a2a.types import SendMessageSuccessResponse
from a2a.types import Task as A2ATask
from a2a.types import TaskArtifactUpdateEvent
from a2a.types import TaskState
from a2a.types import TaskStatus as A2ATaskStatus
from a2a.types import TaskStatusUpdateEvent
from a2a.types import TextPart
from google.adk.a2a.agent import ParametersConfig
from google.adk.a2a.agent import RequestInterceptor
from google.adk.a2a.agent.utils import execute_after_request_interceptors
from google.adk.a2a.agent.utils import execute_before_request_interceptors
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX
from google.adk.agents.remote_a2a_agent import AgentCardResolutionError
@@ -2432,3 +2434,203 @@ class TestRemoteA2aAgentIntegration:
# Verify A2A client was called
mock_a2a_client.send_message.assert_called_once()
class TestRemoteA2aAgentInterceptors:
@pytest.fixture
def mock_context(self):
ctx = Mock(spec=InvocationContext)
ctx.session = Mock()
ctx.session.state = {"key": "value"}
return ctx
@pytest.mark.asyncio
async def test_execute_before_request_interceptors_none(self, mock_context):
request = Mock(spec=A2AMessage)
result_req, params = await execute_before_request_interceptors(
None, mock_context, request
)
assert result_req is request
assert params.client_call_context.state == {"key": "value"}
@pytest.mark.asyncio
async def test_execute_before_request_interceptors_empty(self, mock_context):
request = Mock(spec=A2AMessage)
result_req, params = await execute_before_request_interceptors(
[], mock_context, request
)
assert result_req is request
assert params.client_call_context.state == {"key": "value"}
@pytest.mark.asyncio
async def test_execute_before_request_interceptors_success(
self, mock_context
):
request = Mock(spec=A2AMessage)
new_request = Mock(spec=A2AMessage)
interceptor1 = Mock(spec=RequestInterceptor)
interceptor1.before_request = AsyncMock(
return_value=(
new_request,
ParametersConfig(
client_call_context=ClientCallContext(state={"updated": "true"})
),
)
)
result_req, params = await execute_before_request_interceptors(
[interceptor1], mock_context, request
)
assert result_req is new_request
assert params.client_call_context.state == {"updated": "true"}
interceptor1.before_request.assert_called_once()
@pytest.mark.asyncio
async def test_execute_before_request_interceptors_returns_event(
self, mock_context
):
request = Mock(spec=A2AMessage)
event = Mock(spec=Event)
interceptor1 = Mock(spec=RequestInterceptor)
interceptor1.before_request = AsyncMock(
return_value=(
event,
ParametersConfig(
client_call_context=ClientCallContext(state={"updated": "true"})
),
)
)
interceptor2 = Mock(spec=RequestInterceptor)
interceptor2.before_request = AsyncMock()
result, params = await execute_before_request_interceptors(
[interceptor1, interceptor2], mock_context, request
)
assert result is event
assert params.client_call_context.state == {"updated": "true"}
interceptor1.before_request.assert_called_once()
interceptor2.before_request.assert_not_called()
@pytest.mark.asyncio
async def test_execute_before_request_interceptors_no_before_request(
self, mock_context
):
request = Mock(spec=A2AMessage)
interceptor1 = Mock(spec=RequestInterceptor)
interceptor1.before_request = None
result_req, params = await execute_before_request_interceptors(
[interceptor1], mock_context, request
)
assert result_req is request
assert params.client_call_context.state == {"key": "value"}
@pytest.mark.asyncio
async def test_execute_after_request_interceptors_none(self, mock_context):
response = Mock(spec=A2AMessage)
event = Mock(spec=Event)
result = await execute_after_request_interceptors(
None, mock_context, response, event
)
assert result is event
@pytest.mark.asyncio
async def test_execute_after_request_interceptors_empty(self, mock_context):
response = Mock(spec=A2AMessage)
event = Mock(spec=Event)
result = await execute_after_request_interceptors(
[], mock_context, response, event
)
assert result is event
@pytest.mark.asyncio
async def test_execute_after_request_interceptors_success(self, mock_context):
response = Mock(spec=A2AMessage)
event = Mock(spec=Event)
new_event = Mock(spec=Event)
interceptor1 = Mock(spec=RequestInterceptor)
interceptor1.after_request = AsyncMock(return_value=new_event)
result = await execute_after_request_interceptors(
[interceptor1], mock_context, response, event
)
assert result is new_event
interceptor1.after_request.assert_called_once_with(
mock_context, response, event
)
@pytest.mark.asyncio
async def test_execute_after_request_interceptors_reverse_order(
self, mock_context
):
response = Mock(spec=A2AMessage)
event = Mock(spec=Event)
event1 = Mock(spec=Event)
event2 = Mock(spec=Event)
interceptor1 = Mock(spec=RequestInterceptor)
interceptor1.after_request = AsyncMock(return_value=event1)
interceptor2 = Mock(spec=RequestInterceptor)
interceptor2.after_request = AsyncMock(return_value=event2)
result = await execute_after_request_interceptors(
[interceptor1, interceptor2], mock_context, response, event
)
assert result is event1
interceptor2.after_request.assert_called_once_with(
mock_context, response, event
)
interceptor1.after_request.assert_called_once_with(
mock_context, response, event2
)
@pytest.mark.asyncio
async def test_execute_after_request_interceptors_returns_none(
self, mock_context
):
response = Mock(spec=A2AMessage)
event = Mock(spec=Event)
interceptor1 = Mock(spec=RequestInterceptor)
interceptor1.after_request = AsyncMock()
interceptor2 = Mock(spec=RequestInterceptor)
interceptor2.after_request = AsyncMock(return_value=None)
result = await execute_after_request_interceptors(
[interceptor1, interceptor2], mock_context, response, event
)
assert result is None
interceptor2.after_request.assert_called_once_with(
mock_context, response, event
)
interceptor1.after_request.assert_not_called()
@pytest.mark.asyncio
async def test_execute_after_request_interceptors_no_after_request(
self, mock_context
):
response = Mock(spec=A2AMessage)
event = Mock(spec=Event)
interceptor1 = Mock(spec=RequestInterceptor)
interceptor1.after_request = None
result = await execute_after_request_interceptors(
[interceptor1], mock_context, response, event
)
assert result is event