Files
adk-python/contributing/samples/cache_analysis/run_cache_experiments.py
T
George Weale 2367901ec5 chore: Upgrade to headers to 2026
Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 858763407
2026-01-20 14:50:09 -08:00

716 lines
23 KiB
Python

#!/usr/bin/env python3
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Cache Performance Experiments for ADK Context Caching
This script runs two experiments to compare caching performance:
A. Gemini 2.0 Flash: Cache enabled vs disabled (explicit caching test)
B. Gemini 2.5 Flash: Implicit vs explicit caching comparison
"""
import argparse
import asyncio
import copy
import json
import logging
import sys
import time
from typing import Any
from typing import Dict
from typing import List
try:
# Try relative imports first (when run as module)
from .agent import app
from .utils import get_test_prompts
from .utils import run_experiment_batch
except ImportError:
# Fallback to direct imports (when run as script)
from agent import app
from utils import get_test_prompts
from utils import run_experiment_batch
from google.adk.cli.utils import logs
from google.adk.runners import InMemoryRunner
from google.adk.utils.cache_performance_analyzer import CachePerformanceAnalyzer
APP_NAME = "cache_analysis_experiments"
USER_ID = "cache_researcher"
def create_agent_variant(base_app, model_name: str, cache_enabled: bool):
"""Create an app variant with specified model and cache settings."""
import datetime
from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.apps.app import App
# Extract the root agent and modify its model
agent_copy = copy.deepcopy(base_app.root_agent)
agent_copy.model = model_name
# Prepend dynamic timestamp to instruction to avoid implicit cache reuse across runs
current_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
dynamic_prefix = f"Current session started at: {current_timestamp}\n\n"
agent_copy.instruction = dynamic_prefix + agent_copy.instruction
# Update agent name to reflect configuration
cache_status = "cached" if cache_enabled else "no_cache"
agent_copy.name = (
f"cache_analysis_{model_name.replace('.', '_').replace('-', '_')}_{cache_status}"
)
if cache_enabled:
# Use standardized cache config
cache_config = ContextCacheConfig(
min_tokens=4096,
ttl_seconds=600, # 10 mins for research sessions
cache_intervals=3, # Maximum invocations before cache refresh
)
else:
# Disable caching by setting config to None
cache_config = None
# Create new App with updated configuration
app_copy = App(
name=f"{base_app.name}_{cache_status}",
root_agent=agent_copy,
context_cache_config=cache_config,
)
return app_copy
async def run_cache_comparison_experiment(
model_name: str,
description: str,
cached_label: str,
uncached_label: str,
experiment_title: str,
reverse_order: bool = False,
request_delay: float = 2.0,
) -> Dict[str, Any]:
"""
Run a cache performance comparison experiment for a specific model.
Args:
model_name: Model to test (e.g., "gemini-2.0-flash", "gemini-2.5-flash")
description: Description of what the experiment tests
cached_label: Label for the cached experiment variant
uncached_label: Label for the uncached experiment variant
experiment_title: Title to display for the experiment
Returns:
Dictionary containing experiment results and performance comparison
"""
print("=" * 80)
print(f"EXPERIMENT {model_name}: {experiment_title}")
print("=" * 80)
print(f"Testing: {description}")
print(f"Model: {model_name}")
print()
# Create app variants
app_cached = create_agent_variant(app, model_name, cache_enabled=True)
app_uncached = create_agent_variant(app, model_name, cache_enabled=False)
# Get test prompts
prompts = get_test_prompts()
# Create runners
runner_cached = InMemoryRunner(app=app_cached, app_name=None)
runner_uncached = InMemoryRunner(app=app_uncached, app_name=None)
# Create sessions for each experiment to avoid cross-contamination
session_cached = await runner_cached.session_service.create_session(
app_name=runner_cached.app_name, user_id=USER_ID
)
session_uncached = await runner_uncached.session_service.create_session(
app_name=runner_uncached.app_name, user_id=USER_ID
)
if not reverse_order: # Default: uncached first
print("▢️ Running experiments in DEFAULT ORDER (uncached first)")
print()
# Test uncached version first
results_uncached = await run_experiment_batch(
app_uncached.root_agent.name,
runner_uncached,
USER_ID,
session_uncached.id,
prompts,
f"Experiment {model_name} - {uncached_label}",
request_delay=request_delay,
)
# Brief pause between experiments
await asyncio.sleep(5)
# Test cached version second
results_cached = await run_experiment_batch(
app_cached.root_agent.name,
runner_cached,
USER_ID,
session_cached.id,
prompts,
f"Experiment {model_name} - {cached_label}",
request_delay=request_delay,
)
else:
print("πŸ”„ Running experiments in ALTERNATE ORDER (cached first)")
print()
# Test cached version first
results_cached = await run_experiment_batch(
app_cached.root_agent.name,
runner_cached,
USER_ID,
session_cached.id,
prompts,
f"Experiment {model_name} - {cached_label}",
request_delay=request_delay,
)
# Brief pause between experiments
await asyncio.sleep(5)
# Test uncached version second
results_uncached = await run_experiment_batch(
app_uncached.root_agent.name,
runner_uncached,
USER_ID,
session_uncached.id,
prompts,
f"Experiment {model_name} - {uncached_label}",
request_delay=request_delay,
)
# Analyze cache performance using CachePerformanceAnalyzer
performance_analysis = await analyze_cache_performance_from_sessions(
runner_cached,
session_cached,
runner_uncached,
session_uncached,
model_name,
)
# Extract metrics from analyzer for backward compatibility
cached_analysis = performance_analysis.get("cached_analysis", {})
uncached_analysis = performance_analysis.get("uncached_analysis", {})
cached_total_prompt_tokens = cached_analysis.get("total_prompt_tokens", 0)
cached_total_cached_tokens = cached_analysis.get("total_cached_tokens", 0)
cached_cache_hit_ratio = cached_analysis.get("cache_hit_ratio_percent", 0.0)
cached_cache_utilization_ratio = cached_analysis.get(
"cache_utilization_ratio_percent", 0.0
)
cached_avg_cached_tokens_per_request = cached_analysis.get(
"avg_cached_tokens_per_request", 0.0
)
cached_requests_with_hits = cached_analysis.get("requests_with_cache_hits", 0)
total_cached_requests = cached_analysis.get("total_requests", 0)
uncached_total_prompt_tokens = uncached_analysis.get("total_prompt_tokens", 0)
uncached_total_cached_tokens = uncached_analysis.get("total_cached_tokens", 0)
uncached_cache_hit_ratio = uncached_analysis.get(
"cache_hit_ratio_percent", 0.0
)
uncached_cache_utilization_ratio = uncached_analysis.get(
"cache_utilization_ratio_percent", 0.0
)
uncached_avg_cached_tokens_per_request = uncached_analysis.get(
"avg_cached_tokens_per_request", 0.0
)
uncached_requests_with_hits = uncached_analysis.get(
"requests_with_cache_hits", 0
)
total_uncached_requests = uncached_analysis.get("total_requests", 0)
summary = {
"experiment": model_name,
"description": description,
"model": model_name,
"cached_results": results_cached,
"uncached_results": results_uncached,
"cache_analysis": {
"cached_experiment": {
"cache_hit_ratio_percent": cached_cache_hit_ratio,
"cache_utilization_ratio_percent": cached_cache_utilization_ratio,
"total_prompt_tokens": cached_total_prompt_tokens,
"total_cached_tokens": cached_total_cached_tokens,
"avg_cached_tokens_per_request": (
cached_avg_cached_tokens_per_request
),
"requests_with_cache_hits": cached_requests_with_hits,
"total_requests": total_cached_requests,
},
"uncached_experiment": {
"cache_hit_ratio_percent": uncached_cache_hit_ratio,
"cache_utilization_ratio_percent": (
uncached_cache_utilization_ratio
),
"total_prompt_tokens": uncached_total_prompt_tokens,
"total_cached_tokens": uncached_total_cached_tokens,
"avg_cached_tokens_per_request": (
uncached_avg_cached_tokens_per_request
),
"requests_with_cache_hits": uncached_requests_with_hits,
"total_requests": total_uncached_requests,
},
},
}
print(f"πŸ“Š EXPERIMENT {model_name} CACHE ANALYSIS:")
print(f" πŸ”₯ {cached_label}:")
print(
f" Cache Hit Ratio: {cached_cache_hit_ratio:.1f}%"
f" ({cached_total_cached_tokens:,} /"
f" {cached_total_prompt_tokens:,} tokens)"
)
print(
f" Cache Utilization: {cached_cache_utilization_ratio:.1f}%"
f" ({cached_requests_with_hits}/{total_cached_requests} requests)"
)
print(
" Avg Cached Tokens/Request:"
f" {cached_avg_cached_tokens_per_request:.0f}"
)
print(f" ❄️ {uncached_label}:")
print(
f" Cache Hit Ratio: {uncached_cache_hit_ratio:.1f}%"
f" ({uncached_total_cached_tokens:,} /"
f" {uncached_total_prompt_tokens:,} tokens)"
)
print(
f" Cache Utilization: {uncached_cache_utilization_ratio:.1f}%"
f" ({uncached_requests_with_hits}/{total_uncached_requests} requests)"
)
print(
" Avg Cached Tokens/Request:"
f" {uncached_avg_cached_tokens_per_request:.0f}"
)
print()
# Add performance analysis to summary
summary["performance_analysis"] = performance_analysis
return summary
async def analyze_cache_performance_from_sessions(
runner_cached,
session_cached,
runner_uncached,
session_uncached,
model_name: str,
) -> Dict[str, Any]:
"""Analyze cache performance using CachePerformanceAnalyzer."""
print("πŸ“Š ANALYZING CACHE PERFORMANCE WITH CachePerformanceAnalyzer...")
analyzer_cached = CachePerformanceAnalyzer(runner_cached.session_service)
analyzer_uncached = CachePerformanceAnalyzer(runner_uncached.session_service)
# Analyze cached experiment
try:
cached_analysis = await analyzer_cached.analyze_agent_cache_performance(
session_cached.id,
USER_ID,
runner_cached.app_name,
f"cache_analysis_{model_name.replace('.', '_').replace('-', '_')}_cached",
)
print(f" πŸ”₯ Cached Experiment Analysis:")
print(f" Status: {cached_analysis['status']}")
if cached_analysis["status"] == "active":
print(
" Cache Hit Ratio:"
f" {cached_analysis['cache_hit_ratio_percent']:.1f}%"
f" ({cached_analysis['total_cached_tokens']:,} /"
f" {cached_analysis['total_prompt_tokens']:,} tokens)"
)
print(
" Cache Utilization:"
f" {cached_analysis['cache_utilization_ratio_percent']:.1f}%"
f" ({cached_analysis['requests_with_cache_hits']}/{cached_analysis['total_requests']} requests)"
)
print(
" Avg Cached Tokens/Request:"
f" {cached_analysis['avg_cached_tokens_per_request']:.0f}"
)
print(
f" Requests with cache: {cached_analysis['requests_with_cache']}"
)
print(
" Avg invocations used:"
f" {cached_analysis['avg_invocations_used']:.1f}"
)
print(f" Cache refreshes: {cached_analysis['cache_refreshes']}")
print(f" Total invocations: {cached_analysis['total_invocations']}")
except Exception as e:
print(f" ❌ Error analyzing cached experiment: {e}")
cached_analysis = {"status": "error", "error": str(e)}
# Analyze uncached experiment
try:
uncached_analysis = await analyzer_uncached.analyze_agent_cache_performance(
session_uncached.id,
USER_ID,
runner_uncached.app_name,
f"cache_analysis_{model_name.replace('.', '_').replace('-', '_')}_no_cache",
)
print(f" ❄️ Uncached Experiment Analysis:")
print(f" Status: {uncached_analysis['status']}")
if uncached_analysis["status"] == "active":
print(
" Cache Hit Ratio:"
f" {uncached_analysis['cache_hit_ratio_percent']:.1f}%"
f" ({uncached_analysis['total_cached_tokens']:,} /"
f" {uncached_analysis['total_prompt_tokens']:,} tokens)"
)
print(
" Cache Utilization:"
f" {uncached_analysis['cache_utilization_ratio_percent']:.1f}%"
f" ({uncached_analysis['requests_with_cache_hits']}/{uncached_analysis['total_requests']} requests)"
)
print(
" Avg Cached Tokens/Request:"
f" {uncached_analysis['avg_cached_tokens_per_request']:.0f}"
)
print(
" Requests with cache:"
f" {uncached_analysis['requests_with_cache']}"
)
print(
" Avg invocations used:"
f" {uncached_analysis['avg_invocations_used']:.1f}"
)
print(f" Cache refreshes: {uncached_analysis['cache_refreshes']}")
print(f" Total invocations: {uncached_analysis['total_invocations']}")
except Exception as e:
print(f" ❌ Error analyzing uncached experiment: {e}")
uncached_analysis = {"status": "error", "error": str(e)}
print()
return {
"cached_analysis": cached_analysis,
"uncached_analysis": uncached_analysis,
}
def get_experiment_labels(model_name: str) -> Dict[str, str]:
"""Get experiment labels and titles for a given model."""
# Determine experiment type based on model name
if "2.5" in model_name:
# Gemini 2.5 models have implicit caching
return {
"description": "Google implicit caching vs ADK explicit caching",
"cached_label": "Explicit Caching",
"uncached_label": "Implicit Caching",
"experiment_title": "Implicit vs Explicit Caching",
}
else:
# Other models (2.0, etc.) test explicit caching vs no caching
return {
"description": "ADK explicit caching enabled vs disabled",
"cached_label": "Cached",
"uncached_label": "Uncached",
"experiment_title": "Cache Performance Comparison",
}
def calculate_averaged_results(
all_results: List[Dict[str, Any]], model_name: str
) -> Dict[str, Any]:
"""Calculate averaged results from multiple experiment runs."""
if not all_results:
raise ValueError("No results to average")
# Calculate average cache metrics
cache_hit_ratios = [
r["cache_analysis"]["cache_hit_ratio_percent"] for r in all_results
]
cache_utilization_ratios = [
r["cache_analysis"]["cache_utilization_ratio_percent"]
for r in all_results
]
total_prompt_tokens = [
r["cache_analysis"]["total_prompt_tokens"] for r in all_results
]
total_cached_tokens = [
r["cache_analysis"]["total_cached_tokens"] for r in all_results
]
avg_cached_tokens_per_request = [
r["cache_analysis"]["avg_cached_tokens_per_request"] for r in all_results
]
requests_with_cache_hits = [
r["cache_analysis"]["requests_with_cache_hits"] for r in all_results
]
def safe_average(values):
"""Calculate average, handling empty lists."""
return sum(values) / len(values) if values else 0.0
# Create averaged result
averaged_result = {
"experiment": model_name,
"description": all_results[0]["description"],
"model": model_name,
"individual_runs": (
all_results
), # Keep all individual results for reference
"averaged_cache_analysis": {
"cache_hit_ratio_percent": safe_average(cache_hit_ratios),
"cache_utilization_ratio_percent": safe_average(
cache_utilization_ratios
),
"total_prompt_tokens": safe_average(total_prompt_tokens),
"total_cached_tokens": safe_average(total_cached_tokens),
"avg_cached_tokens_per_request": safe_average(
avg_cached_tokens_per_request
),
"requests_with_cache_hits": safe_average(requests_with_cache_hits),
},
"statistics": {
"runs_completed": len(all_results),
"cache_hit_ratio_std": _calculate_std(cache_hit_ratios),
"cache_utilization_std": _calculate_std(cache_utilization_ratios),
"cached_tokens_per_request_std": _calculate_std(
avg_cached_tokens_per_request
),
},
}
# Print averaged results
print("\nπŸ“Š AVERAGED CACHE ANALYSIS RESULTS:")
print("=" * 80)
avg_cache = averaged_result["averaged_cache_analysis"]
stats = averaged_result["statistics"]
print(f" Runs completed: {stats['runs_completed']}")
print(
f" Average Cache Hit Ratio: {avg_cache['cache_hit_ratio_percent']:.1f}%"
f" (Β±{stats['cache_hit_ratio_std']:.1f}%)"
)
print(
" Average Cache Utilization:"
f" {avg_cache['cache_utilization_ratio_percent']:.1f}%"
f" (Β±{stats['cache_utilization_std']:.1f}%)"
)
print(
" Average Cached Tokens/Request:"
f" {avg_cache['avg_cached_tokens_per_request']:.0f}"
f" (Β±{stats['cached_tokens_per_request_std']:.0f})"
)
print()
return averaged_result
def _calculate_std(values):
"""Calculate standard deviation."""
if len(values) <= 1:
return 0.0
mean = sum(values) / len(values)
variance = sum((x - mean) ** 2 for x in values) / len(values)
return variance**0.5
def save_results(results: Dict[str, Any], filename: str):
"""Save experiment results to JSON file."""
with open(filename, "w") as f:
json.dump(results, f, indent=2)
print(f"πŸ’Ύ Results saved to: {filename}")
async def main():
"""Run cache performance experiment for a specific model."""
parser = argparse.ArgumentParser(
description="ADK Cache Performance Experiment"
)
parser.add_argument(
"model",
help="Model to test (e.g., gemini-2.5-flash, gemini-2.0-flash-001)",
)
parser.add_argument(
"--output",
help="Output filename for results (default: cache_{model}_results.json)",
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help=(
"Number of times to repeat each experiment for averaged results"
" (default: 1)"
),
)
parser.add_argument(
"--cached-first",
action="store_true",
help="Run cached experiment first (default: uncached first)",
)
parser.add_argument(
"--request-delay",
type=float,
default=2.0,
help=(
"Delay in seconds between API requests to avoid overloading (default:"
" 2.0)"
),
)
parser.add_argument(
"--log-level",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Set logging level (default: INFO)",
)
args = parser.parse_args()
# Setup logger with specified level
log_level = getattr(logging, args.log_level.upper())
logs.setup_adk_logger(log_level)
# Set default output filename based on model
if not args.output:
args.output = (
f"cache_{args.model.replace('.', '_').replace('-', '_')}_results.json"
)
print("πŸ§ͺ ADK CONTEXT CACHE PERFORMANCE EXPERIMENT")
print("=" * 80)
print(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Model: {args.model}")
print(f"Repetitions: {args.repeat}")
print()
start_time = time.time()
try:
# Get experiment labels for the model
labels = get_experiment_labels(args.model)
# Run the experiment multiple times if repeat > 1
if args.repeat == 1:
# Single run
result = await run_cache_comparison_experiment(
model_name=args.model,
reverse_order=args.cached_first,
request_delay=args.request_delay,
**labels,
)
else:
# Multiple runs with averaging
print(f"πŸ”„ Running experiment {args.repeat} times for averaged results")
print("=" * 80)
all_results = []
for run_num in range(args.repeat):
print(f"\nπŸƒ RUN {run_num + 1}/{args.repeat}")
print("-" * 40)
run_result = await run_cache_comparison_experiment(
model_name=args.model,
reverse_order=args.cached_first,
request_delay=args.request_delay,
**labels,
)
all_results.append(run_result)
# Brief pause between runs
if run_num < args.repeat - 1:
print("⏸️ Pausing 10 seconds between runs...")
await asyncio.sleep(10)
# Calculate averaged results
result = calculate_averaged_results(all_results, args.model)
# Add completion metadata
result["end_time"] = time.strftime("%Y-%m-%d %H:%M:%S")
result["total_duration"] = time.time() - start_time
result["repetitions"] = args.repeat
except KeyboardInterrupt:
print("\n⚠️ Experiment interrupted by user")
sys.exit(1)
except Exception as e:
print(f"\n❌ Experiment failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Save results
save_results(result, args.output)
# Print final summary
print("=" * 80)
print("πŸŽ‰ EXPERIMENT COMPLETED SUCCESSFULLY!")
print("=" * 80)
# Handle both single and averaged results
if args.repeat == 1:
cached_exp = result["cache_analysis"]["cached_experiment"]
uncached_exp = result["cache_analysis"]["uncached_experiment"]
labels = get_experiment_labels(args.model)
print(f"{args.model}:")
print(f" πŸ”₯ {labels['cached_label']}:")
print(f" Cache Hit Ratio: {cached_exp['cache_hit_ratio_percent']:.1f}%")
print(
" Cache Utilization:"
f" {cached_exp['cache_utilization_ratio_percent']:.1f}%"
)
print(
" Cached Tokens/Request:"
f" {cached_exp['avg_cached_tokens_per_request']:.0f}"
)
print(f" ❄️ {labels['uncached_label']}:")
print(
f" Cache Hit Ratio: {uncached_exp['cache_hit_ratio_percent']:.1f}%"
)
print(
" Cache Utilization:"
f" {uncached_exp['cache_utilization_ratio_percent']:.1f}%"
)
print(
" Cached Tokens/Request:"
f" {uncached_exp['avg_cached_tokens_per_request']:.0f}"
)
else:
# For averaged results, show summary comparison
cached_exp = result["averaged_cache_analysis"]["cached_experiment"]
uncached_exp = result["averaged_cache_analysis"]["uncached_experiment"]
labels = get_experiment_labels(args.model)
print(f"{args.model} (averaged over {args.repeat} runs):")
print(f" πŸ”₯ {labels['cached_label']} vs ❄️ {labels['uncached_label']}:")
print(
f" Cache Hit Ratio: {cached_exp['cache_hit_ratio_percent']:.1f}% vs"
f" {uncached_exp['cache_hit_ratio_percent']:.1f}%"
)
print(
" Cache Utilization:"
f" {cached_exp['cache_utilization_ratio_percent']:.1f}% vs"
f" {uncached_exp['cache_utilization_ratio_percent']:.1f}%"
)
print(f"\nTotal execution time: {result['total_duration']:.2f} seconds")
print(f"Results saved to: {args.output}")
if __name__ == "__main__":
asyncio.run(main())