You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Introduce A2A request interceptors in RemoteA2aAgent
This change adds a new `a2a` subpackage with configuration and utility functions for intercepting requests and responses in `RemoteA2aAgent`. The `RemoteA2aAgent` now accepts an `A2aRemoteAgentConfig` to register `RequestInterceptor` instances, allowing custom logic to be executed before and after the A2A message send. PiperOrigin-RevId: 875559286
This commit is contained in:
committed by
Copybara-Service
parent
5f806ed73a
commit
6f772d2b08
@@ -0,0 +1,25 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A2A agents package."""
|
||||
|
||||
from .config import A2aRemoteAgentConfig
|
||||
from .config import ParametersConfig
|
||||
from .config import RequestInterceptor
|
||||
|
||||
__all__ = [
|
||||
"A2aRemoteAgentConfig",
|
||||
"ParametersConfig",
|
||||
"RequestInterceptor",
|
||||
]
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for A2A agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Awaitable
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from a2a.client.middleware import ClientCallContext
|
||||
from a2a.server.events import Event as A2AEvent
|
||||
from a2a.types import Message as A2AMessage
|
||||
from a2a.types import MessageSendConfiguration
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
|
||||
|
||||
class ParametersConfig(BaseModel):
|
||||
"""Configuration for the parameters passed to the A2A send_message request."""
|
||||
|
||||
request_metadata: Optional[dict[str, Any]] = None
|
||||
client_call_context: Optional[ClientCallContext] = None
|
||||
# TODO: Add support for requested_extension and
|
||||
# message_send_configuration once they are supported by the A2A client.
|
||||
#
|
||||
# requested_extension: Optional[list[str]] = None
|
||||
# message_send_configuration: Optional[MessageSendConfiguration] = None
|
||||
|
||||
|
||||
class RequestInterceptor(BaseModel):
|
||||
"""Interceptor for A2A requests."""
|
||||
|
||||
before_request: Optional[
|
||||
Callable[
|
||||
[InvocationContext, A2AMessage, ParametersConfig],
|
||||
Awaitable[tuple[Union[A2AMessage, Event], ParametersConfig]],
|
||||
]
|
||||
] = None
|
||||
"""Hook executed before the agent starts processing the request.
|
||||
|
||||
Returns an Event if the request should be aborted and the Event
|
||||
returned to the caller.
|
||||
"""
|
||||
|
||||
after_request: Optional[
|
||||
Callable[
|
||||
[InvocationContext, A2AEvent, Event], Awaitable[Union[Event, None]]
|
||||
]
|
||||
] = None
|
||||
"""Hook executed after the agent has processed the request.
|
||||
|
||||
Returns None if the event should not be sent to the caller.
|
||||
"""
|
||||
|
||||
|
||||
class A2aRemoteAgentConfig(BaseModel):
|
||||
"""Configuration for the RemoteA2aAgent."""
|
||||
|
||||
request_interceptors: Optional[list[RequestInterceptor]] = None
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for A2A agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from a2a.client import ClientEvent as A2AClientEvent
|
||||
from a2a.client.middleware import ClientCallContext
|
||||
from a2a.types import Message as A2AMessage
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from .config import ParametersConfig
|
||||
from .config import RequestInterceptor
|
||||
|
||||
|
||||
async def execute_before_request_interceptors(
|
||||
request_interceptors: Optional[list[RequestInterceptor]],
|
||||
ctx: InvocationContext,
|
||||
a2a_request: A2AMessage,
|
||||
) -> tuple[Union[A2AMessage, Event], ParametersConfig]:
|
||||
"""Executes registered before_request interceptors."""
|
||||
|
||||
params = ParametersConfig(
|
||||
client_call_context=ClientCallContext(state=ctx.session.state)
|
||||
)
|
||||
if request_interceptors:
|
||||
for interceptor in request_interceptors:
|
||||
if not interceptor.before_request:
|
||||
continue
|
||||
|
||||
result, params = await interceptor.before_request(
|
||||
ctx, a2a_request, params
|
||||
)
|
||||
if isinstance(result, Event):
|
||||
return result, params
|
||||
a2a_request = result
|
||||
|
||||
return a2a_request, params
|
||||
|
||||
|
||||
async def execute_after_request_interceptors(
|
||||
request_interceptors: Optional[list[RequestInterceptor]],
|
||||
ctx: InvocationContext,
|
||||
a2a_response: A2AMessage | A2AClientEvent,
|
||||
event: Event,
|
||||
) -> Optional[Event]:
|
||||
"""Executes registered after_request interceptors."""
|
||||
if request_interceptors:
|
||||
for interceptor in reversed(request_interceptors):
|
||||
if interceptor.after_request:
|
||||
event = await interceptor.after_request(ctx, a2a_response, event)
|
||||
if not event:
|
||||
return None
|
||||
return event
|
||||
@@ -35,6 +35,7 @@ from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.client.middleware import ClientCallContext
|
||||
from a2a.types import AgentCard
|
||||
from a2a.types import Message as A2AMessage
|
||||
from a2a.types import MessageSendConfiguration
|
||||
from a2a.types import Part as A2APart
|
||||
from a2a.types import Role
|
||||
from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent
|
||||
@@ -43,6 +44,7 @@ from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent
|
||||
from a2a.types import TransportProtocol as A2ATransport
|
||||
from google.genai import types as genai_types
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
|
||||
@@ -50,6 +52,9 @@ except ImportError:
|
||||
# Fallback for older versions of a2a-sdk.
|
||||
AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json"
|
||||
|
||||
from ..a2a.agent.config import A2aRemoteAgentConfig
|
||||
from ..a2a.agent.utils import execute_after_request_interceptors
|
||||
from ..a2a.agent.utils import execute_before_request_interceptors
|
||||
from ..a2a.converters.event_converter import convert_a2a_message_to_event
|
||||
from ..a2a.converters.event_converter import convert_a2a_task_to_event
|
||||
from ..a2a.converters.event_converter import convert_event_to_a2a_message
|
||||
@@ -127,6 +132,7 @@ class RemoteA2aAgent(BaseAgent):
|
||||
Callable[[InvocationContext, A2AMessage], dict[str, Any]]
|
||||
] = None,
|
||||
full_history_when_stateless: bool = False,
|
||||
config: Optional[A2aRemoteAgentConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize RemoteA2aAgent.
|
||||
@@ -147,6 +153,7 @@ class RemoteA2aAgent(BaseAgent):
|
||||
return Tasks or context IDs) will receive all session events on every
|
||||
request. If False, the default behavior of sending only events since the
|
||||
last reply from the agent will be used.
|
||||
config: Optional configuration object.
|
||||
**kwargs: Additional arguments passed to BaseAgent
|
||||
|
||||
Raises:
|
||||
@@ -174,6 +181,7 @@ class RemoteA2aAgent(BaseAgent):
|
||||
self._a2a_client_factory: Optional[A2AClientFactory] = a2a_client_factory
|
||||
self._a2a_request_meta_provider = a2a_request_meta_provider
|
||||
self._full_history_when_stateless = full_history_when_stateless
|
||||
self._config = config or A2aRemoteAgentConfig()
|
||||
|
||||
# Validate and store agent card reference
|
||||
if isinstance(agent_card, AgentCard):
|
||||
@@ -558,14 +566,26 @@ class RemoteA2aAgent(BaseAgent):
|
||||
logger.debug(build_a2a_request_log(a2a_request))
|
||||
|
||||
try:
|
||||
request_metadata = None
|
||||
if self._a2a_request_meta_provider:
|
||||
request_metadata = self._a2a_request_meta_provider(ctx, a2a_request)
|
||||
a2a_request, parameters = await execute_before_request_interceptors(
|
||||
self._config.request_interceptors, ctx, a2a_request
|
||||
)
|
||||
|
||||
if isinstance(a2a_request, Event):
|
||||
yield a2a_request
|
||||
return
|
||||
|
||||
# Backward compatibility
|
||||
if self._a2a_request_meta_provider:
|
||||
parameters.request_metadata = self._a2a_request_meta_provider(
|
||||
ctx, a2a_request
|
||||
)
|
||||
|
||||
# TODO: Add support for requested_extension and
|
||||
# message_send_configuration once they are supported by the A2A client.
|
||||
async for a2a_response in self._a2a_client.send_message(
|
||||
request=a2a_request,
|
||||
request_metadata=request_metadata,
|
||||
context=ClientCallContext(state=ctx.session.state),
|
||||
request_metadata=parameters.request_metadata,
|
||||
context=parameters.client_call_context,
|
||||
):
|
||||
logger.debug(build_a2a_response_log(a2a_response))
|
||||
|
||||
@@ -573,6 +593,12 @@ class RemoteA2aAgent(BaseAgent):
|
||||
if not event:
|
||||
continue
|
||||
|
||||
event = await execute_after_request_interceptors(
|
||||
self._config.request_interceptors, ctx, a2a_response, event
|
||||
)
|
||||
if not event:
|
||||
continue
|
||||
|
||||
# Add metadata about the request and response
|
||||
event.custom_metadata = event.custom_metadata or {}
|
||||
event.custom_metadata[A2A_METADATA_PREFIX + "request"] = (
|
||||
|
||||
@@ -21,7 +21,6 @@ from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from a2a.client.client import ClientConfig
|
||||
from a2a.client.client import Consumer
|
||||
from a2a.client.client_factory import ClientFactory
|
||||
from a2a.client.middleware import ClientCallContext
|
||||
from a2a.types import AgentCapabilities
|
||||
@@ -29,13 +28,16 @@ from a2a.types import AgentCard
|
||||
from a2a.types import AgentSkill
|
||||
from a2a.types import Artifact
|
||||
from a2a.types import Message as A2AMessage
|
||||
from a2a.types import SendMessageSuccessResponse
|
||||
from a2a.types import Task as A2ATask
|
||||
from a2a.types import TaskArtifactUpdateEvent
|
||||
from a2a.types import TaskState
|
||||
from a2a.types import TaskStatus as A2ATaskStatus
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
from a2a.types import TextPart
|
||||
from google.adk.a2a.agent import ParametersConfig
|
||||
from google.adk.a2a.agent import RequestInterceptor
|
||||
from google.adk.a2a.agent.utils import execute_after_request_interceptors
|
||||
from google.adk.a2a.agent.utils import execute_before_request_interceptors
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX
|
||||
from google.adk.agents.remote_a2a_agent import AgentCardResolutionError
|
||||
@@ -2432,3 +2434,203 @@ class TestRemoteA2aAgentIntegration:
|
||||
|
||||
# Verify A2A client was called
|
||||
mock_a2a_client.send_message.assert_called_once()
|
||||
|
||||
|
||||
class TestRemoteA2aAgentInterceptors:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(self):
|
||||
ctx = Mock(spec=InvocationContext)
|
||||
ctx.session = Mock()
|
||||
ctx.session.state = {"key": "value"}
|
||||
return ctx
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_before_request_interceptors_none(self, mock_context):
|
||||
request = Mock(spec=A2AMessage)
|
||||
result_req, params = await execute_before_request_interceptors(
|
||||
None, mock_context, request
|
||||
)
|
||||
assert result_req is request
|
||||
assert params.client_call_context.state == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_before_request_interceptors_empty(self, mock_context):
|
||||
request = Mock(spec=A2AMessage)
|
||||
result_req, params = await execute_before_request_interceptors(
|
||||
[], mock_context, request
|
||||
)
|
||||
assert result_req is request
|
||||
assert params.client_call_context.state == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_before_request_interceptors_success(
|
||||
self, mock_context
|
||||
):
|
||||
request = Mock(spec=A2AMessage)
|
||||
new_request = Mock(spec=A2AMessage)
|
||||
|
||||
interceptor1 = Mock(spec=RequestInterceptor)
|
||||
interceptor1.before_request = AsyncMock(
|
||||
return_value=(
|
||||
new_request,
|
||||
ParametersConfig(
|
||||
client_call_context=ClientCallContext(state={"updated": "true"})
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
result_req, params = await execute_before_request_interceptors(
|
||||
[interceptor1], mock_context, request
|
||||
)
|
||||
|
||||
assert result_req is new_request
|
||||
assert params.client_call_context.state == {"updated": "true"}
|
||||
interceptor1.before_request.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_before_request_interceptors_returns_event(
|
||||
self, mock_context
|
||||
):
|
||||
request = Mock(spec=A2AMessage)
|
||||
event = Mock(spec=Event)
|
||||
|
||||
interceptor1 = Mock(spec=RequestInterceptor)
|
||||
interceptor1.before_request = AsyncMock(
|
||||
return_value=(
|
||||
event,
|
||||
ParametersConfig(
|
||||
client_call_context=ClientCallContext(state={"updated": "true"})
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
interceptor2 = Mock(spec=RequestInterceptor)
|
||||
interceptor2.before_request = AsyncMock()
|
||||
|
||||
result, params = await execute_before_request_interceptors(
|
||||
[interceptor1, interceptor2], mock_context, request
|
||||
)
|
||||
|
||||
assert result is event
|
||||
assert params.client_call_context.state == {"updated": "true"}
|
||||
interceptor1.before_request.assert_called_once()
|
||||
interceptor2.before_request.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_before_request_interceptors_no_before_request(
|
||||
self, mock_context
|
||||
):
|
||||
request = Mock(spec=A2AMessage)
|
||||
|
||||
interceptor1 = Mock(spec=RequestInterceptor)
|
||||
interceptor1.before_request = None
|
||||
|
||||
result_req, params = await execute_before_request_interceptors(
|
||||
[interceptor1], mock_context, request
|
||||
)
|
||||
|
||||
assert result_req is request
|
||||
assert params.client_call_context.state == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_after_request_interceptors_none(self, mock_context):
|
||||
response = Mock(spec=A2AMessage)
|
||||
event = Mock(spec=Event)
|
||||
result = await execute_after_request_interceptors(
|
||||
None, mock_context, response, event
|
||||
)
|
||||
assert result is event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_after_request_interceptors_empty(self, mock_context):
|
||||
response = Mock(spec=A2AMessage)
|
||||
event = Mock(spec=Event)
|
||||
result = await execute_after_request_interceptors(
|
||||
[], mock_context, response, event
|
||||
)
|
||||
assert result is event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_after_request_interceptors_success(self, mock_context):
|
||||
response = Mock(spec=A2AMessage)
|
||||
event = Mock(spec=Event)
|
||||
new_event = Mock(spec=Event)
|
||||
|
||||
interceptor1 = Mock(spec=RequestInterceptor)
|
||||
interceptor1.after_request = AsyncMock(return_value=new_event)
|
||||
|
||||
result = await execute_after_request_interceptors(
|
||||
[interceptor1], mock_context, response, event
|
||||
)
|
||||
|
||||
assert result is new_event
|
||||
interceptor1.after_request.assert_called_once_with(
|
||||
mock_context, response, event
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_after_request_interceptors_reverse_order(
|
||||
self, mock_context
|
||||
):
|
||||
response = Mock(spec=A2AMessage)
|
||||
event = Mock(spec=Event)
|
||||
event1 = Mock(spec=Event)
|
||||
event2 = Mock(spec=Event)
|
||||
|
||||
interceptor1 = Mock(spec=RequestInterceptor)
|
||||
interceptor1.after_request = AsyncMock(return_value=event1)
|
||||
|
||||
interceptor2 = Mock(spec=RequestInterceptor)
|
||||
interceptor2.after_request = AsyncMock(return_value=event2)
|
||||
|
||||
result = await execute_after_request_interceptors(
|
||||
[interceptor1, interceptor2], mock_context, response, event
|
||||
)
|
||||
|
||||
assert result is event1
|
||||
interceptor2.after_request.assert_called_once_with(
|
||||
mock_context, response, event
|
||||
)
|
||||
interceptor1.after_request.assert_called_once_with(
|
||||
mock_context, response, event2
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_after_request_interceptors_returns_none(
|
||||
self, mock_context
|
||||
):
|
||||
response = Mock(spec=A2AMessage)
|
||||
event = Mock(spec=Event)
|
||||
|
||||
interceptor1 = Mock(spec=RequestInterceptor)
|
||||
interceptor1.after_request = AsyncMock()
|
||||
|
||||
interceptor2 = Mock(spec=RequestInterceptor)
|
||||
interceptor2.after_request = AsyncMock(return_value=None)
|
||||
|
||||
result = await execute_after_request_interceptors(
|
||||
[interceptor1, interceptor2], mock_context, response, event
|
||||
)
|
||||
|
||||
assert result is None
|
||||
interceptor2.after_request.assert_called_once_with(
|
||||
mock_context, response, event
|
||||
)
|
||||
interceptor1.after_request.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_after_request_interceptors_no_after_request(
|
||||
self, mock_context
|
||||
):
|
||||
response = Mock(spec=A2AMessage)
|
||||
event = Mock(spec=Event)
|
||||
|
||||
interceptor1 = Mock(spec=RequestInterceptor)
|
||||
interceptor1.after_request = None
|
||||
|
||||
result = await execute_after_request_interceptors(
|
||||
[interceptor1], mock_context, response, event
|
||||
)
|
||||
|
||||
assert result is event
|
||||
|
||||
Reference in New Issue
Block a user