chore: Add sample agent for content cache and basic profiling

PiperOrigin-RevId: 809166922
This commit is contained in:
Xiang (Sean) Zhou
2025-09-19 13:37:18 -07:00
committed by Copybara-Service
parent c66245a3b8
commit f4e1fd962e
5 changed files with 1960 additions and 0 deletions
@@ -0,0 +1,114 @@
# Cache Analysis Research Assistant
This sample demonstrates ADK context caching features with a comprehensive research assistant agent designed to test both Gemini 2.0 Flash and 2.5 Flash context caching capabilities. The sample showcases the difference between explicit ADK caching and Google's built-in implicit caching.
## Key Features
- **App-Level Cache Configuration**: Context cache settings applied at the App level
- **Large Context Instructions**: Over 4200 tokens in system instructions to trigger context caching thresholds
- **Comprehensive Tool Suite**: 7 specialized research and analysis tools
- **Multi-Model Support**: Compatible with any Gemini model, automatically adapts experiment type
- **Performance Metrics**: Detailed token usage tracking including `cached_content_token_count`
## Cache Configuration
```python
ContextCacheConfig(
min_tokens=4096,
ttl_seconds=600, # 10 mins for research sessions
cache_intervals=3, # Maximum invocations before cache invalidation
```
## Usage
### Run Cache Experiments
The `run_cache_experiments.py` script compares caching performance between models:
```bash
# Test any Gemini model - script automatically determines experiment type
python run_cache_experiments.py <model_name> --output results.json
# Examples:
python run_cache_experiments.py gemini-2.0-flash-001 --output gemini_2_0_results.json
python run_cache_experiments.py gemini-2.5-flash --output gemini_2_5_results.json
python run_cache_experiments.py gemini-1.5-flash --output gemini_1_5_results.json
# Run multiple iterations for averaged results
python run_cache_experiments.py <model_name> --repeat 3 --output averaged_results.json
```
### Direct Agent Usage
```bash
# Run the agent directly
adk run contributing/samples/cache_analysis/agent.py
# Web interface for debugging
adk web contributing/samples/cache_analysis
```
## Experiment Types
The script automatically determines the experiment type based on the model name:
### Models with "2.5" (e.g., gemini-2.5-flash)
- **Explicit Caching**: ADK explicit caching + Google's implicit caching
- **Implicit Only**: Google's built-in implicit caching alone
- **Measures**: Added benefit of explicit caching over Google's built-in implicit caching
### Other Models (e.g., gemini-2.0-flash-001, gemini-1.5-flash)
- **Cached**: ADK explicit context caching enabled
- **Uncached**: No caching (baseline comparison)
- **Measures**: Raw performance improvement from explicit caching vs no caching
## Tools Included
1. **analyze_data_patterns** - Statistical analysis and pattern recognition in datasets
2. **research_literature** - Academic and professional literature research with citations
3. **generate_test_scenarios** - Comprehensive test case generation and validation strategies
4. **benchmark_performance** - System performance measurement and bottleneck analysis
5. **optimize_system_performance** - Performance optimization recommendations and strategies
6. **analyze_security_vulnerabilities** - Security risk assessment and vulnerability analysis
7. **design_scalability_architecture** - Scalable system architecture design and planning
## Expected Results
### Performance vs Cost Trade-offs
**Note**: This sample uses a tool-heavy agent that may show different performance characteristics than simple text-based agents.
### Performance Improvements
- **Simple Text Agents**: Typically see 30-70% latency reduction with caching
- **Tool-Heavy Agents**: May experience higher latency due to cache setup overhead, but still provide cost benefits
- **Gemini 2.5 Flash**: Compares explicit ADK caching against Google's built-in implicit caching
### Cost Savings
- **Input Token Cost**: 75% reduction for cached content (25% of normal cost)
- **Typical Savings**: 30-60% on input costs for multi-turn conversations
- **Tool-Heavy Workloads**: Cost savings often outweigh latency trade-offs
### Token Metrics
- **Cached Content Token Count**: Non-zero values indicating successful cache hits
- **Cache Hit Ratio**: Proportion of tokens served from cache vs fresh computation
## Troubleshooting
### Zero Cached Tokens
If `cached_content_token_count` is always 0:
- Verify model names match exactly (e.g., `gemini-2.0-flash-001`)
- Check that cache configuration `min_tokens` threshold is met
- Ensure proper App-based configuration is used
### Session Errors
If seeing "Session not found" errors:
- Verify `runner.app_name` is used for session creation
- Check App vs Agent object usage in InMemoryRunner initialization
## Technical Implementation
This sample demonstrates:
- **Modern App Architecture**: App-level cache configuration following ADK best practices
- **Integration Testing**: Comprehensive cache functionality validation
- **Performance Analysis**: Detailed metrics collection and comparison methodology
- **Error Handling**: Robust session management and cache invalidation handling
@@ -0,0 +1,17 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
__all__ = ['agent']
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,272 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for cache analysis experiments."""
import asyncio
import time
from typing import Any
from typing import Dict
from typing import List
from google.adk.runners import InMemoryRunner
async def call_agent_async(
runner: InMemoryRunner, user_id: str, session_id: str, prompt: str
) -> Dict[str, Any]:
"""Call agent asynchronously and return response with token usage."""
from google.genai import types
response_parts = []
token_usage = {
"prompt_token_count": 0,
"candidates_token_count": 0,
"cached_content_token_count": 0,
"total_token_count": 0,
}
async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
new_message=types.Content(parts=[types.Part(text=prompt)], role="user"),
):
if event.content and event.content.parts:
for part in event.content.parts:
if hasattr(part, "text") and part.text:
response_parts.append(part.text)
# Collect token usage information
if event.usage_metadata:
if (
hasattr(event.usage_metadata, "prompt_token_count")
and event.usage_metadata.prompt_token_count
):
token_usage[
"prompt_token_count"
] += event.usage_metadata.prompt_token_count
if (
hasattr(event.usage_metadata, "candidates_token_count")
and event.usage_metadata.candidates_token_count
):
token_usage[
"candidates_token_count"
] += event.usage_metadata.candidates_token_count
if (
hasattr(event.usage_metadata, "cached_content_token_count")
and event.usage_metadata.cached_content_token_count
):
token_usage[
"cached_content_token_count"
] += event.usage_metadata.cached_content_token_count
if (
hasattr(event.usage_metadata, "total_token_count")
and event.usage_metadata.total_token_count
):
token_usage[
"total_token_count"
] += event.usage_metadata.total_token_count
response_text = "".join(response_parts)
return {"response_text": response_text, "token_usage": token_usage}
def get_test_prompts() -> List[str]:
"""Get a standardized set of test prompts for cache analysis experiments.
Designed for consistent behavior:
- Prompts 1-5: Will NOT trigger function calls (general questions)
- Prompts 6-10: Will trigger function calls (specific tool requests)
"""
return [
# === PROMPTS THAT WILL NOT TRIGGER FUNCTION CALLS ===
# (General questions that don't match specific tool descriptions)
"Hello, what can you do for me?",
(
"What is artificial intelligence and how does it work in modern"
" applications?"
),
"Explain the difference between machine learning and deep learning.",
"What are the main challenges in implementing AI systems at scale?",
"How do recommendation systems work in modern e-commerce platforms?",
# === PROMPTS THAT WILL TRIGGER FUNCTION CALLS ===
# (Specific requests with all required parameters clearly specified)
(
"Use benchmark_performance with system_name='E-commerce Platform',"
" metrics=['latency', 'throughput'], duration='standard',"
" load_profile='realistic'."
),
(
"Call analyze_user_behavior_patterns with"
" user_segment='premium_customers', time_period='last_30_days',"
" metrics=['engagement', 'conversion']."
),
(
"Run market_research_analysis for industry='fintech',"
" focus_areas=['user_experience', 'security'],"
" report_depth='comprehensive'."
),
(
"Execute competitive_analysis with competitors=['Netflix',"
" 'Disney+'], analysis_type='feature_comparison',"
" output_format='detailed'."
),
(
"Perform content_performance_evaluation on content_type='video',"
" platform='social_media', success_metrics=['views', 'engagement']."
),
]
async def run_experiment_batch(
agent_name: str,
runner: InMemoryRunner,
user_id: str,
session_id: str,
prompts: List[str],
experiment_name: str,
request_delay: float = 2.0,
) -> Dict[str, Any]:
"""Run a batch of prompts and collect cache metrics."""
results = []
print(f"🧪 Running {experiment_name}")
print(f"Agent: {agent_name}")
print(f"Session: {session_id}")
print(f"Prompts: {len(prompts)}")
print(f"Request delay: {request_delay}s between calls")
print("-" * 60)
for i, prompt in enumerate(prompts, 1):
print(f"[{i}/{len(prompts)}] Running test prompt...")
print(f"Prompt: {prompt[:100]}...")
try:
agent_response = await call_agent_async(
runner, user_id, session_id, prompt
)
result = {
"prompt_number": i,
"prompt": prompt,
"response_length": len(agent_response["response_text"]),
"success": True,
"error": None,
"token_usage": agent_response["token_usage"],
}
# Extract token usage for individual prompt statistics
prompt_tokens = agent_response["token_usage"].get("prompt_token_count", 0)
cached_tokens = agent_response["token_usage"].get(
"cached_content_token_count", 0
)
print(
"✅ Completed (Response:"
f" {len(agent_response['response_text'])} chars)"
)
print(
f" 📊 Tokens - Prompt: {prompt_tokens:,}, Cached: {cached_tokens:,}"
)
except Exception as e:
result = {
"prompt_number": i,
"prompt": prompt,
"response_length": 0,
"success": False,
"error": str(e),
"token_usage": {
"prompt_token_count": 0,
"candidates_token_count": 0,
"cached_content_token_count": 0,
"total_token_count": 0,
},
}
print(f"❌ Failed: {e}")
results.append(result)
# Configurable pause between requests to avoid API overload
if i < len(prompts): # Don't sleep after the last request
print(f" ⏸️ Waiting {request_delay}s before next request...")
await asyncio.sleep(request_delay)
successful_requests = sum(1 for r in results if r["success"])
# Calculate cache statistics for this batch
total_prompt_tokens = sum(
r.get("token_usage", {}).get("prompt_token_count", 0) for r in results
)
total_cached_tokens = sum(
r.get("token_usage", {}).get("cached_content_token_count", 0)
for r in results
)
# Calculate cache hit ratio
if total_prompt_tokens > 0:
cache_hit_ratio = (total_cached_tokens / total_prompt_tokens) * 100
else:
cache_hit_ratio = 0.0
# Calculate cache utilization
requests_with_cache_hits = sum(
1
for r in results
if r.get("token_usage", {}).get("cached_content_token_count", 0) > 0
)
cache_utilization_ratio = (
(requests_with_cache_hits / len(prompts)) * 100 if prompts else 0.0
)
# Average cached tokens per request
avg_cached_tokens_per_request = (
total_cached_tokens / len(prompts) if prompts else 0.0
)
summary = {
"experiment_name": experiment_name,
"agent_name": agent_name,
"total_requests": len(prompts),
"successful_requests": successful_requests,
"results": results,
"cache_statistics": {
"cache_hit_ratio_percent": cache_hit_ratio,
"cache_utilization_ratio_percent": cache_utilization_ratio,
"total_prompt_tokens": total_prompt_tokens,
"total_cached_tokens": total_cached_tokens,
"avg_cached_tokens_per_request": avg_cached_tokens_per_request,
"requests_with_cache_hits": requests_with_cache_hits,
},
}
print("-" * 60)
print(f"{experiment_name} completed:")
print(f" Total requests: {len(prompts)}")
print(f" Successful: {successful_requests}/{len(prompts)}")
print(" 📊 BATCH CACHE STATISTICS:")
print(
f" Cache Hit Ratio: {cache_hit_ratio:.1f}%"
f" ({total_cached_tokens:,} / {total_prompt_tokens:,} tokens)"
)
print(
f" Cache Utilization: {cache_utilization_ratio:.1f}%"
f" ({requests_with_cache_hits}/{len(prompts)} requests)"
)
print(f" Avg Cached Tokens/Request: {avg_cached_tokens_per_request:.0f}")
print()
return summary