You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: New implementation of A2aAgentExecutor and A2A-ADK conversion
This change introduces new implementation files for the A2aAgentExecutor and event converters. The existing A2aAgentExecutor now acts as a wrapper, allowing a switch between the legacy and new implementations. The new implementation includes support for execution interceptors and a dedicated executor context. Main Changes= `a2a_agent_executor_impl.py` = the new implementation of the AgentExecutor differs from the legacy one (`a2a_agent_executor.py`) for the removal of the TaskResultAggregator and the explicit `InvocationContext` creation. Instead, it uses `ExecutorContext` and delegates event conversion to the new logic that supports streaming. It maintains an `agents_artifact` state map to handle partial updates and emits TaskArtifactUpdateEvents for content. The `long_running_functions.py` is used to keep track of the LongRunning FunctionCalls and respective FunctionResponse, to emit them at the end of the generation loop in a `TaskStateUpdateEvent(input-required/auth-required)`. `from_adk_event.py` = this file replaces the conversion functions in the `event_converter.py` used to convert the adk events into a2a events, estrapolating them in a dedicated file. The main changes in the methods are the introduction of TaskArtifactUpdateEvent to handle content parts, allowing for true artifact streaming and chunking. It utilizes an `agents_artifacts` dictionary to track artifact IDs across partial events to correctly handle append operations. PiperOrigin-RevId: 878399140
This commit is contained in:
committed by
Copybara-Service
parent
2b8ccd4a00
commit
87ffc55640
@@ -0,0 +1,288 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import uuid
|
||||
|
||||
from a2a.server.events import Event as A2AEvent
|
||||
from a2a.types import Artifact
|
||||
from a2a.types import DataPart
|
||||
from a2a.types import Message
|
||||
from a2a.types import Part as A2APart
|
||||
from a2a.types import Role
|
||||
from a2a.types import TaskArtifactUpdateEvent
|
||||
from a2a.types import TaskState
|
||||
from a2a.types import TaskStatus
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
from a2a.types import TextPart
|
||||
|
||||
from ...events.event import Event
|
||||
from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
from ..experimental import a2a_experimental
|
||||
from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY
|
||||
from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
|
||||
from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY
|
||||
from .part_converter import convert_genai_part_to_a2a_part
|
||||
from .part_converter import GenAIPartToA2APartConverter
|
||||
from .utils import _get_adk_metadata_key
|
||||
|
||||
# Constants
|
||||
DEFAULT_ERROR_MESSAGE = "An error occurred during processing"
|
||||
|
||||
# Logger
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
A2AUpdateEvent = Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]
|
||||
|
||||
AdkEventToA2AEventsConverter = Callable[
|
||||
[
|
||||
Event,
|
||||
Optional[Dict[str, str]],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
GenAIPartToA2APartConverter,
|
||||
],
|
||||
List[A2AUpdateEvent],
|
||||
]
|
||||
"""A callable that converts an ADK Event into a list of A2A events.
|
||||
|
||||
This interface allows for custom logic to map ADK's event structure to the
|
||||
event structure expected by the A2A server.
|
||||
|
||||
Args:
|
||||
event: The source ADK Event to convert.
|
||||
agents_artifacts: State map for tracking active artifact IDs across chunks.
|
||||
task_id: The ID of the A2A task being processed.
|
||||
context_id: The context ID from the A2A request.
|
||||
part_converter: A function to convert GenAI content parts to A2A
|
||||
parts.
|
||||
|
||||
Returns:
|
||||
A list of A2A events.
|
||||
"""
|
||||
|
||||
|
||||
def _convert_adk_parts_to_a2a_parts(
|
||||
event: Event,
|
||||
part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part,
|
||||
) -> Optional[List[A2APart]]:
|
||||
"""Converts an ADK event to an A2A parts list.
|
||||
|
||||
Args:
|
||||
event: The ADK event to convert.
|
||||
part_converter: The function to convert GenAI part to A2A part.
|
||||
|
||||
Returns:
|
||||
A list of A2A parts representing the converted ADK event.
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are invalid.
|
||||
"""
|
||||
if not event:
|
||||
raise ValueError("Event cannot be None")
|
||||
|
||||
if not event.content or not event.content.parts:
|
||||
return []
|
||||
|
||||
try:
|
||||
output_parts = []
|
||||
for part in event.content.parts:
|
||||
a2a_parts = part_converter(part)
|
||||
if not isinstance(a2a_parts, list):
|
||||
a2a_parts = [a2a_parts] if a2a_parts else []
|
||||
for a2a_part in a2a_parts:
|
||||
output_parts.append(a2a_part)
|
||||
|
||||
return output_parts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to convert event to status message: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def create_error_status_event(
|
||||
event: Event,
|
||||
task_id: Optional[str] = None,
|
||||
context_id: Optional[str] = None,
|
||||
) -> TaskStatusUpdateEvent:
|
||||
"""Creates a TaskStatusUpdateEvent for error scenarios.
|
||||
|
||||
Args:
|
||||
event: The ADK event containing error information.
|
||||
task_id: Optional task ID to use for generated events.
|
||||
context_id: Optional Context ID to use for generated events.
|
||||
|
||||
Returns:
|
||||
A TaskStatusUpdateEvent with FAILED state.
|
||||
"""
|
||||
error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE
|
||||
|
||||
error_event = TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.failed,
|
||||
message=Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.agent,
|
||||
parts=[A2APart(root=TextPart(text=error_message))],
|
||||
),
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
final=True,
|
||||
)
|
||||
return _add_event_metadata(event, [error_event])[0]
|
||||
|
||||
|
||||
@a2a_experimental
|
||||
def convert_event_to_a2a_events(
|
||||
event: Event,
|
||||
agents_artifacts: Dict[str, str],
|
||||
task_id: Optional[str] = None,
|
||||
context_id: Optional[str] = None,
|
||||
part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part,
|
||||
) -> List[A2AUpdateEvent]:
|
||||
"""Converts a GenAI event to a list of A2A StatusUpdate and ArtifactUpdate events.
|
||||
|
||||
Args:
|
||||
event: The ADK event to convert.
|
||||
agents_artifacts: State map for tracking active artifact IDs across chunks.
|
||||
task_id: Optional task ID to use for generated events.
|
||||
context_id: Optional Context ID to use for generated events.
|
||||
part_converter: The function to convert GenAI part to A2A part.
|
||||
|
||||
Returns:
|
||||
A list of A2A update events representing the converted ADK event.
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are invalid.
|
||||
"""
|
||||
if not event:
|
||||
raise ValueError("Event cannot be None")
|
||||
if agents_artifacts is None:
|
||||
raise ValueError("Agents artifacts cannot be None")
|
||||
|
||||
a2a_events = []
|
||||
try:
|
||||
a2a_parts = _convert_adk_parts_to_a2a_parts(
|
||||
event, part_converter=part_converter
|
||||
)
|
||||
# Handle artifact updates for normal parts
|
||||
if a2a_parts:
|
||||
agent_name = event.author
|
||||
partial = event.partial or False
|
||||
|
||||
artifact_id = agents_artifacts.get(agent_name)
|
||||
if artifact_id:
|
||||
append = partial
|
||||
if not partial:
|
||||
del agents_artifacts[agent_name]
|
||||
else:
|
||||
artifact_id = str(uuid.uuid4())
|
||||
# TODO: Clarify if new artifact id must have append=False
|
||||
append = False
|
||||
if partial:
|
||||
agents_artifacts[agent_name] = artifact_id
|
||||
|
||||
a2a_events.append(
|
||||
TaskArtifactUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
last_chunk=not partial,
|
||||
append=append,
|
||||
artifact=Artifact(
|
||||
artifact_id=artifact_id,
|
||||
parts=a2a_parts,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
a2a_events = _add_event_metadata(event, a2a_events)
|
||||
return a2a_events
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to convert event to A2A events: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Optional[Any]:
|
||||
"""Serializes a value and returns it if it contains meaningful content.
|
||||
|
||||
Returns None if the value is empty or missing.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Handle Pydantic models
|
||||
if hasattr(value, "model_dump"):
|
||||
try:
|
||||
dumped = value.model_dump(
|
||||
exclude_none=True,
|
||||
exclude_unset=True,
|
||||
exclude_defaults=True,
|
||||
by_alias=True,
|
||||
)
|
||||
return dumped if dumped else None
|
||||
except Exception as e:
|
||||
logger.warning("Failed to serialize Pydantic model, falling back: %s", e)
|
||||
return str(value)
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
# TODO: Clarify if this metadata needs to be translated back into the ADK event
|
||||
def _add_event_metadata(
|
||||
event: Event, a2a_events: List[A2AEvent]
|
||||
) -> List[A2AEvent]:
|
||||
"""Gets the context metadata for the event and applies it to A2A events."""
|
||||
if not event:
|
||||
raise ValueError("Event cannot be None")
|
||||
|
||||
metadata_values = {
|
||||
"invocation_id": event.invocation_id,
|
||||
"author": event.author,
|
||||
"event_id": event.id,
|
||||
"branch": event.branch,
|
||||
"citation_metadata": event.citation_metadata,
|
||||
"grounding_metadata": event.grounding_metadata,
|
||||
"custom_metadata": event.custom_metadata,
|
||||
"usage_metadata": event.usage_metadata,
|
||||
"error_code": event.error_code,
|
||||
"actions": event.actions,
|
||||
}
|
||||
|
||||
metadata = {}
|
||||
for field_name, field_value in metadata_values.items():
|
||||
value = _serialize_value(field_value)
|
||||
if value is not None:
|
||||
metadata[_get_adk_metadata_key(field_name)] = value
|
||||
|
||||
for a2a_event in a2a_events:
|
||||
if isinstance(a2a_event, TaskStatusUpdateEvent):
|
||||
a2a_event.status.message.metadata = metadata.copy()
|
||||
elif isinstance(a2a_event, TaskArtifactUpdateEvent):
|
||||
a2a_event.artifact.metadata = metadata.copy()
|
||||
|
||||
return a2a_events
|
||||
@@ -0,0 +1,215 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import List
|
||||
from typing import Set
|
||||
import uuid
|
||||
|
||||
from a2a.server.agent_execution.context import RequestContext
|
||||
from a2a.types import DataPart
|
||||
from a2a.types import Message
|
||||
from a2a.types import Part as A2APart
|
||||
from a2a.types import Role
|
||||
from a2a.types import TaskState
|
||||
from a2a.types import TaskStatus
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
from a2a.types import TextPart
|
||||
from google.genai import types as genai_types
|
||||
|
||||
from ...events.event import Event
|
||||
from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY
|
||||
from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
|
||||
from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE
|
||||
from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY
|
||||
from .part_converter import A2APartToGenAIPartConverter
|
||||
from .part_converter import convert_a2a_part_to_genai_part
|
||||
from .utils import _get_adk_metadata_key
|
||||
|
||||
|
||||
class LongRunningFunctions:
|
||||
"""Keeps track of long running function calls and related responses."""
|
||||
|
||||
def __init__(
|
||||
self, part_converter: A2APartToGenAIPartConverter | None = None
|
||||
) -> None:
|
||||
self._parts: List[genai_types.Part] = []
|
||||
self._long_running_tool_ids: Set[str] = set()
|
||||
self._part_converter = part_converter or convert_a2a_part_to_genai_part
|
||||
self._task_state: TaskState = TaskState.input_required
|
||||
|
||||
def has_long_running_function_calls(self) -> bool:
|
||||
"""Returns True if there are long running function calls."""
|
||||
return bool(self._long_running_tool_ids)
|
||||
|
||||
def process_event(self, event: Event) -> Event:
|
||||
"""Processes parts to extract long running calls and responses.
|
||||
|
||||
Returns a copy of the input event with processed parts removed from
|
||||
event.content.parts.
|
||||
|
||||
Args:
|
||||
event: The ADK event containing long running tool IDs and content parts.
|
||||
"""
|
||||
event = event.model_copy(deep=True)
|
||||
if not event.content or not event.content.parts:
|
||||
return event
|
||||
|
||||
kept_parts = []
|
||||
for part in event.content.parts:
|
||||
should_remove = False
|
||||
if part.function_call:
|
||||
if part.function_call.id in event.long_running_tool_ids:
|
||||
if not event.partial:
|
||||
self._parts.append(part)
|
||||
self._long_running_tool_ids.add(part.function_call.id)
|
||||
should_remove = True
|
||||
|
||||
elif part.function_response:
|
||||
if part.function_response.id in self._long_running_tool_ids:
|
||||
if not event.partial:
|
||||
self._parts.append(part)
|
||||
should_remove = True
|
||||
|
||||
if not should_remove:
|
||||
kept_parts.append(part)
|
||||
|
||||
event.content.parts = kept_parts
|
||||
return event
|
||||
|
||||
def create_long_running_function_call_event(
|
||||
self,
|
||||
task_id: str,
|
||||
context_id: str,
|
||||
) -> TaskStatusUpdateEvent:
|
||||
"""Creates a task status update event for the long running function calls."""
|
||||
if not self._long_running_tool_ids:
|
||||
return None
|
||||
|
||||
a2a_parts = self._return_long_running_parts()
|
||||
if not a2a_parts:
|
||||
return None
|
||||
|
||||
return TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(
|
||||
state=self._task_state,
|
||||
message=Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.agent,
|
||||
parts=a2a_parts,
|
||||
),
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
final=True,
|
||||
)
|
||||
|
||||
def _return_long_running_parts(self) -> List[A2APart]:
|
||||
"""Converts long-running parts to A2A parts."""
|
||||
if not self._long_running_tool_ids:
|
||||
return []
|
||||
|
||||
output_parts = []
|
||||
for part in self._parts:
|
||||
a2a_parts = self._part_converter(part)
|
||||
if not isinstance(a2a_parts, list):
|
||||
a2a_parts = [a2a_parts] if a2a_parts else []
|
||||
for a2a_part in a2a_parts:
|
||||
self._mark_long_running_function_call(a2a_part)
|
||||
output_parts.append(a2a_part)
|
||||
|
||||
return output_parts
|
||||
|
||||
def _mark_long_running_function_call(self, a2a_part: A2APart) -> None:
|
||||
"""Processes long-running tool metadata for an A2A part.
|
||||
|
||||
Args:
|
||||
a2a_part: The A2A part to potentially mark as long-running.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(a2a_part.root, DataPart)
|
||||
and a2a_part.root.metadata
|
||||
and a2a_part.root.metadata.get(
|
||||
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
|
||||
)
|
||||
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
|
||||
):
|
||||
a2a_part.root.metadata[
|
||||
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
|
||||
] = True
|
||||
# If the function is a request for EUC, set the task state to
|
||||
# auth_required. Otherwise, set it to input_required. Save the state of
|
||||
# the last function call, as it will be the state of the task.
|
||||
if a2a_part.root.metadata.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME:
|
||||
self._task_state = TaskState.auth_required
|
||||
else:
|
||||
self._task_state = TaskState.input_required
|
||||
|
||||
|
||||
def handle_user_input(context: RequestContext) -> TaskStatusUpdateEvent | None:
|
||||
"""Processes user input events, validating function responses."""
|
||||
|
||||
if (
|
||||
not context.current_task
|
||||
or not context.current_task.status
|
||||
or (
|
||||
context.current_task.status.state != TaskState.input_required
|
||||
and context.current_task.status.state != TaskState.auth_required
|
||||
)
|
||||
):
|
||||
return None
|
||||
|
||||
# If the task is in input_required or auth_required state, we expect the user
|
||||
# to provide a response for the function call. Check if the user input
|
||||
# contains a function response.
|
||||
for a2a_part in context.message.parts:
|
||||
if (
|
||||
isinstance(a2a_part.root, DataPart)
|
||||
and a2a_part.root.metadata
|
||||
and a2a_part.root.metadata.get(
|
||||
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
|
||||
)
|
||||
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE
|
||||
):
|
||||
return None
|
||||
|
||||
return TaskStatusUpdateEvent(
|
||||
task_id=context.task_id,
|
||||
context_id=context.context_id,
|
||||
status=TaskStatus(
|
||||
state=context.current_task.status.state,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
message=Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.agent,
|
||||
parts=[
|
||||
A2APart(
|
||||
root=TextPart(
|
||||
text=(
|
||||
"It was not provided a function response for the"
|
||||
" function call."
|
||||
)
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
final=True,
|
||||
)
|
||||
@@ -35,22 +35,14 @@ from a2a.types import TaskStatus
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
from a2a.types import TextPart
|
||||
from google.adk.runners import Runner
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from ...utils.context_utils import Aclosing
|
||||
from ..converters.event_converter import AdkEventToA2AEventsConverter
|
||||
from ..converters.event_converter import convert_event_to_a2a_events
|
||||
from ..converters.part_converter import A2APartToGenAIPartConverter
|
||||
from ..converters.part_converter import convert_a2a_part_to_genai_part
|
||||
from ..converters.part_converter import convert_genai_part_to_a2a_part
|
||||
from ..converters.part_converter import GenAIPartToA2APartConverter
|
||||
from ..converters.request_converter import A2ARequestToAgentRunRequestConverter
|
||||
from ..converters.request_converter import AgentRunRequest
|
||||
from ..converters.request_converter import convert_a2a_request_to_agent_run_request
|
||||
from ..converters.utils import _get_adk_metadata_key
|
||||
from ..experimental import a2a_experimental
|
||||
from .config import ExecuteInterceptor
|
||||
from .a2a_agent_executor_impl import _A2aAgentExecutor as ExecutorImpl
|
||||
from .config import A2aAgentExecutorConfig
|
||||
from .executor_context import ExecutorContext
|
||||
from .task_result_aggregator import TaskResultAggregator
|
||||
from .utils import execute_after_agent_interceptors
|
||||
@@ -60,29 +52,16 @@ from .utils import execute_before_agent_interceptors
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
@a2a_experimental
|
||||
class A2aAgentExecutorConfig(BaseModel):
|
||||
"""Configuration for the A2aAgentExecutor."""
|
||||
|
||||
a2a_part_converter: A2APartToGenAIPartConverter = (
|
||||
convert_a2a_part_to_genai_part
|
||||
)
|
||||
gen_ai_part_converter: GenAIPartToA2APartConverter = (
|
||||
convert_genai_part_to_a2a_part
|
||||
)
|
||||
request_converter: A2ARequestToAgentRunRequestConverter = (
|
||||
convert_a2a_request_to_agent_run_request
|
||||
)
|
||||
event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events
|
||||
|
||||
execute_interceptors: Optional[list[ExecuteInterceptor]] = None
|
||||
|
||||
|
||||
@a2a_experimental
|
||||
class A2aAgentExecutor(AgentExecutor):
|
||||
"""An AgentExecutor that runs an ADK Agent against an A2A request and
|
||||
|
||||
publishes updates to an event queue.
|
||||
|
||||
Args:
|
||||
runner: The runner to use for the agent.
|
||||
config: The config to use for the executor.
|
||||
use_legacy: Whether to use the legacy executor implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -90,10 +69,15 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
*,
|
||||
runner: Runner | Callable[..., Runner | Awaitable[Runner]],
|
||||
config: Optional[A2aAgentExecutorConfig] = None,
|
||||
use_legacy: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self._runner = runner
|
||||
self._config = config or A2aAgentExecutorConfig()
|
||||
if not use_legacy:
|
||||
self._executor_impl = ExecutorImpl(runner=runner, config=config)
|
||||
else:
|
||||
self._executor_impl = None
|
||||
self._runner = runner
|
||||
self._config = config or A2aAgentExecutorConfig()
|
||||
|
||||
async def _resolve_runner(self) -> Runner:
|
||||
"""Resolve the runner, handling cases where it's a callable that returns a Runner."""
|
||||
@@ -122,6 +106,10 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
@override
|
||||
async def cancel(self, context: RequestContext, event_queue: EventQueue):
|
||||
"""Cancel the execution."""
|
||||
if self._executor_impl:
|
||||
await self._executor_impl.cancel(context, event_queue)
|
||||
return
|
||||
|
||||
# TODO: Implement proper cancellation logic if needed
|
||||
raise NotImplementedError('Cancellation is not supported')
|
||||
|
||||
@@ -132,6 +120,7 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
event_queue: EventQueue,
|
||||
):
|
||||
"""Executes an A2A request and publishes updates to the event queue
|
||||
|
||||
specified. It runs as following:
|
||||
* Takes the input from the A2A request
|
||||
* Convert the input to ADK input content, and runs the ADK agent
|
||||
@@ -139,6 +128,10 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
* Converts the ADK output events into A2A task updates
|
||||
* Publishes the updates back to A2A server via event queue
|
||||
"""
|
||||
if self._executor_impl:
|
||||
await self._executor_impl.execute(context, event_queue)
|
||||
return
|
||||
|
||||
if not context.message:
|
||||
raise ValueError('A2A request must have a message')
|
||||
|
||||
@@ -213,7 +206,7 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
run_config=run_request.run_config,
|
||||
)
|
||||
|
||||
self._executor_context = ExecutorContext(
|
||||
executor_context = ExecutorContext(
|
||||
app_name=runner.app_name,
|
||||
user_id=run_request.user_id,
|
||||
session_id=run_request.session_id,
|
||||
@@ -250,7 +243,7 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
):
|
||||
a2a_event = await execute_after_event_interceptors(
|
||||
a2a_event,
|
||||
self._executor_context,
|
||||
executor_context,
|
||||
adk_event,
|
||||
self._config.execute_interceptors,
|
||||
)
|
||||
@@ -302,7 +295,7 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
)
|
||||
|
||||
final_event = await execute_after_agent_interceptors(
|
||||
self._executor_context,
|
||||
executor_context,
|
||||
final_event,
|
||||
self._config.execute_interceptors,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,310 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Awaitable
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from a2a.server.agent_execution import AgentExecutor
|
||||
from a2a.server.agent_execution.context import RequestContext
|
||||
from a2a.server.events.event_queue import EventQueue
|
||||
from a2a.types import Artifact
|
||||
from a2a.types import Message
|
||||
from a2a.types import Part
|
||||
from a2a.types import Role
|
||||
from a2a.types import Task
|
||||
from a2a.types import TaskState
|
||||
from a2a.types import TaskStatus
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
from a2a.types import TextPart
|
||||
from typing_extensions import override
|
||||
|
||||
from ...runners import Runner
|
||||
from ...utils.context_utils import Aclosing
|
||||
from ..converters.from_adk_event import create_error_status_event
|
||||
from ..converters.long_running_functions import handle_user_input
|
||||
from ..converters.long_running_functions import LongRunningFunctions
|
||||
from ..converters.request_converter import AgentRunRequest
|
||||
from ..converters.utils import _get_adk_metadata_key
|
||||
from ..experimental import a2a_experimental
|
||||
from .config import A2aAgentExecutorConfig
|
||||
from .executor_context import ExecutorContext
|
||||
from .utils import execute_after_agent_interceptors
|
||||
from .utils import execute_after_event_interceptors
|
||||
from .utils import execute_before_agent_interceptors
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
@a2a_experimental
|
||||
class _A2aAgentExecutor(AgentExecutor):
|
||||
"""An AgentExecutor that runs an ADK Agent against an A2A request and
|
||||
|
||||
publishes updates to an event queue.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
runner: Runner | Callable[..., Runner | Awaitable[Runner]],
|
||||
config: Optional[A2aAgentExecutorConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._runner = runner
|
||||
self._config = config or A2aAgentExecutorConfig()
|
||||
|
||||
@override
|
||||
async def cancel(self, context: RequestContext, event_queue: EventQueue):
|
||||
"""Cancel the execution."""
|
||||
# TODO: Implement proper cancellation logic if needed
|
||||
raise NotImplementedError('Cancellation is not supported')
|
||||
|
||||
@override
|
||||
async def execute(
|
||||
self,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
):
|
||||
"""Executes an A2A request and publishes updates to the event queue
|
||||
|
||||
specified. It runs as following:
|
||||
* Takes the input from the A2A request
|
||||
* Convert the input to ADK input content, and runs the ADK agent
|
||||
* Collects output events of the underlying ADK Agent
|
||||
* Converts the ADK output events into A2A task updates
|
||||
* Publishes the updates back to A2A server via event queue
|
||||
"""
|
||||
if not context.message:
|
||||
raise ValueError('A2A request must have a message')
|
||||
|
||||
context = await execute_before_agent_interceptors(
|
||||
context, self._config.execute_interceptors
|
||||
)
|
||||
|
||||
runner = await self._resolve_runner()
|
||||
try:
|
||||
run_request = self._config.request_converter(
|
||||
context,
|
||||
self._config.a2a_part_converter,
|
||||
)
|
||||
await self._resolve_session(run_request, runner)
|
||||
|
||||
executor_context = ExecutorContext(
|
||||
app_name=runner.app_name,
|
||||
user_id=run_request.user_id,
|
||||
session_id=run_request.session_id,
|
||||
runner=runner,
|
||||
)
|
||||
|
||||
# for new task, create a task submitted event
|
||||
if not context.current_task:
|
||||
await event_queue.enqueue_event(
|
||||
Task(
|
||||
id=context.task_id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.submitted,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
context_id=context.context_id,
|
||||
history=[context.message],
|
||||
metadata=self._get_invocation_metadata(executor_context),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Check if the user input is responding to the agent's
|
||||
# request for input.
|
||||
missing_user_input_event = handle_user_input(context)
|
||||
if missing_user_input_event:
|
||||
missing_user_input_event.metadata = self._get_invocation_metadata(
|
||||
executor_context
|
||||
)
|
||||
await event_queue.enqueue_event(missing_user_input_event)
|
||||
return
|
||||
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=context.task_id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.working,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
context_id=context.context_id,
|
||||
final=False,
|
||||
metadata=self._get_invocation_metadata(executor_context),
|
||||
)
|
||||
)
|
||||
|
||||
# Handle the request and publish updates to the event queue
|
||||
await self._handle_request(
|
||||
context,
|
||||
executor_context,
|
||||
event_queue,
|
||||
runner,
|
||||
run_request,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error('Error handling A2A request: %s', e, exc_info=True)
|
||||
# Publish failure event
|
||||
try:
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=context.task_id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.failed,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
message=Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.agent,
|
||||
parts=[TextPart(text=str(e))],
|
||||
),
|
||||
),
|
||||
context_id=context.context_id,
|
||||
final=True,
|
||||
)
|
||||
)
|
||||
except Exception as enqueue_error:
|
||||
logger.error(
|
||||
'Failed to publish failure event: %s', enqueue_error, exc_info=True
|
||||
)
|
||||
|
||||
async def _handle_request(
|
||||
self,
|
||||
context: RequestContext,
|
||||
executor_context: ExecutorContext,
|
||||
event_queue: EventQueue,
|
||||
runner: Runner,
|
||||
run_request: AgentRunRequest,
|
||||
):
|
||||
agents_artifact: dict[str, str] = {}
|
||||
error_event = None
|
||||
long_running_functions = LongRunningFunctions(
|
||||
self._config.gen_ai_part_converter
|
||||
)
|
||||
async with Aclosing(runner.run_async(**vars(run_request))) as agen:
|
||||
async for adk_event in agen:
|
||||
# Handle error scenarios
|
||||
if adk_event and (adk_event.error_code or adk_event.error_message):
|
||||
error_event = create_error_status_event(
|
||||
adk_event,
|
||||
context.task_id,
|
||||
context.context_id,
|
||||
)
|
||||
|
||||
# Handle long running function calls
|
||||
adk_event = long_running_functions.process_event(adk_event)
|
||||
|
||||
for a2a_event in self._config.adk_event_converter(
|
||||
adk_event,
|
||||
agents_artifact,
|
||||
context.task_id,
|
||||
context.context_id,
|
||||
self._config.gen_ai_part_converter,
|
||||
):
|
||||
a2a_event.metadata = self._get_invocation_metadata(executor_context)
|
||||
a2a_event = await execute_after_event_interceptors(
|
||||
a2a_event,
|
||||
executor_context,
|
||||
adk_event,
|
||||
self._config.execute_interceptors,
|
||||
)
|
||||
if not a2a_event:
|
||||
continue
|
||||
await event_queue.enqueue_event(a2a_event)
|
||||
|
||||
if error_event:
|
||||
final_event = error_event
|
||||
elif long_running_functions.has_long_running_function_calls():
|
||||
final_event = (
|
||||
long_running_functions.create_long_running_function_call_event(
|
||||
context.task_id, context.context_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
final_event = TaskStatusUpdateEvent(
|
||||
task_id=context.task_id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.completed,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
context_id=context.context_id,
|
||||
final=True,
|
||||
)
|
||||
|
||||
final_event.metadata = self._get_invocation_metadata(executor_context)
|
||||
final_event = await execute_after_agent_interceptors(
|
||||
executor_context, final_event, self._config.execute_interceptors
|
||||
)
|
||||
await event_queue.enqueue_event(final_event)
|
||||
|
||||
async def _resolve_runner(self) -> Runner:
|
||||
"""Resolve the runner, handling cases where it's a callable that returns a Runner."""
|
||||
if isinstance(self._runner, Runner):
|
||||
return self._runner
|
||||
if callable(self._runner):
|
||||
result = self._runner()
|
||||
|
||||
if inspect.iscoroutine(result):
|
||||
resolved_runner = await result
|
||||
else:
|
||||
resolved_runner = result
|
||||
|
||||
self._runner = resolved_runner
|
||||
return resolved_runner
|
||||
|
||||
raise TypeError(
|
||||
'Runner must be a Runner instance or a callable that returns a'
|
||||
f' Runner, got {type(self._runner)}'
|
||||
)
|
||||
|
||||
async def _resolve_session(
|
||||
self,
|
||||
run_request: AgentRunRequest,
|
||||
runner: Runner,
|
||||
):
|
||||
session_id = run_request.session_id
|
||||
# create a new session if not exists
|
||||
user_id = run_request.user_id
|
||||
session = await runner.session_service.get_session(
|
||||
app_name=runner.app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
if session is None:
|
||||
session = await runner.session_service.create_session(
|
||||
app_name=runner.app_name,
|
||||
user_id=user_id,
|
||||
state={},
|
||||
session_id=session_id,
|
||||
)
|
||||
# Update run_request with the new session_id
|
||||
run_request.session_id = session.id
|
||||
|
||||
def _get_invocation_metadata(
|
||||
self, executor_context: ExecutorContext
|
||||
) -> dict[str, str]:
|
||||
return {
|
||||
_get_adk_metadata_key('app_name'): executor_context.app_name,
|
||||
_get_adk_metadata_key('user_id'): executor_context.user_id,
|
||||
_get_adk_metadata_key('session_id'): executor_context.session_id,
|
||||
# TODO: Remove this metadata once the new agent executor
|
||||
# is fully adopted.
|
||||
_get_adk_metadata_key('agent_executor_v2'): True,
|
||||
}
|
||||
@@ -23,9 +23,21 @@ from typing import Union
|
||||
from a2a.server.agent_execution.context import RequestContext
|
||||
from a2a.server.events import Event as A2AEvent
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...events.event import Event
|
||||
from ..converters.event_converter import AdkEventToA2AEventsConverter
|
||||
from ..converters.event_converter import convert_event_to_a2a_events as legacy_convert_event_to_a2a_events
|
||||
from ..converters.from_adk_event import AdkEventToA2AEventsConverter as AdkEventToA2AEventsConverterImpl
|
||||
from ..converters.from_adk_event import convert_event_to_a2a_events as convert_event_to_a2a_events_impl
|
||||
from ..converters.part_converter import A2APartToGenAIPartConverter
|
||||
from ..converters.part_converter import convert_a2a_part_to_genai_part
|
||||
from ..converters.part_converter import convert_genai_part_to_a2a_part
|
||||
from ..converters.part_converter import GenAIPartToA2APartConverter
|
||||
from ..converters.request_converter import A2ARequestToAgentRunRequestConverter
|
||||
from ..converters.request_converter import convert_a2a_request_to_agent_run_request
|
||||
from ..converters.utils import _get_adk_metadata_key
|
||||
from ..experimental import a2a_experimental
|
||||
from .executor_context import ExecutorContext
|
||||
|
||||
|
||||
@@ -67,3 +79,29 @@ class ExecuteInterceptor:
|
||||
completed or failed) before it is enqueued. Must return a valid
|
||||
`TaskStatusUpdateEvent`.
|
||||
"""
|
||||
|
||||
|
||||
@a2a_experimental
|
||||
class A2aAgentExecutorConfig(BaseModel):
|
||||
"""Configuration for the A2aAgentExecutor."""
|
||||
|
||||
a2a_part_converter: A2APartToGenAIPartConverter = (
|
||||
convert_a2a_part_to_genai_part
|
||||
)
|
||||
gen_ai_part_converter: GenAIPartToA2APartConverter = (
|
||||
convert_genai_part_to_a2a_part
|
||||
)
|
||||
request_converter: A2ARequestToAgentRunRequestConverter = (
|
||||
convert_a2a_request_to_agent_run_request
|
||||
)
|
||||
event_converter: AdkEventToA2AEventsConverter = (
|
||||
legacy_convert_event_to_a2a_events
|
||||
)
|
||||
"""Set up the default event converter implementation to be used by the legacy agent executor implementation."""
|
||||
|
||||
adk_event_converter: AdkEventToA2AEventsConverterImpl = (
|
||||
convert_event_to_a2a_events_impl
|
||||
)
|
||||
"""Set up the imlp event converter implementation to be used by the new agent executor implementation."""
|
||||
|
||||
execute_interceptors: Optional[list[ExecuteInterceptor]] = None
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
import uuid
|
||||
|
||||
from a2a.types import Part as A2APart
|
||||
from a2a.types import TaskArtifactUpdateEvent
|
||||
from a2a.types import TaskState
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
from a2a.types import TextPart
|
||||
from google.adk.a2a.converters.from_adk_event import convert_event_to_a2a_events
|
||||
from google.adk.events.event import Event
|
||||
from google.genai import types as genai_types
|
||||
import pytest
|
||||
|
||||
|
||||
class TestFromAdk:
|
||||
"""Test suite for from_adk functions."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_event = Mock(spec=Event)
|
||||
self.mock_event.id = "test-event-id"
|
||||
self.mock_event.invocation_id = "test-invocation-id"
|
||||
self.mock_event.author = "test-author"
|
||||
self.mock_event.branch = None
|
||||
self.mock_event.content = None
|
||||
self.mock_event.error_code = None
|
||||
self.mock_event.error_message = None
|
||||
self.mock_event.grounding_metadata = None
|
||||
self.mock_event.citation_metadata = None
|
||||
self.mock_event.custom_metadata = None
|
||||
self.mock_event.usage_metadata = None
|
||||
self.mock_event.actions = None
|
||||
self.mock_event.partial = True
|
||||
self.mock_event.long_running_tool_ids = None
|
||||
|
||||
def test_convert_event_to_a2a_events_artifact_update(self):
|
||||
"""Test conversion of event to TaskArtifactUpdateEvent."""
|
||||
# Setup event with content
|
||||
self.mock_event.content = genai_types.Content(
|
||||
parts=[genai_types.Part(text="hello")], role="model"
|
||||
)
|
||||
self.mock_event.author = "agent-1"
|
||||
|
||||
agents_artifacts = {}
|
||||
|
||||
# Mock part converter to return a standard text part
|
||||
mock_a2a_part = A2APart(root=TextPart(text="hello"))
|
||||
mock_a2a_part.root.metadata = {}
|
||||
mock_convert_part = Mock(return_value=[mock_a2a_part])
|
||||
|
||||
result = convert_event_to_a2a_events(
|
||||
self.mock_event,
|
||||
agents_artifacts,
|
||||
task_id="task-123",
|
||||
context_id="context-456",
|
||||
part_converter=mock_convert_part,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], TaskArtifactUpdateEvent)
|
||||
assert result[0].task_id == "task-123"
|
||||
assert result[0].context_id == "context-456"
|
||||
assert result[0].artifact.parts == [mock_a2a_part]
|
||||
assert "agent-1" in agents_artifacts # Artifact ID should be stored
|
||||
|
||||
def test_convert_event_to_a2a_events_error(self):
|
||||
"""Test conversion of event with error to TaskStatusUpdateEvent."""
|
||||
self.mock_event.error_code = "ERR001"
|
||||
self.mock_event.error_message = "Something went wrong"
|
||||
|
||||
agents_artifacts = {}
|
||||
|
||||
result = convert_event_to_a2a_events(
|
||||
self.mock_event,
|
||||
agents_artifacts,
|
||||
task_id="task-123",
|
||||
context_id="context-456",
|
||||
)
|
||||
|
||||
# Should not return any artifact events
|
||||
assert len(result) == 0
|
||||
|
||||
def test_convert_event_to_a2a_events_none_event(self):
|
||||
"""Test convert_event_to_a2a_events with None event."""
|
||||
with pytest.raises(ValueError, match="Event cannot be None"):
|
||||
convert_event_to_a2a_events(None, {})
|
||||
|
||||
def test_convert_event_to_a2a_events_none_artifacts(self):
|
||||
"""Test convert_event_to_a2a_events with None agents_artifacts."""
|
||||
with pytest.raises(ValueError, match="Agents artifacts cannot be None"):
|
||||
convert_event_to_a2a_events(self.mock_event, None)
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user