From 3bd2f29f3a56d2dc723dcc0422afdedaa42a4370 Mon Sep 17 00:00:00 2001 From: "Wei Sun (Jack)" Date: Tue, 16 Sep 2025 21:53:50 -0700 Subject: [PATCH] feat(conformance): Adds a replay plugin to replay the previously recorded llm/tool recordings for conformance tests PiperOrigin-RevId: 807979314 --- src/google/adk/cli/plugins/replay_plugin.py | 400 ++++++++++++++++++++ 1 file changed, 400 insertions(+) create mode 100644 src/google/adk/cli/plugins/replay_plugin.py diff --git a/src/google/adk/cli/plugins/replay_plugin.py b/src/google/adk/cli/plugins/replay_plugin.py new file mode 100644 index 00000000..d2e12ae0 --- /dev/null +++ b/src/google/adk/cli/plugins/replay_plugin.py @@ -0,0 +1,400 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Replay plugin for ADK conformance testing.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING + +from google.genai import types +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override +import yaml + +from ...agents.callback_context import CallbackContext +from ...models.llm_request import LlmRequest +from ...models.llm_response import LlmResponse +from ...plugins.base_plugin import BasePlugin +from .recordings_schema import LlmRecording +from .recordings_schema import Recording +from .recordings_schema import Recordings +from .recordings_schema import ToolRecording + +if TYPE_CHECKING: + from ...agents.invocation_context import InvocationContext + from ...tools.base_tool import BaseTool + from ...tools.tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +class ReplayVerificationError(Exception): + """Exception raised when replay verification fails.""" + + pass + + +class _InvocationReplayState(BaseModel): + """Per-invocation replay state to isolate concurrent runs.""" + + test_case_path: str + user_message_index: int + recordings: Recordings + + # Per-agent replay indices for parallel execution + # key: agent_name -> current replay index for that agent + agent_replay_indices: dict[str, int] = Field(default_factory=dict) + + +class ReplayPlugin(BasePlugin): + """Plugin for replaying ADK agent interactions from recordings.""" + + def __init__(self, *, name: str = "adk_replay") -> None: + super().__init__(name=name) + + # Track replay state per invocation to support concurrent runs + # key: invocation_id -> _InvocationReplayState + self._invocation_states: dict[str, _InvocationReplayState] = {} + + @override + async def before_run_callback( + self, *, invocation_context: InvocationContext + ) -> Optional[types.Content]: + """Load replay recordings when enabled.""" + ctx = CallbackContext(invocation_context) + if self._is_replay_mode_on(ctx): + # Load the replay state for this invocation + self._load_invocation_state(ctx) + return None + + @override + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Replay LLM response from recordings instead of making real call.""" + if not self._is_replay_mode_on(callback_context): + return None + + if (state := self._get_invocation_state(callback_context)) is None: + raise ValueError( + "Replay state not initialized. Ensure before_run created it." + ) + + agent_name = callback_context.agent_name + + # Verify and get the next LLM recording for this specific agent + recording = self._verify_and_get_next_llm_recording_for_agent( + state, agent_name, llm_request + ) + + logger.debug("Verified and replaying LLM response for agent %s", agent_name) + + # Return the recorded response + return recording.llm_response + + @override + async def before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + ) -> Optional[dict]: + """Replay tool response from recordings instead of executing tool.""" + if not self._is_replay_mode_on(tool_context): + return None + + if (state := self._get_invocation_state(tool_context)) is None: + raise ValueError( + "Replay state not initialized. Ensure before_run created it." + ) + + agent_name = tool_context.agent_name + + # Verify and get the next tool recording for this specific agent + recording = self._verify_and_get_next_tool_recording_for_agent( + state, agent_name, tool.name, tool_args + ) + # Execute the actual tool to get state updates. + await tool.run_async(args=tool_args, tool_context=tool_context) + + logger.debug( + "Verified and replaying tool response for agent %s: tool=%s", + agent_name, + tool.name, + ) + + # Return the recorded response + return recording.tool_response.response + + @override + async def after_run_callback( + self, *, invocation_context: InvocationContext + ) -> None: + """Clean up replay state after invocation completes.""" + ctx = CallbackContext(invocation_context) + if not self._is_replay_mode_on(ctx): + return None + + # Clean up per-invocation replay state + self._invocation_states.pop(ctx.invocation_id, None) + logger.debug("Cleaned up replay state for invocation %s", ctx.invocation_id) + + # Private helpers + def _is_replay_mode_on(self, callback_context: CallbackContext) -> bool: + """Check if replay mode is enabled for this invocation.""" + session_state = callback_context.state + if not (config := session_state.get("_adk_replay_config")): + return False + + case_dir = config.get("dir") + msg_index = config.get("user_message_index") + + return case_dir and msg_index is not None + + def _get_invocation_state( + self, callback_context: CallbackContext + ) -> Optional[_InvocationReplayState]: + """Get existing replay state for this invocation.""" + invocation_id = callback_context.invocation_id + return self._invocation_states.get(invocation_id) + + def _load_invocation_state( + self, callback_context: CallbackContext + ) -> _InvocationReplayState: + """Load and store replay state for this invocation.""" + invocation_id = callback_context.invocation_id + session_state = callback_context.state + + config = session_state.get("_adk_replay_config", {}) + case_dir = config.get("dir") + msg_index = config.get("user_message_index") + + if not case_dir or msg_index is None: + raise ValueError("Replay parameters are missing from session state") + + # Load recordings + recordings_file = Path(case_dir) / "generated-recordings.yaml" + + if not recordings_file.exists(): + raise ValueError(f"Recordings file not found: {recordings_file}") + + try: + with recordings_file.open("r", encoding="utf-8") as f: + recordings_data = yaml.safe_load(f) + recordings = Recordings.model_validate(recordings_data) + except Exception as e: + raise ValueError(f"Failed to load recordings from {recordings_file}: {e}") + + # Load and store invocation state + state = _InvocationReplayState( + test_case_path=case_dir, + user_message_index=msg_index, + recordings=recordings, + ) + self._invocation_states[invocation_id] = state + logger.debug( + "Loaded replay state for invocation %s: case_dir=%s, msg_index=%s, " + "recordings=%d", + invocation_id, + case_dir, + msg_index, + len(recordings.recordings), + ) + return state + + def _get_next_recording_for_agent( + self, + state: _InvocationReplayState, + agent_name: str, + ) -> Recording: + """Get the next recording for the specific agent in strict order.""" + # Get current agent index + current_agent_index = state.agent_replay_indices.get(agent_name, 0) + + # Filter ALL recordings for this agent and user message index (strict order) + agent_recordings = [ + recording + for recording in state.recordings.recordings + if ( + recording.agent_name == agent_name + and recording.user_message_index == state.user_message_index + ) + ] + + # Check if we have enough recordings for this agent + if current_agent_index >= len(agent_recordings): + raise ReplayVerificationError( + f"Runtime sent more requests than expected for agent '{agent_name}'" + f" at user_message_index {state.user_message_index}. Expected" + f" {len(agent_recordings)}, but got request at index" + f" {current_agent_index}" + ) + + # Get the expected recording + expected_recording = agent_recordings[current_agent_index] + + # Advance agent index + state.agent_replay_indices[agent_name] = current_agent_index + 1 + + return expected_recording + + def _verify_and_get_next_llm_recording_for_agent( + self, + state: _InvocationReplayState, + agent_name: str, + llm_request: LlmRequest, + ) -> LlmRecording: + """Verify and get the next LLM recording for the specific agent.""" + current_agent_index = state.agent_replay_indices.get(agent_name, 0) + expected_recording = self._get_next_recording_for_agent(state, agent_name) + + # Verify this is an LLM recording + if not expected_recording.llm_recording: + raise ReplayVerificationError( + f"Expected LLM recording for agent '{agent_name}' at index " + f"{current_agent_index}, but found tool recording" + ) + + # Strict verification of LLM request + self._verify_llm_request_match( + expected_recording.llm_recording.llm_request, + llm_request, + agent_name, + current_agent_index, + ) + + return expected_recording.llm_recording + + def _verify_and_get_next_tool_recording_for_agent( + self, + state: _InvocationReplayState, + agent_name: str, + tool_name: str, + tool_args: dict[str, Any], + ) -> ToolRecording: + """Verify and get the next tool recording for the specific agent.""" + current_agent_index = state.agent_replay_indices.get(agent_name, 0) + expected_recording = self._get_next_recording_for_agent(state, agent_name) + + # Verify this is a tool recording + if not expected_recording.tool_recording: + raise ReplayVerificationError( + f"Expected tool recording for agent '{agent_name}' at index " + f"{current_agent_index}, but found LLM recording" + ) + + # Strict verification of tool call + self._verify_tool_call_match( + expected_recording.tool_recording.tool_call, + tool_name, + tool_args, + agent_name, + current_agent_index, + ) + + return expected_recording.tool_recording + + def _verify_llm_request_match( + self, + recorded_request: LlmRequest, + current_request: LlmRequest, + agent_name: str, + agent_index: int, + ) -> None: + """Verify that the current LLM request exactly matches the recorded one.""" + self._verify_config_match( + recorded_request, current_request, agent_name, agent_index + ) + handled_fields: set[str] = {"config"} + ignored_fields: set[str] = {"live_connect_config"} + exclude_fields = handled_fields | ignored_fields + if not self._compare_fields( + recorded_request, current_request, exclude_fields=exclude_fields + ): + raise ValueError( + f"LLM request mismatch for agent '{agent_name}' (index" + f" {agent_index}): " + "recorded:" + f" {recorded_request.model_dump(exclude_none=True, exclude=exclude_fields)}," + " current:" + f" {current_request.model_dump(exclude_none=True, exclude=exclude_fields)}" + ) + + def _compare_fields( + self, + obj1: BaseModel, + obj2: BaseModel, + *, + exclude_fields: Optional[set[str]] = None, + ) -> bool: + """Compare two Pydantic models excluding specified fields.""" + exclude_fields = exclude_fields or set() + dict1 = obj1.model_dump(exclude_none=True, exclude=exclude_fields) + dict2 = obj2.model_dump(exclude_none=True, exclude=exclude_fields) + return dict1 == dict2 + + def _verify_config_match( + self, + recorded_request: LlmRequest, + current_request: LlmRequest, + agent_name: str, + agent_index: int, + ) -> None: + """Verify that the config matches between recorded and current requests.""" + # Fields to ignore when comparing GenerateContentConfig (denylist approach) + ignored_fields: set[str] = { + "http_options", + "labels", + } + + if not self._compare_fields( + recorded_request.config, + current_request.config, + exclude_fields=ignored_fields, + ): + raise ValueError( + f"Config mismatch for agent '{agent_name}' (index {agent_index}): " + "recorded:" + f" {recorded_request.config.model_dump(exclude_none=True, exclude=ignored_fields)}," + " current:" + f" {current_request.config.model_dump(exclude_none=True, exclude=ignored_fields)}" + ) + + def _verify_tool_call_match( + self, + recorded_call: types.FunctionCall, + tool_name: str, + tool_args: dict[str, Any], + agent_name: str, + agent_index: int, + ) -> None: + """Verify that the current tool call exactly matches the recorded one.""" + if recorded_call.name != tool_name: + raise ReplayVerificationError( + f"Tool name mismatch for agent '{agent_name}' at index {agent_index}:" + f" recorded='{recorded_call.name}', current='{tool_name}'" + ) + + if recorded_call.args != tool_args: + raise ReplayVerificationError( + f"Tool args mismatch for agent '{agent_name}' at index {agent_index}:" + f" recorded={recorded_call.args}, current={tool_args}" + )