diff --git a/src/google/adk/a2a/converters/from_adk_event.py b/src/google/adk/a2a/converters/from_adk_event.py new file mode 100644 index 00000000..05bf16d1 --- /dev/null +++ b/src/google/adk/a2a/converters/from_adk_event.py @@ -0,0 +1,288 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +from datetime import datetime +from datetime import timezone +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +import uuid + +from a2a.server.events import Event as A2AEvent +from a2a.types import Artifact +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart + +from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME +from ..experimental import a2a_experimental +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import convert_genai_part_to_a2a_part +from .part_converter import GenAIPartToA2APartConverter +from .utils import _get_adk_metadata_key + +# Constants +DEFAULT_ERROR_MESSAGE = "An error occurred during processing" + +# Logger +logger = logging.getLogger("google_adk." + __name__) + +A2AUpdateEvent = Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent] + +AdkEventToA2AEventsConverter = Callable[ + [ + Event, + Optional[Dict[str, str]], + Optional[str], + Optional[str], + GenAIPartToA2APartConverter, + ], + List[A2AUpdateEvent], +] +"""A callable that converts an ADK Event into a list of A2A events. + +This interface allows for custom logic to map ADK's event structure to the +event structure expected by the A2A server. + +Args: + event: The source ADK Event to convert. + agents_artifacts: State map for tracking active artifact IDs across chunks. + task_id: The ID of the A2A task being processed. + context_id: The context ID from the A2A request. + part_converter: A function to convert GenAI content parts to A2A + parts. + +Returns: + A list of A2A events. +""" + + +def _convert_adk_parts_to_a2a_parts( + event: Event, + part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part, +) -> Optional[List[A2APart]]: + """Converts an ADK event to an A2A parts list. + + Args: + event: The ADK event to convert. + part_converter: The function to convert GenAI part to A2A part. + + Returns: + A list of A2A parts representing the converted ADK event. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + + if not event.content or not event.content.parts: + return [] + + try: + output_parts = [] + for part in event.content.parts: + a2a_parts = part_converter(part) + if not isinstance(a2a_parts, list): + a2a_parts = [a2a_parts] if a2a_parts else [] + for a2a_part in a2a_parts: + output_parts.append(a2a_part) + + return output_parts + + except Exception as e: + logger.error("Failed to convert event to status message: %s", e) + raise + + +def create_error_status_event( + event: Event, + task_id: Optional[str] = None, + context_id: Optional[str] = None, +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent for error scenarios. + + Args: + event: The ADK event containing error information. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + + Returns: + A TaskStatusUpdateEvent with FAILED state. + """ + error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE + + error_event = TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=TaskState.failed, + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[A2APart(root=TextPart(text=error_message))], + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + final=True, + ) + return _add_event_metadata(event, [error_event])[0] + + +@a2a_experimental +def convert_event_to_a2a_events( + event: Event, + agents_artifacts: Dict[str, str], + task_id: Optional[str] = None, + context_id: Optional[str] = None, + part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part, +) -> List[A2AUpdateEvent]: + """Converts a GenAI event to a list of A2A StatusUpdate and ArtifactUpdate events. + + Args: + event: The ADK event to convert. + agents_artifacts: State map for tracking active artifact IDs across chunks. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + part_converter: The function to convert GenAI part to A2A part. + + Returns: + A list of A2A update events representing the converted ADK event. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + if agents_artifacts is None: + raise ValueError("Agents artifacts cannot be None") + + a2a_events = [] + try: + a2a_parts = _convert_adk_parts_to_a2a_parts( + event, part_converter=part_converter + ) + # Handle artifact updates for normal parts + if a2a_parts: + agent_name = event.author + partial = event.partial or False + + artifact_id = agents_artifacts.get(agent_name) + if artifact_id: + append = partial + if not partial: + del agents_artifacts[agent_name] + else: + artifact_id = str(uuid.uuid4()) + # TODO: Clarify if new artifact id must have append=False + append = False + if partial: + agents_artifacts[agent_name] = artifact_id + + a2a_events.append( + TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + last_chunk=not partial, + append=append, + artifact=Artifact( + artifact_id=artifact_id, + parts=a2a_parts, + ), + ) + ) + + a2a_events = _add_event_metadata(event, a2a_events) + return a2a_events + + except Exception as e: + logger.error("Failed to convert event to A2A events: %s", e) + raise + + +def _serialize_value(value: Any) -> Optional[Any]: + """Serializes a value and returns it if it contains meaningful content. + + Returns None if the value is empty or missing. + """ + if value is None: + return None + + # Handle Pydantic models + if hasattr(value, "model_dump"): + try: + dumped = value.model_dump( + exclude_none=True, + exclude_unset=True, + exclude_defaults=True, + by_alias=True, + ) + return dumped if dumped else None + except Exception as e: + logger.warning("Failed to serialize Pydantic model, falling back: %s", e) + return str(value) + + return str(value) + + +# TODO: Clarify if this metadata needs to be translated back into the ADK event +def _add_event_metadata( + event: Event, a2a_events: List[A2AEvent] +) -> List[A2AEvent]: + """Gets the context metadata for the event and applies it to A2A events.""" + if not event: + raise ValueError("Event cannot be None") + + metadata_values = { + "invocation_id": event.invocation_id, + "author": event.author, + "event_id": event.id, + "branch": event.branch, + "citation_metadata": event.citation_metadata, + "grounding_metadata": event.grounding_metadata, + "custom_metadata": event.custom_metadata, + "usage_metadata": event.usage_metadata, + "error_code": event.error_code, + "actions": event.actions, + } + + metadata = {} + for field_name, field_value in metadata_values.items(): + value = _serialize_value(field_value) + if value is not None: + metadata[_get_adk_metadata_key(field_name)] = value + + for a2a_event in a2a_events: + if isinstance(a2a_event, TaskStatusUpdateEvent): + a2a_event.status.message.metadata = metadata.copy() + elif isinstance(a2a_event, TaskArtifactUpdateEvent): + a2a_event.artifact.metadata = metadata.copy() + + return a2a_events diff --git a/src/google/adk/a2a/converters/long_running_functions.py b/src/google/adk/a2a/converters/long_running_functions.py new file mode 100644 index 00000000..0bbb46da --- /dev/null +++ b/src/google/adk/a2a/converters/long_running_functions.py @@ -0,0 +1,215 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +from typing import List +from typing import Set +import uuid + +from a2a.server.agent_execution.context import RequestContext +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.genai import types as genai_types + +from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE +from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import A2APartToGenAIPartConverter +from .part_converter import convert_a2a_part_to_genai_part +from .utils import _get_adk_metadata_key + + +class LongRunningFunctions: + """Keeps track of long running function calls and related responses.""" + + def __init__( + self, part_converter: A2APartToGenAIPartConverter | None = None + ) -> None: + self._parts: List[genai_types.Part] = [] + self._long_running_tool_ids: Set[str] = set() + self._part_converter = part_converter or convert_a2a_part_to_genai_part + self._task_state: TaskState = TaskState.input_required + + def has_long_running_function_calls(self) -> bool: + """Returns True if there are long running function calls.""" + return bool(self._long_running_tool_ids) + + def process_event(self, event: Event) -> Event: + """Processes parts to extract long running calls and responses. + + Returns a copy of the input event with processed parts removed from + event.content.parts. + + Args: + event: The ADK event containing long running tool IDs and content parts. + """ + event = event.model_copy(deep=True) + if not event.content or not event.content.parts: + return event + + kept_parts = [] + for part in event.content.parts: + should_remove = False + if part.function_call: + if part.function_call.id in event.long_running_tool_ids: + if not event.partial: + self._parts.append(part) + self._long_running_tool_ids.add(part.function_call.id) + should_remove = True + + elif part.function_response: + if part.function_response.id in self._long_running_tool_ids: + if not event.partial: + self._parts.append(part) + should_remove = True + + if not should_remove: + kept_parts.append(part) + + event.content.parts = kept_parts + return event + + def create_long_running_function_call_event( + self, + task_id: str, + context_id: str, + ) -> TaskStatusUpdateEvent: + """Creates a task status update event for the long running function calls.""" + if not self._long_running_tool_ids: + return None + + a2a_parts = self._return_long_running_parts() + if not a2a_parts: + return None + + return TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=self._task_state, + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=a2a_parts, + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + final=True, + ) + + def _return_long_running_parts(self) -> List[A2APart]: + """Converts long-running parts to A2A parts.""" + if not self._long_running_tool_ids: + return [] + + output_parts = [] + for part in self._parts: + a2a_parts = self._part_converter(part) + if not isinstance(a2a_parts, list): + a2a_parts = [a2a_parts] if a2a_parts else [] + for a2a_part in a2a_parts: + self._mark_long_running_function_call(a2a_part) + output_parts.append(a2a_part) + + return output_parts + + def _mark_long_running_function_call(self, a2a_part: A2APart) -> None: + """Processes long-running tool metadata for an A2A part. + + Args: + a2a_part: The A2A part to potentially mark as long-running. + """ + + if ( + isinstance(a2a_part.root, DataPart) + and a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ): + a2a_part.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ] = True + # If the function is a request for EUC, set the task state to + # auth_required. Otherwise, set it to input_required. Save the state of + # the last function call, as it will be the state of the task. + if a2a_part.root.metadata.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME: + self._task_state = TaskState.auth_required + else: + self._task_state = TaskState.input_required + + +def handle_user_input(context: RequestContext) -> TaskStatusUpdateEvent | None: + """Processes user input events, validating function responses.""" + + if ( + not context.current_task + or not context.current_task.status + or ( + context.current_task.status.state != TaskState.input_required + and context.current_task.status.state != TaskState.auth_required + ) + ): + return None + + # If the task is in input_required or auth_required state, we expect the user + # to provide a response for the function call. Check if the user input + # contains a function response. + for a2a_part in context.message.parts: + if ( + isinstance(a2a_part.root, DataPart) + and a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ): + return None + + return TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=context.current_task.status.state, + timestamp=datetime.now(timezone.utc).isoformat(), + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[ + A2APart( + root=TextPart( + text=( + "It was not provided a function response for the" + " function call." + ) + ) + ) + ], + ), + ), + final=True, + ) diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index 956b1233..da28955a 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -35,22 +35,14 @@ from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart from google.adk.runners import Runner -from pydantic import BaseModel from typing_extensions import override from ...utils.context_utils import Aclosing -from ..converters.event_converter import AdkEventToA2AEventsConverter -from ..converters.event_converter import convert_event_to_a2a_events -from ..converters.part_converter import A2APartToGenAIPartConverter -from ..converters.part_converter import convert_a2a_part_to_genai_part -from ..converters.part_converter import convert_genai_part_to_a2a_part -from ..converters.part_converter import GenAIPartToA2APartConverter -from ..converters.request_converter import A2ARequestToAgentRunRequestConverter from ..converters.request_converter import AgentRunRequest -from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key from ..experimental import a2a_experimental -from .config import ExecuteInterceptor +from .a2a_agent_executor_impl import _A2aAgentExecutor as ExecutorImpl +from .config import A2aAgentExecutorConfig from .executor_context import ExecutorContext from .task_result_aggregator import TaskResultAggregator from .utils import execute_after_agent_interceptors @@ -60,29 +52,16 @@ from .utils import execute_before_agent_interceptors logger = logging.getLogger('google_adk.' + __name__) -@a2a_experimental -class A2aAgentExecutorConfig(BaseModel): - """Configuration for the A2aAgentExecutor.""" - - a2a_part_converter: A2APartToGenAIPartConverter = ( - convert_a2a_part_to_genai_part - ) - gen_ai_part_converter: GenAIPartToA2APartConverter = ( - convert_genai_part_to_a2a_part - ) - request_converter: A2ARequestToAgentRunRequestConverter = ( - convert_a2a_request_to_agent_run_request - ) - event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events - - execute_interceptors: Optional[list[ExecuteInterceptor]] = None - - @a2a_experimental class A2aAgentExecutor(AgentExecutor): """An AgentExecutor that runs an ADK Agent against an A2A request and publishes updates to an event queue. + + Args: + runner: The runner to use for the agent. + config: The config to use for the executor. + use_legacy: Whether to use the legacy executor implementation. """ def __init__( @@ -90,10 +69,15 @@ class A2aAgentExecutor(AgentExecutor): *, runner: Runner | Callable[..., Runner | Awaitable[Runner]], config: Optional[A2aAgentExecutorConfig] = None, + use_legacy: bool = True, ): super().__init__() - self._runner = runner - self._config = config or A2aAgentExecutorConfig() + if not use_legacy: + self._executor_impl = ExecutorImpl(runner=runner, config=config) + else: + self._executor_impl = None + self._runner = runner + self._config = config or A2aAgentExecutorConfig() async def _resolve_runner(self) -> Runner: """Resolve the runner, handling cases where it's a callable that returns a Runner.""" @@ -122,6 +106,10 @@ class A2aAgentExecutor(AgentExecutor): @override async def cancel(self, context: RequestContext, event_queue: EventQueue): """Cancel the execution.""" + if self._executor_impl: + await self._executor_impl.cancel(context, event_queue) + return + # TODO: Implement proper cancellation logic if needed raise NotImplementedError('Cancellation is not supported') @@ -132,6 +120,7 @@ class A2aAgentExecutor(AgentExecutor): event_queue: EventQueue, ): """Executes an A2A request and publishes updates to the event queue + specified. It runs as following: * Takes the input from the A2A request * Convert the input to ADK input content, and runs the ADK agent @@ -139,6 +128,10 @@ class A2aAgentExecutor(AgentExecutor): * Converts the ADK output events into A2A task updates * Publishes the updates back to A2A server via event queue """ + if self._executor_impl: + await self._executor_impl.execute(context, event_queue) + return + if not context.message: raise ValueError('A2A request must have a message') @@ -213,7 +206,7 @@ class A2aAgentExecutor(AgentExecutor): run_config=run_request.run_config, ) - self._executor_context = ExecutorContext( + executor_context = ExecutorContext( app_name=runner.app_name, user_id=run_request.user_id, session_id=run_request.session_id, @@ -250,7 +243,7 @@ class A2aAgentExecutor(AgentExecutor): ): a2a_event = await execute_after_event_interceptors( a2a_event, - self._executor_context, + executor_context, adk_event, self._config.execute_interceptors, ) @@ -302,7 +295,7 @@ class A2aAgentExecutor(AgentExecutor): ) final_event = await execute_after_agent_interceptors( - self._executor_context, + executor_context, final_event, self._config.execute_interceptors, ) diff --git a/src/google/adk/a2a/executor/a2a_agent_executor_impl.py b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py new file mode 100644 index 00000000..cec68f36 --- /dev/null +++ b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py @@ -0,0 +1,310 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +import inspect +import logging +from typing import Awaitable +from typing import Callable +from typing import Optional +import uuid + +from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Artifact +from a2a.types import Message +from a2a.types import Part +from a2a.types import Role +from a2a.types import Task +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from typing_extensions import override + +from ...runners import Runner +from ...utils.context_utils import Aclosing +from ..converters.from_adk_event import create_error_status_event +from ..converters.long_running_functions import handle_user_input +from ..converters.long_running_functions import LongRunningFunctions +from ..converters.request_converter import AgentRunRequest +from ..converters.utils import _get_adk_metadata_key +from ..experimental import a2a_experimental +from .config import A2aAgentExecutorConfig +from .executor_context import ExecutorContext +from .utils import execute_after_agent_interceptors +from .utils import execute_after_event_interceptors +from .utils import execute_before_agent_interceptors + +logger = logging.getLogger('google_adk.' + __name__) + + +@a2a_experimental +class _A2aAgentExecutor(AgentExecutor): + """An AgentExecutor that runs an ADK Agent against an A2A request and + + publishes updates to an event queue. + """ + + def __init__( + self, + *, + runner: Runner | Callable[..., Runner | Awaitable[Runner]], + config: Optional[A2aAgentExecutorConfig] = None, + ): + super().__init__() + self._runner = runner + self._config = config or A2aAgentExecutorConfig() + + @override + async def cancel(self, context: RequestContext, event_queue: EventQueue): + """Cancel the execution.""" + # TODO: Implement proper cancellation logic if needed + raise NotImplementedError('Cancellation is not supported') + + @override + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ): + """Executes an A2A request and publishes updates to the event queue + + specified. It runs as following: + * Takes the input from the A2A request + * Convert the input to ADK input content, and runs the ADK agent + * Collects output events of the underlying ADK Agent + * Converts the ADK output events into A2A task updates + * Publishes the updates back to A2A server via event queue + """ + if not context.message: + raise ValueError('A2A request must have a message') + + context = await execute_before_agent_interceptors( + context, self._config.execute_interceptors + ) + + runner = await self._resolve_runner() + try: + run_request = self._config.request_converter( + context, + self._config.a2a_part_converter, + ) + await self._resolve_session(run_request, runner) + + executor_context = ExecutorContext( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + runner=runner, + ) + + # for new task, create a task submitted event + if not context.current_task: + await event_queue.enqueue_event( + Task( + id=context.task_id, + status=TaskStatus( + state=TaskState.submitted, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + history=[context.message], + metadata=self._get_invocation_metadata(executor_context), + ) + ) + else: + # Check if the user input is responding to the agent's + # request for input. + missing_user_input_event = handle_user_input(context) + if missing_user_input_event: + missing_user_input_event.metadata = self._get_invocation_metadata( + executor_context + ) + await event_queue.enqueue_event(missing_user_input_event) + return + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.working, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=False, + metadata=self._get_invocation_metadata(executor_context), + ) + ) + + # Handle the request and publish updates to the event queue + await self._handle_request( + context, + executor_context, + event_queue, + runner, + run_request, + ) + except Exception as e: + logger.error('Error handling A2A request: %s', e, exc_info=True) + # Publish failure event + try: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.failed, + timestamp=datetime.now(timezone.utc).isoformat(), + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[TextPart(text=str(e))], + ), + ), + context_id=context.context_id, + final=True, + ) + ) + except Exception as enqueue_error: + logger.error( + 'Failed to publish failure event: %s', enqueue_error, exc_info=True + ) + + async def _handle_request( + self, + context: RequestContext, + executor_context: ExecutorContext, + event_queue: EventQueue, + runner: Runner, + run_request: AgentRunRequest, + ): + agents_artifact: dict[str, str] = {} + error_event = None + long_running_functions = LongRunningFunctions( + self._config.gen_ai_part_converter + ) + async with Aclosing(runner.run_async(**vars(run_request))) as agen: + async for adk_event in agen: + # Handle error scenarios + if adk_event and (adk_event.error_code or adk_event.error_message): + error_event = create_error_status_event( + adk_event, + context.task_id, + context.context_id, + ) + + # Handle long running function calls + adk_event = long_running_functions.process_event(adk_event) + + for a2a_event in self._config.adk_event_converter( + adk_event, + agents_artifact, + context.task_id, + context.context_id, + self._config.gen_ai_part_converter, + ): + a2a_event.metadata = self._get_invocation_metadata(executor_context) + a2a_event = await execute_after_event_interceptors( + a2a_event, + executor_context, + adk_event, + self._config.execute_interceptors, + ) + if not a2a_event: + continue + await event_queue.enqueue_event(a2a_event) + + if error_event: + final_event = error_event + elif long_running_functions.has_long_running_function_calls(): + final_event = ( + long_running_functions.create_long_running_function_call_event( + context.task_id, context.context_id + ) + ) + else: + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.completed, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=True, + ) + + final_event.metadata = self._get_invocation_metadata(executor_context) + final_event = await execute_after_agent_interceptors( + executor_context, final_event, self._config.execute_interceptors + ) + await event_queue.enqueue_event(final_event) + + async def _resolve_runner(self) -> Runner: + """Resolve the runner, handling cases where it's a callable that returns a Runner.""" + if isinstance(self._runner, Runner): + return self._runner + if callable(self._runner): + result = self._runner() + + if inspect.iscoroutine(result): + resolved_runner = await result + else: + resolved_runner = result + + self._runner = resolved_runner + return resolved_runner + + raise TypeError( + 'Runner must be a Runner instance or a callable that returns a' + f' Runner, got {type(self._runner)}' + ) + + async def _resolve_session( + self, + run_request: AgentRunRequest, + runner: Runner, + ): + session_id = run_request.session_id + # create a new session if not exists + user_id = run_request.user_id + session = await runner.session_service.get_session( + app_name=runner.app_name, + user_id=user_id, + session_id=session_id, + ) + if session is None: + session = await runner.session_service.create_session( + app_name=runner.app_name, + user_id=user_id, + state={}, + session_id=session_id, + ) + # Update run_request with the new session_id + run_request.session_id = session.id + + def _get_invocation_metadata( + self, executor_context: ExecutorContext + ) -> dict[str, str]: + return { + _get_adk_metadata_key('app_name'): executor_context.app_name, + _get_adk_metadata_key('user_id'): executor_context.user_id, + _get_adk_metadata_key('session_id'): executor_context.session_id, + # TODO: Remove this metadata once the new agent executor + # is fully adopted. + _get_adk_metadata_key('agent_executor_v2'): True, + } diff --git a/src/google/adk/a2a/executor/config.py b/src/google/adk/a2a/executor/config.py index 79e88546..c083affd 100644 --- a/src/google/adk/a2a/executor/config.py +++ b/src/google/adk/a2a/executor/config.py @@ -23,9 +23,21 @@ from typing import Union from a2a.server.agent_execution.context import RequestContext from a2a.server.events import Event as A2AEvent from a2a.types import TaskStatusUpdateEvent +from pydantic import BaseModel from ...events.event import Event +from ..converters.event_converter import AdkEventToA2AEventsConverter +from ..converters.event_converter import convert_event_to_a2a_events as legacy_convert_event_to_a2a_events +from ..converters.from_adk_event import AdkEventToA2AEventsConverter as AdkEventToA2AEventsConverterImpl +from ..converters.from_adk_event import convert_event_to_a2a_events as convert_event_to_a2a_events_impl +from ..converters.part_converter import A2APartToGenAIPartConverter +from ..converters.part_converter import convert_a2a_part_to_genai_part +from ..converters.part_converter import convert_genai_part_to_a2a_part +from ..converters.part_converter import GenAIPartToA2APartConverter +from ..converters.request_converter import A2ARequestToAgentRunRequestConverter +from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key +from ..experimental import a2a_experimental from .executor_context import ExecutorContext @@ -67,3 +79,29 @@ class ExecuteInterceptor: completed or failed) before it is enqueued. Must return a valid `TaskStatusUpdateEvent`. """ + + +@a2a_experimental +class A2aAgentExecutorConfig(BaseModel): + """Configuration for the A2aAgentExecutor.""" + + a2a_part_converter: A2APartToGenAIPartConverter = ( + convert_a2a_part_to_genai_part + ) + gen_ai_part_converter: GenAIPartToA2APartConverter = ( + convert_genai_part_to_a2a_part + ) + request_converter: A2ARequestToAgentRunRequestConverter = ( + convert_a2a_request_to_agent_run_request + ) + event_converter: AdkEventToA2AEventsConverter = ( + legacy_convert_event_to_a2a_events + ) + """Set up the default event converter implementation to be used by the legacy agent executor implementation.""" + + adk_event_converter: AdkEventToA2AEventsConverterImpl = ( + convert_event_to_a2a_events_impl + ) + """Set up the imlp event converter implementation to be used by the new agent executor implementation.""" + + execute_interceptors: Optional[list[ExecuteInterceptor]] = None diff --git a/tests/unittests/a2a/converters/test_from_adk.py b/tests/unittests/a2a/converters/test_from_adk.py new file mode 100644 index 00000000..23546c58 --- /dev/null +++ b/tests/unittests/a2a/converters/test_from_adk.py @@ -0,0 +1,108 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import Mock +from unittest.mock import patch +import uuid + +from a2a.types import Part as A2APart +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.converters.from_adk_event import convert_event_to_a2a_events +from google.adk.events.event import Event +from google.genai import types as genai_types +import pytest + + +class TestFromAdk: + """Test suite for from_adk functions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_event = Mock(spec=Event) + self.mock_event.id = "test-event-id" + self.mock_event.invocation_id = "test-invocation-id" + self.mock_event.author = "test-author" + self.mock_event.branch = None + self.mock_event.content = None + self.mock_event.error_code = None + self.mock_event.error_message = None + self.mock_event.grounding_metadata = None + self.mock_event.citation_metadata = None + self.mock_event.custom_metadata = None + self.mock_event.usage_metadata = None + self.mock_event.actions = None + self.mock_event.partial = True + self.mock_event.long_running_tool_ids = None + + def test_convert_event_to_a2a_events_artifact_update(self): + """Test conversion of event to TaskArtifactUpdateEvent.""" + # Setup event with content + self.mock_event.content = genai_types.Content( + parts=[genai_types.Part(text="hello")], role="model" + ) + self.mock_event.author = "agent-1" + + agents_artifacts = {} + + # Mock part converter to return a standard text part + mock_a2a_part = A2APart(root=TextPart(text="hello")) + mock_a2a_part.root.metadata = {} + mock_convert_part = Mock(return_value=[mock_a2a_part]) + + result = convert_event_to_a2a_events( + self.mock_event, + agents_artifacts, + task_id="task-123", + context_id="context-456", + part_converter=mock_convert_part, + ) + + assert len(result) == 1 + assert isinstance(result[0], TaskArtifactUpdateEvent) + assert result[0].task_id == "task-123" + assert result[0].context_id == "context-456" + assert result[0].artifact.parts == [mock_a2a_part] + assert "agent-1" in agents_artifacts # Artifact ID should be stored + + def test_convert_event_to_a2a_events_error(self): + """Test conversion of event with error to TaskStatusUpdateEvent.""" + self.mock_event.error_code = "ERR001" + self.mock_event.error_message = "Something went wrong" + + agents_artifacts = {} + + result = convert_event_to_a2a_events( + self.mock_event, + agents_artifacts, + task_id="task-123", + context_id="context-456", + ) + + # Should not return any artifact events + assert len(result) == 0 + + def test_convert_event_to_a2a_events_none_event(self): + """Test convert_event_to_a2a_events with None event.""" + with pytest.raises(ValueError, match="Event cannot be None"): + convert_event_to_a2a_events(None, {}) + + def test_convert_event_to_a2a_events_none_artifacts(self): + """Test convert_event_to_a2a_events with None agents_artifacts.""" + with pytest.raises(ValueError, match="Agents artifacts cannot be None"): + convert_event_to_a2a_events(self.mock_event, None) diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py new file mode 100644 index 00000000..9acae2dc --- /dev/null +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py @@ -0,0 +1,808 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Message +from a2a.types import Task +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.converters.request_converter import AgentRunRequest +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.executor.a2a_agent_executor_impl import _A2aAgentExecutor as A2aAgentExecutor +from google.adk.a2a.executor.a2a_agent_executor_impl import A2aAgentExecutorConfig +from google.adk.a2a.executor.config import ExecuteInterceptor +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.runners import RunConfig +from google.adk.runners import Runner +from google.genai.types import Content +import pytest + + +class TestA2aAgentExecutor: + """Test suite for A2aAgentExecutor class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_runner = Mock(spec=Runner) + self.mock_runner.app_name = "test-app" + self.mock_runner.session_service = Mock() + self.mock_runner._new_invocation_context = Mock() + self.mock_runner.run_async = AsyncMock() + + self.mock_a2a_part_converter = Mock() + self.mock_gen_ai_part_converter = Mock() + self.mock_request_converter = Mock() + self.mock_event_converter = Mock() + self.mock_config = A2aAgentExecutorConfig( + a2a_part_converter=self.mock_a2a_part_converter, + gen_ai_part_converter=self.mock_gen_ai_part_converter, + request_converter=self.mock_request_converter, + adk_event_converter=self.mock_event_converter, + ) + self.executor = A2aAgentExecutor( + runner=self.mock_runner, config=self.mock_config + ) + + self.mock_context = Mock(spec=RequestContext) + self.mock_context.message = Mock(spec=Message) + self.mock_context.message.parts = [Mock(spec=TextPart)] + self.mock_context.current_task = None + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + self.mock_event_queue = Mock(spec=EventQueue) + + self.expected_metadata = { + _get_adk_metadata_key("app_name"): "test-app", + _get_adk_metadata_key("user_id"): "test-user", + _get_adk_metadata_key("session_id"): "test-session", + _get_adk_metadata_key("agent_executor_v2"): True, + } + + async def _create_async_generator(self, items): + """Helper to create async generator from items.""" + for item in items: + yield item + + @pytest.mark.asyncio + async def test_execute_success_new_task(self): + """Test successful execution of a new task.""" + # Setup + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock agent run with proper async generator + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter to return a working status update + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify request converter was called with proper arguments + self.mock_request_converter.assert_called_once_with( + self.mock_context, self.mock_a2a_part_converter + ) + + # Verify event converter was called with proper arguments + self.mock_event_converter.assert_called_once_with( + mock_event, + {}, # agents_artifact (initially empty) + self.mock_context.task_id, + self.mock_context.context_id, + self.mock_gen_ai_part_converter, + ) + + # Verify task submitted event was enqueued + # call 0: submitted + # call 1: working (from converter) + # call 2: completed (final) + assert self.mock_event_queue.enqueue_event.call_count >= 3 + + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert isinstance(submitted_event, Task) + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.metadata == self.expected_metadata + + # Verify working event was enqueued + enqueued_working_event = self.mock_event_queue.enqueue_event.call_args_list[ + 1 + ][0][0] + assert isinstance(enqueued_working_event, TaskStatusUpdateEvent) + assert enqueued_working_event.status.state == TaskState.working + assert enqueued_working_event.metadata == self.expected_metadata + + # Verify converted event was enqueued + converted_event = self.mock_event_queue.enqueue_event.call_args_list[2][0][ + 0 + ] + assert converted_event == working_event + assert converted_event.metadata == self.expected_metadata + + # Verify final event was enqueued + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + assert final_event.status.state == TaskState.completed + assert final_event.metadata == self.expected_metadata + + @pytest.mark.asyncio + async def test_execute_no_message_error(self): + """Test execution fails when no message is provided.""" + self.mock_context.message = None + + with pytest.raises(ValueError, match="A2A request must have a message"): + await self.executor.execute(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_existing_task(self): + """Test execution with existing task (no submitted event).""" + self.mock_context.current_task = Mock() + self.mock_context.task_id = "existing-task-id" + + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock agent run with proper async generator + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter + working_event = TaskStatusUpdateEvent( + task_id="existing-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify submitted event was NOT enqueued for existing task + # So we check first event is working state + first_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] + assert isinstance(first_event, TaskStatusUpdateEvent) + assert first_event.status.state == TaskState.working + assert first_event.metadata == self.expected_metadata + + # Verify manual working event is FIRST + assert isinstance(first_event, TaskStatusUpdateEvent) + assert first_event.status.state == TaskState.working + + # Verify converted event was enqueued + converted_event = self.mock_event_queue.enqueue_event.call_args_list[1][0][ + 0 + ] + assert converted_event == working_event + assert converted_event.metadata == self.expected_metadata + + # Verify final event + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + assert final_event.status.state == TaskState.completed + assert final_event.metadata == self.expected_metadata + + def test_constructor_with_callable_runner(self): + """Test constructor with callable runner.""" + callable_runner = Mock() + executor = A2aAgentExecutor(runner=callable_runner, config=self.mock_config) + + assert executor._runner == callable_runner + assert executor._config == self.mock_config + + @pytest.mark.asyncio + async def test_resolve_runner_direct_instance(self): + """Test _resolve_runner with direct Runner instance.""" + # Setup - already using direct runner instance in setup_method + runner = await self.executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_sync_callable(self): + """Test _resolve_runner with sync callable that returns Runner.""" + + def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_async_callable(self): + """Test _resolve_runner with async callable that returns Runner.""" + + async def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_invalid_type(self): + """Test _resolve_runner with invalid runner type.""" + executor = A2aAgentExecutor(runner="invalid", config=self.mock_config) + + with pytest.raises( + TypeError, match="Runner must be a Runner instance or a callable" + ): + await executor._resolve_runner() + + @pytest.mark.asyncio + async def test_handle_request_integration(self): + """Test the complete request handling flow.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + + # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [ + Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ), + Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ), + ] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter to return events + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Initialize executor context attributes as they would be in execute() + self.executor._invocation_metadata = {} + self.executor._executor_context = Mock() + + # Execute + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + self.mock_request_converter.return_value, + ) + + # Verify events enqueued + # Should check for working events + working_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "status") + and call[0][0].status.state == TaskState.working + ] + # Each ADK event generates 1 working event in this mock setup + assert len(working_events) >= len(mock_events) + + # Verify final event is completed + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] + assert final_event.status.state == TaskState.completed + + @pytest.mark.asyncio + async def test_cancel_with_task_id(self): + """Test cancellation with a task ID.""" + self.mock_context.task_id = "test-task-id" + + with pytest.raises( + NotImplementedError, match="Cancellation is not supported" + ): + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_with_exception_handling(self): + """Test execution with exception handling.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.current_task = None + + self.mock_request_converter.side_effect = Exception("Test error") + + # Execute (should not raise since we catch the exception) + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Check failure event (last) + failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert failure_event.status.state == TaskState.failed + assert failure_event.final == True + assert "Test error" in failure_event.status.message.parts[0].root.text + + @pytest.mark.asyncio + async def test_handle_request_with_non_working_state(self): + """Test handle request when a non-working state is encountered.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Mock agent run event + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + mock_event.error_code = "ERROR" + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter to return a FAILED event + failed_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.failed, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [failed_event] + + run_request = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Initialize executor context attributes + self.executor._invocation_metadata = {} + self.executor._executor_context = Mock() + + # Execute + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + run_request, + ) + + # Verify final event is FAILED, not COMPLETED + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + # The last event should be the synthesized final event + final_event = final_events[-1] + assert final_event.status.state == TaskState.failed + + @pytest.mark.asyncio + async def test_handle_request_with_error_message(self): + """Test handle request when an error message is present without an error code.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Mock agent run event with only error_message + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + mock_event.error_code = None + mock_event.error_message = "Test Error Message" + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] + + run_request = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + executor_context = Mock() + executor_context.app_name = "test-app" + executor_context.user_id = "test-user" + executor_context.session_id = "test-session" + + await self.executor._handle_request( + self.mock_context, + executor_context, + self.mock_event_queue, + self.mock_runner, + run_request, + ) + + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] + assert final_event.status.state == TaskState.failed + assert final_event.metadata == self.expected_metadata + + @pytest.mark.asyncio + async def test_interceptors(self): + """Test interceptors execution.""" + # Setup interceptors + before_interceptor = AsyncMock(return_value=self.mock_context) + after_event_interceptor = AsyncMock() + after_event_interceptor.side_effect = lambda ctx, a2a, adk: a2a + after_agent_interceptor = AsyncMock() + after_agent_interceptor.side_effect = lambda ctx, event: event + + interceptor = ExecuteInterceptor( + before_agent=before_interceptor, + after_event=after_event_interceptor, + after_agent=after_agent_interceptor, + ) + + self.mock_config.execute_interceptors = [interceptor] + + # Mock run + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Pre-setup request converter + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify interceptors called + before_interceptor.assert_called_once_with(self.mock_context) + # after_event called for each event + assert after_event_interceptor.call_count >= 1 + after_agent_interceptor.assert_called_once() + + @pytest.mark.asyncio + @patch("google.adk.a2a.executor.a2a_agent_executor_impl.handle_user_input") + async def test_execute_missing_user_input(self, mock_handle_user_input): + """Test when handle_user_input returns a missing user input event.""" + self.mock_context.current_task = Mock() + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Set up handle_user_input to return an event + missing_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.input_required, timestamp="now"), + context_id="test-context-id", + final=False, + ) + mock_handle_user_input.return_value = missing_event + + self.mock_runner.session_service.get_session = AsyncMock( + return_value=Mock(id="test-session") + ) + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify that the missing_event was enqueued + self.mock_event_queue.enqueue_event.assert_called_once_with(missing_event) + + # Verify that metadata was injected + enqueued_event = self.mock_event_queue.enqueue_event.call_args[0][0] + assert enqueued_event.metadata == self.expected_metadata + + @pytest.mark.asyncio + async def test_resolve_session_creates_new_session(self): + """Test that _resolve_session creates a new session if it doesn't exist.""" + self.mock_runner.session_service.get_session = AsyncMock(return_value=None) + + new_session = Mock() + new_session.id = "new-session-id" + self.mock_runner.session_service.create_session = AsyncMock( + return_value=new_session + ) + + run_request = AgentRunRequest( + user_id="test-user", + session_id="old-session-id", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + await self.executor._resolve_session(run_request, self.mock_runner) + + self.mock_runner.session_service.get_session.assert_called_once_with( + app_name=self.mock_runner.app_name, + user_id="test-user", + session_id="old-session-id", + ) + self.mock_runner.session_service.create_session.assert_called_once_with( + app_name=self.mock_runner.app_name, + user_id="test-user", + state={}, + session_id="old-session-id", + ) + assert run_request.session_id == "new-session-id" + + @pytest.mark.asyncio + async def test_execute_enqueue_error_in_exception_handler(self): + """Test failure event publishing handles exception during enqueue.""" + self.mock_context.task_id = "test-task-id" + self.mock_request_converter.side_effect = Exception("Test error") + + # Make enqueue_event raise an exception + self.mock_event_queue.enqueue_event.side_effect = Exception("Enqueue error") + + # This should not raise an exception itself + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify enqueue_event was called to publish the error event + assert self.mock_event_queue.enqueue_event.call_count == 1 + + @pytest.mark.asyncio + @patch("google.adk.a2a.executor.a2a_agent_executor_impl.LongRunningFunctions") + async def test_long_running_functions_final_event(self, mock_lrf_class): + """Test _handle_request when there are long running function calls.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Set up mock LongRunningFunctions + mock_lrf = mock_lrf_class.return_value + mock_lrf.process_event.side_effect = lambda e: e + mock_lrf.has_long_running_function_calls.return_value = True + + lrf_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.input_required, timestamp="now"), + context_id="test-context-id", + final=False, + ) + mock_lrf.create_long_running_function_call_event.return_value = lrf_event + + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] + + self.executor._invocation_metadata = {} + self.executor._executor_context = Mock() + + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + self.mock_request_converter.return_value, + ) + + # Verify final event is the long running function call event + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if call[0][0] == lrf_event + ] + assert len(final_events) >= 1 + + @pytest.mark.asyncio + async def test_after_event_interceptor_returns_none(self): + """Test after_event_interceptor returning None drops the event.""" + # Setup interceptor returning None + after_event_interceptor = AsyncMock() + after_event_interceptor.side_effect = lambda ctx, a2a, adk: None + + interceptor = ExecuteInterceptor( + after_event=after_event_interceptor, + ) + self.mock_config.execute_interceptors = [interceptor] + + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Event converter returns one event + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + self.executor._executor_context = Mock() + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + self.mock_request_converter.return_value, + ) + + # Since the interceptor returns None, working_event should NOT be enqueued + # The only event enqueued by _handle_request should be the final event + assert self.mock_event_queue.enqueue_event.call_count == 1 + final_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] + assert final_event.status.state == TaskState.completed