You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
c66245a3b8
1. add a context cache config in app level which will apply to all agents in the app 2. pass on cache config through invocation context to llm_reqeust 3. store cache metadata in llm_response 4. lookup old cache metadata from latest event for reusing old cache 5. create new cache if old cache cannot be reused PiperOrigin-RevId: 809158578
315 lines
10 KiB
Python
315 lines
10 KiB
Python
# Copyright 2025 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for CacheMetadata."""
|
|
|
|
import time
|
|
|
|
from google.adk.models.cache_metadata import CacheMetadata
|
|
from pydantic import ValidationError
|
|
import pytest
|
|
|
|
|
|
class TestCacheMetadata:
|
|
"""Test suite for CacheMetadata."""
|
|
|
|
def test_required_fields(self):
|
|
"""Test that all required fields must be provided."""
|
|
# Valid creation with all required fields
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=5,
|
|
cached_contents_count=3,
|
|
)
|
|
|
|
assert (
|
|
metadata.cache_name
|
|
== "projects/123/locations/us-central1/cachedContents/456"
|
|
)
|
|
assert metadata.expire_time > time.time()
|
|
assert metadata.fingerprint == "abc123"
|
|
assert metadata.invocations_used == 5
|
|
assert metadata.cached_contents_count == 3
|
|
assert metadata.created_at is None # Optional field
|
|
|
|
def test_optional_created_at(self):
|
|
"""Test that created_at is optional."""
|
|
current_time = time.time()
|
|
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=3,
|
|
cached_contents_count=2,
|
|
created_at=current_time,
|
|
)
|
|
|
|
assert metadata.created_at == current_time
|
|
|
|
def test_invocations_used_validation(self):
|
|
"""Test invocations_used validation constraints."""
|
|
# Valid: zero or positive
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=0,
|
|
cached_contents_count=1,
|
|
)
|
|
assert metadata.invocations_used == 0
|
|
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=10,
|
|
cached_contents_count=1,
|
|
)
|
|
assert metadata.invocations_used == 10
|
|
|
|
# Invalid: negative
|
|
with pytest.raises(ValidationError) as exc_info:
|
|
CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=-1,
|
|
cached_contents_count=1,
|
|
)
|
|
assert "greater than or equal to 0" in str(exc_info.value)
|
|
|
|
def test_cached_contents_count_validation(self):
|
|
"""Test cached_contents_count validation constraints."""
|
|
# Valid: zero or positive
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=1,
|
|
cached_contents_count=0,
|
|
)
|
|
assert metadata.cached_contents_count == 0
|
|
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=1,
|
|
cached_contents_count=10,
|
|
)
|
|
assert metadata.cached_contents_count == 10
|
|
|
|
# Invalid: negative
|
|
with pytest.raises(ValidationError) as exc_info:
|
|
CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=1,
|
|
cached_contents_count=-1,
|
|
)
|
|
assert "greater than or equal to 0" in str(exc_info.value)
|
|
|
|
def test_expire_soon_property(self):
|
|
"""Test expire_soon property."""
|
|
# Cache that expires in 10 minutes (should not expire soon)
|
|
future_time = time.time() + 600 # 10 minutes
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=future_time,
|
|
fingerprint="abc123",
|
|
invocations_used=1,
|
|
cached_contents_count=1,
|
|
)
|
|
assert not metadata.expire_soon
|
|
|
|
# Cache that expires in 1 minute (should expire soon)
|
|
soon_time = time.time() + 60 # 1 minute
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=soon_time,
|
|
fingerprint="abc123",
|
|
invocations_used=1,
|
|
cached_contents_count=1,
|
|
)
|
|
assert metadata.expire_soon
|
|
|
|
def test_str_representation(self):
|
|
"""Test string representation."""
|
|
current_time = time.time()
|
|
expire_time = current_time + 1800 # 30 minutes
|
|
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/test456",
|
|
expire_time=expire_time,
|
|
fingerprint="abc123",
|
|
invocations_used=7,
|
|
cached_contents_count=4,
|
|
)
|
|
|
|
str_repr = str(metadata)
|
|
assert "test456" in str_repr # Cache ID
|
|
assert "used 7 invocations" in str_repr
|
|
assert "cached 4 contents" in str_repr
|
|
assert "expires in" in str_repr
|
|
|
|
def test_immutability(self):
|
|
"""Test that CacheMetadata is immutable (frozen)."""
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=5,
|
|
cached_contents_count=3,
|
|
)
|
|
|
|
# Should not be able to modify fields
|
|
with pytest.raises(ValidationError):
|
|
metadata.invocations_used = 10
|
|
|
|
def test_model_config(self):
|
|
"""Test that model config is set correctly."""
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=5,
|
|
cached_contents_count=3,
|
|
)
|
|
|
|
assert metadata.model_config["extra"] == "forbid"
|
|
assert metadata.model_config["frozen"] == True
|
|
|
|
def test_field_descriptions(self):
|
|
"""Test that fields have proper descriptions."""
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=5,
|
|
cached_contents_count=3,
|
|
)
|
|
schema = metadata.model_json_schema()
|
|
|
|
assert "invocations_used" in schema["properties"]
|
|
assert (
|
|
"Number of invocations"
|
|
in schema["properties"]["invocations_used"]["description"]
|
|
)
|
|
|
|
assert "cached_contents_count" in schema["properties"]
|
|
assert (
|
|
"Number of contents"
|
|
in schema["properties"]["cached_contents_count"]["description"]
|
|
)
|
|
|
|
def test_realistic_cache_scenarios(self):
|
|
"""Test realistic cache scenarios."""
|
|
current_time = time.time()
|
|
|
|
# Fresh cache
|
|
fresh_cache = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/fresh123",
|
|
expire_time=current_time + 1800,
|
|
fingerprint="fresh_fingerprint",
|
|
invocations_used=1,
|
|
cached_contents_count=5,
|
|
created_at=current_time,
|
|
)
|
|
assert fresh_cache.invocations_used == 1
|
|
assert not fresh_cache.expire_soon
|
|
|
|
# Well-used cache
|
|
used_cache = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/used456",
|
|
expire_time=current_time + 600,
|
|
fingerprint="used_fingerprint",
|
|
invocations_used=8,
|
|
cached_contents_count=3,
|
|
created_at=current_time - 1200,
|
|
)
|
|
assert used_cache.invocations_used == 8
|
|
|
|
# Expiring cache
|
|
expiring_cache = CacheMetadata(
|
|
cache_name=(
|
|
"projects/123/locations/us-central1/cachedContents/expiring789"
|
|
),
|
|
expire_time=current_time + 60, # 1 minute
|
|
fingerprint="expiring_fingerprint",
|
|
invocations_used=15,
|
|
cached_contents_count=10,
|
|
)
|
|
assert expiring_cache.expire_soon
|
|
|
|
def test_cache_name_extraction(self):
|
|
"""Test cache name ID extraction in string representation."""
|
|
metadata = CacheMetadata(
|
|
cache_name=(
|
|
"projects/123/locations/us-central1/cachedContents/extracted_id"
|
|
),
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=1,
|
|
cached_contents_count=2,
|
|
)
|
|
|
|
str_repr = str(metadata)
|
|
assert "extracted_id" in str_repr
|
|
|
|
def test_no_performance_metrics(self):
|
|
"""Test that performance metrics are not in CacheMetadata."""
|
|
metadata = CacheMetadata(
|
|
cache_name="projects/123/locations/us-central1/cachedContents/456",
|
|
expire_time=time.time() + 1800,
|
|
fingerprint="abc123",
|
|
invocations_used=5,
|
|
cached_contents_count=3,
|
|
)
|
|
|
|
# Verify that token counts are NOT in CacheMetadata
|
|
# (they should be in LlmResponse.usage_metadata)
|
|
assert not hasattr(metadata, "cached_tokens")
|
|
assert not hasattr(metadata, "total_tokens")
|
|
assert not hasattr(metadata, "prompt_tokens")
|
|
|
|
def test_missing_required_fields(self):
|
|
"""Test validation when required fields are missing."""
|
|
# Test each required field
|
|
required_fields = [
|
|
"cache_name",
|
|
"expire_time",
|
|
"fingerprint",
|
|
"invocations_used",
|
|
"cached_contents_count",
|
|
]
|
|
|
|
base_args = {
|
|
"cache_name": "projects/123/locations/us-central1/cachedContents/456",
|
|
"expire_time": time.time() + 1800,
|
|
"fingerprint": "abc123",
|
|
"invocations_used": 1,
|
|
"cached_contents_count": 2,
|
|
}
|
|
|
|
for field in required_fields:
|
|
args = base_args.copy()
|
|
del args[field]
|
|
|
|
with pytest.raises(ValidationError):
|
|
CacheMetadata(**args)
|