feat(conformance): Adds a replay plugin to replay the previously recorded llm/tool recordings for conformance tests

PiperOrigin-RevId: 807979314
This commit is contained in:
Wei Sun (Jack)
2025-09-16 21:53:50 -07:00
committed by Copybara-Service
parent 14f118899d
commit 3bd2f29f3a
+400
View File
@@ -0,0 +1,400 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Replay plugin for ADK conformance testing."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override
import yaml
from ...agents.callback_context import CallbackContext
from ...models.llm_request import LlmRequest
from ...models.llm_response import LlmResponse
from ...plugins.base_plugin import BasePlugin
from .recordings_schema import LlmRecording
from .recordings_schema import Recording
from .recordings_schema import Recordings
from .recordings_schema import ToolRecording
if TYPE_CHECKING:
from ...agents.invocation_context import InvocationContext
from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext
logger = logging.getLogger("google_adk." + __name__)
class ReplayVerificationError(Exception):
"""Exception raised when replay verification fails."""
pass
class _InvocationReplayState(BaseModel):
"""Per-invocation replay state to isolate concurrent runs."""
test_case_path: str
user_message_index: int
recordings: Recordings
# Per-agent replay indices for parallel execution
# key: agent_name -> current replay index for that agent
agent_replay_indices: dict[str, int] = Field(default_factory=dict)
class ReplayPlugin(BasePlugin):
"""Plugin for replaying ADK agent interactions from recordings."""
def __init__(self, *, name: str = "adk_replay") -> None:
super().__init__(name=name)
# Track replay state per invocation to support concurrent runs
# key: invocation_id -> _InvocationReplayState
self._invocation_states: dict[str, _InvocationReplayState] = {}
@override
async def before_run_callback(
self, *, invocation_context: InvocationContext
) -> Optional[types.Content]:
"""Load replay recordings when enabled."""
ctx = CallbackContext(invocation_context)
if self._is_replay_mode_on(ctx):
# Load the replay state for this invocation
self._load_invocation_state(ctx)
return None
@override
async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
"""Replay LLM response from recordings instead of making real call."""
if not self._is_replay_mode_on(callback_context):
return None
if (state := self._get_invocation_state(callback_context)) is None:
raise ValueError(
"Replay state not initialized. Ensure before_run created it."
)
agent_name = callback_context.agent_name
# Verify and get the next LLM recording for this specific agent
recording = self._verify_and_get_next_llm_recording_for_agent(
state, agent_name, llm_request
)
logger.debug("Verified and replaying LLM response for agent %s", agent_name)
# Return the recorded response
return recording.llm_response
@override
async def before_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> Optional[dict]:
"""Replay tool response from recordings instead of executing tool."""
if not self._is_replay_mode_on(tool_context):
return None
if (state := self._get_invocation_state(tool_context)) is None:
raise ValueError(
"Replay state not initialized. Ensure before_run created it."
)
agent_name = tool_context.agent_name
# Verify and get the next tool recording for this specific agent
recording = self._verify_and_get_next_tool_recording_for_agent(
state, agent_name, tool.name, tool_args
)
# Execute the actual tool to get state updates.
await tool.run_async(args=tool_args, tool_context=tool_context)
logger.debug(
"Verified and replaying tool response for agent %s: tool=%s",
agent_name,
tool.name,
)
# Return the recorded response
return recording.tool_response.response
@override
async def after_run_callback(
self, *, invocation_context: InvocationContext
) -> None:
"""Clean up replay state after invocation completes."""
ctx = CallbackContext(invocation_context)
if not self._is_replay_mode_on(ctx):
return None
# Clean up per-invocation replay state
self._invocation_states.pop(ctx.invocation_id, None)
logger.debug("Cleaned up replay state for invocation %s", ctx.invocation_id)
# Private helpers
def _is_replay_mode_on(self, callback_context: CallbackContext) -> bool:
"""Check if replay mode is enabled for this invocation."""
session_state = callback_context.state
if not (config := session_state.get("_adk_replay_config")):
return False
case_dir = config.get("dir")
msg_index = config.get("user_message_index")
return case_dir and msg_index is not None
def _get_invocation_state(
self, callback_context: CallbackContext
) -> Optional[_InvocationReplayState]:
"""Get existing replay state for this invocation."""
invocation_id = callback_context.invocation_id
return self._invocation_states.get(invocation_id)
def _load_invocation_state(
self, callback_context: CallbackContext
) -> _InvocationReplayState:
"""Load and store replay state for this invocation."""
invocation_id = callback_context.invocation_id
session_state = callback_context.state
config = session_state.get("_adk_replay_config", {})
case_dir = config.get("dir")
msg_index = config.get("user_message_index")
if not case_dir or msg_index is None:
raise ValueError("Replay parameters are missing from session state")
# Load recordings
recordings_file = Path(case_dir) / "generated-recordings.yaml"
if not recordings_file.exists():
raise ValueError(f"Recordings file not found: {recordings_file}")
try:
with recordings_file.open("r", encoding="utf-8") as f:
recordings_data = yaml.safe_load(f)
recordings = Recordings.model_validate(recordings_data)
except Exception as e:
raise ValueError(f"Failed to load recordings from {recordings_file}: {e}")
# Load and store invocation state
state = _InvocationReplayState(
test_case_path=case_dir,
user_message_index=msg_index,
recordings=recordings,
)
self._invocation_states[invocation_id] = state
logger.debug(
"Loaded replay state for invocation %s: case_dir=%s, msg_index=%s, "
"recordings=%d",
invocation_id,
case_dir,
msg_index,
len(recordings.recordings),
)
return state
def _get_next_recording_for_agent(
self,
state: _InvocationReplayState,
agent_name: str,
) -> Recording:
"""Get the next recording for the specific agent in strict order."""
# Get current agent index
current_agent_index = state.agent_replay_indices.get(agent_name, 0)
# Filter ALL recordings for this agent and user message index (strict order)
agent_recordings = [
recording
for recording in state.recordings.recordings
if (
recording.agent_name == agent_name
and recording.user_message_index == state.user_message_index
)
]
# Check if we have enough recordings for this agent
if current_agent_index >= len(agent_recordings):
raise ReplayVerificationError(
f"Runtime sent more requests than expected for agent '{agent_name}'"
f" at user_message_index {state.user_message_index}. Expected"
f" {len(agent_recordings)}, but got request at index"
f" {current_agent_index}"
)
# Get the expected recording
expected_recording = agent_recordings[current_agent_index]
# Advance agent index
state.agent_replay_indices[agent_name] = current_agent_index + 1
return expected_recording
def _verify_and_get_next_llm_recording_for_agent(
self,
state: _InvocationReplayState,
agent_name: str,
llm_request: LlmRequest,
) -> LlmRecording:
"""Verify and get the next LLM recording for the specific agent."""
current_agent_index = state.agent_replay_indices.get(agent_name, 0)
expected_recording = self._get_next_recording_for_agent(state, agent_name)
# Verify this is an LLM recording
if not expected_recording.llm_recording:
raise ReplayVerificationError(
f"Expected LLM recording for agent '{agent_name}' at index "
f"{current_agent_index}, but found tool recording"
)
# Strict verification of LLM request
self._verify_llm_request_match(
expected_recording.llm_recording.llm_request,
llm_request,
agent_name,
current_agent_index,
)
return expected_recording.llm_recording
def _verify_and_get_next_tool_recording_for_agent(
self,
state: _InvocationReplayState,
agent_name: str,
tool_name: str,
tool_args: dict[str, Any],
) -> ToolRecording:
"""Verify and get the next tool recording for the specific agent."""
current_agent_index = state.agent_replay_indices.get(agent_name, 0)
expected_recording = self._get_next_recording_for_agent(state, agent_name)
# Verify this is a tool recording
if not expected_recording.tool_recording:
raise ReplayVerificationError(
f"Expected tool recording for agent '{agent_name}' at index "
f"{current_agent_index}, but found LLM recording"
)
# Strict verification of tool call
self._verify_tool_call_match(
expected_recording.tool_recording.tool_call,
tool_name,
tool_args,
agent_name,
current_agent_index,
)
return expected_recording.tool_recording
def _verify_llm_request_match(
self,
recorded_request: LlmRequest,
current_request: LlmRequest,
agent_name: str,
agent_index: int,
) -> None:
"""Verify that the current LLM request exactly matches the recorded one."""
self._verify_config_match(
recorded_request, current_request, agent_name, agent_index
)
handled_fields: set[str] = {"config"}
ignored_fields: set[str] = {"live_connect_config"}
exclude_fields = handled_fields | ignored_fields
if not self._compare_fields(
recorded_request, current_request, exclude_fields=exclude_fields
):
raise ValueError(
f"LLM request mismatch for agent '{agent_name}' (index"
f" {agent_index}): "
"recorded:"
f" {recorded_request.model_dump(exclude_none=True, exclude=exclude_fields)},"
" current:"
f" {current_request.model_dump(exclude_none=True, exclude=exclude_fields)}"
)
def _compare_fields(
self,
obj1: BaseModel,
obj2: BaseModel,
*,
exclude_fields: Optional[set[str]] = None,
) -> bool:
"""Compare two Pydantic models excluding specified fields."""
exclude_fields = exclude_fields or set()
dict1 = obj1.model_dump(exclude_none=True, exclude=exclude_fields)
dict2 = obj2.model_dump(exclude_none=True, exclude=exclude_fields)
return dict1 == dict2
def _verify_config_match(
self,
recorded_request: LlmRequest,
current_request: LlmRequest,
agent_name: str,
agent_index: int,
) -> None:
"""Verify that the config matches between recorded and current requests."""
# Fields to ignore when comparing GenerateContentConfig (denylist approach)
ignored_fields: set[str] = {
"http_options",
"labels",
}
if not self._compare_fields(
recorded_request.config,
current_request.config,
exclude_fields=ignored_fields,
):
raise ValueError(
f"Config mismatch for agent '{agent_name}' (index {agent_index}): "
"recorded:"
f" {recorded_request.config.model_dump(exclude_none=True, exclude=ignored_fields)},"
" current:"
f" {current_request.config.model_dump(exclude_none=True, exclude=ignored_fields)}"
)
def _verify_tool_call_match(
self,
recorded_call: types.FunctionCall,
tool_name: str,
tool_args: dict[str, Any],
agent_name: str,
agent_index: int,
) -> None:
"""Verify that the current tool call exactly matches the recorded one."""
if recorded_call.name != tool_name:
raise ReplayVerificationError(
f"Tool name mismatch for agent '{agent_name}' at index {agent_index}:"
f" recorded='{recorded_call.name}', current='{tool_name}'"
)
if recorded_call.args != tool_args:
raise ReplayVerificationError(
f"Tool args mismatch for agent '{agent_name}' at index {agent_index}:"
f" recorded={recorded_call.args}, current={tool_args}"
)