You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Allow custom request and event converters in A2aAgentExecutor
This change introduces type aliases for request and event conversion functions: - `A2ARequestToADKRunArgsConverter`: For converting A2A `RequestContext` to an `ADKRunArgs` Pydantic model. - `AdkEventToA2AEventsConverter`: For converting ADK `Event` to a list of A2A `A2AEvent` objects. The `convert_a2a_request_to_adk_run_args` function now returns a structured `ADKRunArgs` model instead of a generic dictionary, improving type safety. These converter types can now be provided via the `A2aAgentExecutorConfig` to customize the conversion logic used by the `A2aAgentExecutor`. The executor defaults to the existing `convert_a2a_request_to_adk_run_args` and `convert_event_to_a2a_events` functions if no custom converters are specified. This allows users to inject their own logic for handling request and event conversions, for example, to add custom metadata or transform data types, without modifying the core executor. PiperOrigin-RevId: 819934960
This commit is contained in:
committed by
Copybara-Service
parent
6ab1498aa0
commit
a17f3b2e6d
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
import logging
|
||||
@@ -57,6 +58,34 @@ DEFAULT_ERROR_MESSAGE = "An error occurred during processing"
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
AdkEventToA2AEventsConverter = Callable[
|
||||
[
|
||||
Event,
|
||||
InvocationContext,
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
GenAIPartToA2APartConverter,
|
||||
],
|
||||
List[A2AEvent],
|
||||
]
|
||||
"""A callable that converts an ADK Event into a list of A2A events.
|
||||
|
||||
This interface allows for custom logic to map ADK's event structure to the
|
||||
event structure expected by the A2A server.
|
||||
|
||||
Args:
|
||||
event: The source ADK Event to convert.
|
||||
invocation_context: The context of the ADK agent invocation.
|
||||
task_id: The ID of the A2A task being processed.
|
||||
context_id: The context ID from the A2A request.
|
||||
part_converter: A function to convert GenAI content parts to A2A
|
||||
parts.
|
||||
|
||||
Returns:
|
||||
A list of A2A events.
|
||||
"""
|
||||
|
||||
|
||||
def _serialize_metadata_value(value: Any) -> str:
|
||||
"""Safely serializes metadata values to string format.
|
||||
|
||||
|
||||
@@ -14,8 +14,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
@@ -35,6 +39,39 @@ from .part_converter import A2APartToGenAIPartConverter
|
||||
from .part_converter import convert_a2a_part_to_genai_part
|
||||
|
||||
|
||||
@a2a_experimental
|
||||
class AgentRunRequest(BaseModel):
|
||||
"""Data model for arguments passed to the ADK runner."""
|
||||
|
||||
user_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
invocation_id: Optional[str] = None
|
||||
new_message: Optional[genai_types.Content] = None
|
||||
state_delta: Optional[dict[str, Any]] = None
|
||||
run_config: Optional[RunConfig] = None
|
||||
|
||||
|
||||
A2ARequestToAgentRunRequestConverter = Callable[
|
||||
[
|
||||
RequestContext,
|
||||
A2APartToGenAIPartConverter,
|
||||
],
|
||||
AgentRunRequest,
|
||||
]
|
||||
"""A callable that converts an A2A RequestContext to RunnerRequest for ADK runner.
|
||||
|
||||
This interface allows for custom logic to map an incoming A2A RequestContext to the
|
||||
structured arguments expected by the ADK runner's `run_async` method.
|
||||
|
||||
Args:
|
||||
request: The incoming request context from the A2A server.
|
||||
part_converter: A function to convert A2A content parts to GenAI parts.
|
||||
|
||||
Returns:
|
||||
An RunnerRequest object containing the keyword arguments for ADK runner's run_async method.
|
||||
"""
|
||||
|
||||
|
||||
def _get_user_id(request: RequestContext) -> str:
|
||||
# Get user from call context if available (auth is enabled on a2a server)
|
||||
if (
|
||||
@@ -49,20 +86,32 @@ def _get_user_id(request: RequestContext) -> str:
|
||||
|
||||
|
||||
@a2a_experimental
|
||||
def convert_a2a_request_to_adk_run_args(
|
||||
def convert_a2a_request_to_agent_run_request(
|
||||
request: RequestContext,
|
||||
part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part,
|
||||
) -> dict[str, Any]:
|
||||
) -> AgentRunRequest:
|
||||
"""Converts an A2A RequestContext to an AgentRunRequest model.
|
||||
|
||||
Args:
|
||||
request: The incoming request context from the A2A server.
|
||||
part_converter: A function to convert A2A content parts to GenAI parts.
|
||||
|
||||
Returns:
|
||||
A AgentRunRequest object ready to be used as arguments for the ADK runner.
|
||||
|
||||
Raises:
|
||||
ValueError: If the request message is None.
|
||||
"""
|
||||
|
||||
if not request.message:
|
||||
raise ValueError('Request message cannot be None')
|
||||
|
||||
return {
|
||||
'user_id': _get_user_id(request),
|
||||
'session_id': request.context_id,
|
||||
'new_message': genai_types.Content(
|
||||
return AgentRunRequest(
|
||||
user_id=_get_user_id(request),
|
||||
session_id=request.context_id,
|
||||
new_message=genai_types.Content(
|
||||
role='user',
|
||||
parts=[part_converter(part) for part in request.message.parts],
|
||||
),
|
||||
'run_config': RunConfig(),
|
||||
}
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Awaitable
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
@@ -52,12 +51,15 @@ from google.adk.runners import Runner
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from ..converters.event_converter import AdkEventToA2AEventsConverter
|
||||
from ..converters.event_converter import convert_event_to_a2a_events
|
||||
from ..converters.part_converter import A2APartToGenAIPartConverter
|
||||
from ..converters.part_converter import convert_a2a_part_to_genai_part
|
||||
from ..converters.part_converter import convert_genai_part_to_a2a_part
|
||||
from ..converters.part_converter import GenAIPartToA2APartConverter
|
||||
from ..converters.request_converter import convert_a2a_request_to_adk_run_args
|
||||
from ..converters.request_converter import A2ARequestToAgentRunRequestConverter
|
||||
from ..converters.request_converter import AgentRunRequest
|
||||
from ..converters.request_converter import convert_a2a_request_to_agent_run_request
|
||||
from ..converters.utils import _get_adk_metadata_key
|
||||
from ..experimental import a2a_experimental
|
||||
from .task_result_aggregator import TaskResultAggregator
|
||||
@@ -75,6 +77,10 @@ class A2aAgentExecutorConfig(BaseModel):
|
||||
gen_ai_part_converter: GenAIPartToA2APartConverter = (
|
||||
convert_genai_part_to_a2a_part
|
||||
)
|
||||
request_converter: A2ARequestToAgentRunRequestConverter = (
|
||||
convert_a2a_request_to_agent_run_request
|
||||
)
|
||||
event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events
|
||||
|
||||
|
||||
@a2a_experimental
|
||||
@@ -192,19 +198,20 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
# Resolve the runner instance
|
||||
runner = await self._resolve_runner()
|
||||
|
||||
# Convert the a2a request to ADK run args
|
||||
run_args = convert_a2a_request_to_adk_run_args(
|
||||
context, self._config.a2a_part_converter
|
||||
# Convert the a2a request to AgentRunRequest
|
||||
run_request = self._config.request_converter(
|
||||
context,
|
||||
self._config.a2a_part_converter,
|
||||
)
|
||||
|
||||
# ensure the session exists
|
||||
session = await self._prepare_session(context, run_args, runner)
|
||||
session = await self._prepare_session(context, run_request, runner)
|
||||
|
||||
# create invocation context
|
||||
invocation_context = runner._new_invocation_context(
|
||||
session=session,
|
||||
new_message=run_args['new_message'],
|
||||
run_config=run_args['run_config'],
|
||||
new_message=run_request.new_message,
|
||||
run_config=run_request.run_config,
|
||||
)
|
||||
|
||||
# publish the task working event
|
||||
@@ -219,16 +226,16 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
final=False,
|
||||
metadata={
|
||||
_get_adk_metadata_key('app_name'): runner.app_name,
|
||||
_get_adk_metadata_key('user_id'): run_args['user_id'],
|
||||
_get_adk_metadata_key('session_id'): run_args['session_id'],
|
||||
_get_adk_metadata_key('user_id'): run_request.user_id,
|
||||
_get_adk_metadata_key('session_id'): run_request.session_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
task_result_aggregator = TaskResultAggregator()
|
||||
async with Aclosing(runner.run_async(**run_args)) as agen:
|
||||
async with Aclosing(runner.run_async(**vars(run_request))) as agen:
|
||||
async for adk_event in agen:
|
||||
for a2a_event in convert_event_to_a2a_events(
|
||||
for a2a_event in self._config.event_converter(
|
||||
adk_event,
|
||||
invocation_context,
|
||||
context.task_id,
|
||||
@@ -284,12 +291,15 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
)
|
||||
|
||||
async def _prepare_session(
|
||||
self, context: RequestContext, run_args: dict[str, Any], runner: Runner
|
||||
self,
|
||||
context: RequestContext,
|
||||
run_request: AgentRunRequest,
|
||||
runner: Runner,
|
||||
):
|
||||
|
||||
session_id = run_args['session_id']
|
||||
session_id = run_request.session_id
|
||||
# create a new session if not exists
|
||||
user_id = run_args['user_id']
|
||||
user_id = run_request.user_id
|
||||
session = await runner.session_service.get_session(
|
||||
app_name=runner.app_name,
|
||||
user_id=user_id,
|
||||
@@ -302,7 +312,7 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
state={},
|
||||
session_id=session_id,
|
||||
)
|
||||
# Update run_args with the new session_id
|
||||
run_args['session_id'] = session.id
|
||||
# Update run_request with the new session_id
|
||||
run_request.session_id = session.id
|
||||
|
||||
return session
|
||||
|
||||
@@ -27,7 +27,7 @@ pytestmark = pytest.mark.skipif(
|
||||
try:
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from google.adk.a2a.converters.request_converter import _get_user_id
|
||||
from google.adk.a2a.converters.request_converter import convert_a2a_request_to_adk_run_args
|
||||
from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request
|
||||
from google.adk.runners import RunConfig
|
||||
from google.genai import types as genai_types
|
||||
except ImportError as e:
|
||||
@@ -143,11 +143,11 @@ class TestGetUserId:
|
||||
assert result == "A2A_USER_None"
|
||||
|
||||
|
||||
class TestConvertA2aRequestToAdkRunArgs:
|
||||
"""Test cases for convert_a2a_request_to_adk_run_args function."""
|
||||
class TestConvertA2aRequestToAgentRunRequest:
|
||||
"""Test cases for convert_a2a_request_to_agent_run_request function."""
|
||||
|
||||
def test_convert_a2a_request_basic(self):
|
||||
"""Test basic conversion of A2A request to ADK run args."""
|
||||
"""Test basic conversion of A2A request to ADK AgentRunRequest."""
|
||||
# Arrange
|
||||
mock_part1 = Mock()
|
||||
mock_part2 = Mock()
|
||||
@@ -173,16 +173,18 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
mock_convert_part.side_effect = [mock_genai_part1, mock_genai_part2]
|
||||
|
||||
# Act
|
||||
result = convert_a2a_request_to_adk_run_args(request, mock_convert_part)
|
||||
result = convert_a2a_request_to_agent_run_request(
|
||||
request, mock_convert_part
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["user_id"] == "test_user"
|
||||
assert result["session_id"] == "test_context_123"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == [mock_genai_part1, mock_genai_part2]
|
||||
assert isinstance(result["run_config"], RunConfig)
|
||||
assert result.user_id == "test_user"
|
||||
assert result.session_id == "test_context_123"
|
||||
assert isinstance(result.new_message, genai_types.Content)
|
||||
assert result.new_message.role == "user"
|
||||
assert result.new_message.parts == [mock_genai_part1, mock_genai_part2]
|
||||
assert isinstance(result.run_config, RunConfig)
|
||||
|
||||
# Verify calls
|
||||
assert mock_convert_part.call_count == 2
|
||||
@@ -197,7 +199,7 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Request message cannot be None"):
|
||||
convert_a2a_request_to_adk_run_args(request)
|
||||
convert_a2a_request_to_agent_run_request(request)
|
||||
|
||||
def test_convert_a2a_request_empty_parts(self):
|
||||
"""Test conversion with empty parts list."""
|
||||
@@ -212,16 +214,18 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
request.call_context = None
|
||||
|
||||
# Act
|
||||
result = convert_a2a_request_to_adk_run_args(request, mock_convert_part)
|
||||
result = convert_a2a_request_to_agent_run_request(
|
||||
request, mock_convert_part
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["user_id"] == "A2A_USER_test_context_123"
|
||||
assert result["session_id"] == "test_context_123"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == []
|
||||
assert isinstance(result["run_config"], RunConfig)
|
||||
assert result.user_id == "A2A_USER_test_context_123"
|
||||
assert result.session_id == "test_context_123"
|
||||
assert isinstance(result.new_message, genai_types.Content)
|
||||
assert result.new_message.role == "user"
|
||||
assert result.new_message.parts == []
|
||||
assert isinstance(result.run_config, RunConfig)
|
||||
|
||||
# Verify convert_part wasn't called
|
||||
mock_convert_part.assert_not_called()
|
||||
@@ -244,16 +248,18 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
mock_convert_part.return_value = mock_genai_part
|
||||
|
||||
# Act
|
||||
result = convert_a2a_request_to_adk_run_args(request, mock_convert_part)
|
||||
result = convert_a2a_request_to_agent_run_request(
|
||||
request, mock_convert_part
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["user_id"] == "A2A_USER_None"
|
||||
assert result["session_id"] is None
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == [mock_genai_part]
|
||||
assert isinstance(result["run_config"], RunConfig)
|
||||
assert result.user_id == "A2A_USER_None"
|
||||
assert result.session_id is None
|
||||
assert isinstance(result.new_message, genai_types.Content)
|
||||
assert result.new_message.role == "user"
|
||||
assert result.new_message.parts == [mock_genai_part]
|
||||
assert isinstance(result.run_config, RunConfig)
|
||||
|
||||
def test_convert_a2a_request_no_auth(self):
|
||||
"""Test conversion when no authentication is available."""
|
||||
@@ -273,16 +279,18 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
mock_convert_part.return_value = mock_genai_part
|
||||
|
||||
# Act
|
||||
result = convert_a2a_request_to_adk_run_args(request, mock_convert_part)
|
||||
result = convert_a2a_request_to_agent_run_request(
|
||||
request, mock_convert_part
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["user_id"] == "A2A_USER_session_123"
|
||||
assert result["session_id"] == "session_123"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == [mock_genai_part]
|
||||
assert isinstance(result["run_config"], RunConfig)
|
||||
assert result.user_id == "A2A_USER_session_123"
|
||||
assert result.session_id == "session_123"
|
||||
assert isinstance(result.new_message, genai_types.Content)
|
||||
assert result.new_message.role == "user"
|
||||
assert result.new_message.parts == [mock_genai_part]
|
||||
assert isinstance(result.run_config, RunConfig)
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
@@ -312,16 +320,18 @@ class TestIntegration:
|
||||
mock_convert_part.return_value = mock_genai_part
|
||||
|
||||
# Act
|
||||
result = convert_a2a_request_to_adk_run_args(request, mock_convert_part)
|
||||
result = convert_a2a_request_to_agent_run_request(
|
||||
request, mock_convert_part
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["user_id"] == "auth_user" # Should use authenticated user
|
||||
assert result["session_id"] == "mysession"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == [mock_genai_part]
|
||||
assert isinstance(result["run_config"], RunConfig)
|
||||
assert result.user_id == "auth_user" # Should use authenticated user
|
||||
assert result.session_id == "mysession"
|
||||
assert isinstance(result.new_message, genai_types.Content)
|
||||
assert result.new_message.role == "user"
|
||||
assert result.new_message.parts == [mock_genai_part]
|
||||
assert isinstance(result.run_config, RunConfig)
|
||||
|
||||
def test_end_to_end_conversion_with_fallback_user(self):
|
||||
"""Test end-to-end conversion with fallback user ID."""
|
||||
@@ -341,15 +351,17 @@ class TestIntegration:
|
||||
mock_convert_part.return_value = mock_genai_part
|
||||
|
||||
# Act
|
||||
result = convert_a2a_request_to_adk_run_args(request, mock_convert_part)
|
||||
result = convert_a2a_request_to_agent_run_request(
|
||||
request, mock_convert_part
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert (
|
||||
result["user_id"] == "A2A_USER_test_session_456"
|
||||
result.user_id == "A2A_USER_test_session_456"
|
||||
) # Should fallback to context ID
|
||||
assert result["session_id"] == "test_session_456"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == [mock_genai_part]
|
||||
assert isinstance(result["run_config"], RunConfig)
|
||||
assert result.session_id == "test_session_456"
|
||||
assert isinstance(result.new_message, genai_types.Content)
|
||||
assert result.new_message.role == "user"
|
||||
assert result.new_message.parts == [mock_genai_part]
|
||||
assert isinstance(result.run_config, RunConfig)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user