You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Expand add_memory to accept MemoryEntry
The `add_memory` methods in `Context` and `BaseMemoryService` now accept `MemoryEntry` objects in addition to strings. The Vertex AI Memory Bank service implementation is updated to handle these new types Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 872108561
This commit is contained in:
committed by
Copybara-Service
parent
2d8b6a2f5b
commit
f27a9cfb87
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
||||
from ..events.event import Event
|
||||
from ..events.event_actions import EventActions
|
||||
from ..memory.base_memory_service import SearchMemoryResponse
|
||||
from ..memory.memory_entry import MemoryEntry
|
||||
from ..sessions.state import State
|
||||
from ..tools.tool_confirmation import ToolConfirmation
|
||||
from .invocation_context import InvocationContext
|
||||
@@ -366,7 +367,7 @@ class Context(ReadonlyContext):
|
||||
async def add_memory(
|
||||
self,
|
||||
*,
|
||||
memories: Sequence[str],
|
||||
memories: Sequence[MemoryEntry],
|
||||
custom_metadata: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
"""Adds explicit memory items directly to the memory service.
|
||||
|
||||
@@ -99,7 +99,7 @@ class BaseMemoryService(ABC):
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
memories: Sequence[str],
|
||||
memories: Sequence[MemoryEntry],
|
||||
custom_metadata: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
"""Adds explicit memory items directly to the memory service.
|
||||
|
||||
@@ -57,6 +57,7 @@ _CREATE_MEMORY_CONFIG_FALLBACK_KEYS = frozenset({
|
||||
'expire_time',
|
||||
'http_options',
|
||||
'metadata',
|
||||
'revision_labels',
|
||||
'revision_expire_time',
|
||||
'revision_ttl',
|
||||
'topics',
|
||||
@@ -215,7 +216,7 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
memories: Sequence[str],
|
||||
memories: Sequence[MemoryEntry],
|
||||
custom_metadata: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
"""Adds explicit memory items via Vertex memories.create."""
|
||||
@@ -267,20 +268,29 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
memories: Sequence[str],
|
||||
memories: Sequence[MemoryEntry],
|
||||
custom_metadata: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
"""Adds direct memory items without server-side extraction."""
|
||||
if not self._agent_engine_id:
|
||||
raise ValueError('Agent Engine ID is required for Memory Bank.')
|
||||
|
||||
memory_texts = _validate_memory_texts(memories)
|
||||
normalized_memories = _normalize_memories_for_create(memories)
|
||||
api_client = self._get_api_client()
|
||||
config = _build_create_memory_config(custom_metadata)
|
||||
for memory_text in memory_texts:
|
||||
for index, memory in enumerate(normalized_memories):
|
||||
memory_fact = _memory_entry_to_fact(memory, index=index)
|
||||
memory_metadata = _merge_custom_metadata_for_memory(
|
||||
custom_metadata=custom_metadata,
|
||||
memory=memory,
|
||||
)
|
||||
memory_revision_labels = _revision_labels_for_memory(memory)
|
||||
config = _build_create_memory_config(
|
||||
memory_metadata,
|
||||
memory_revision_labels=memory_revision_labels,
|
||||
)
|
||||
operation = await api_client.agent_engines.memories.create(
|
||||
name='reasoningEngines/' + self._agent_engine_id,
|
||||
fact=memory_text,
|
||||
fact=memory_fact,
|
||||
scope={
|
||||
'app_name': app_name,
|
||||
'user_id': user_id,
|
||||
@@ -431,18 +441,21 @@ def _build_generate_memories_config(
|
||||
|
||||
def _build_create_memory_config(
|
||||
custom_metadata: Mapping[str, object] | None,
|
||||
*,
|
||||
memory_revision_labels: Mapping[str, str] | None = None,
|
||||
) -> dict[str, object]:
|
||||
"""Builds a valid memories.create config from caller metadata."""
|
||||
config: dict[str, object] = {'wait_for_completion': False}
|
||||
supports_metadata = _supports_create_memory_metadata()
|
||||
config_keys = _get_create_memory_config_keys()
|
||||
if not custom_metadata:
|
||||
return config
|
||||
supports_revision_labels = 'revision_labels' in config_keys
|
||||
|
||||
logger.debug('Memory creation metadata: %s', custom_metadata)
|
||||
if custom_metadata:
|
||||
logger.debug('Memory creation metadata: %s', custom_metadata)
|
||||
|
||||
metadata_by_key: dict[str, object] = {}
|
||||
for key, value in custom_metadata.items():
|
||||
custom_revision_labels: dict[str, str] = {}
|
||||
for key, value in (custom_metadata or {}).items():
|
||||
if key == 'metadata':
|
||||
if value is None:
|
||||
continue
|
||||
@@ -460,6 +473,16 @@ def _build_create_memory_config(
|
||||
' mapping.'
|
||||
)
|
||||
continue
|
||||
if key == 'revision_labels':
|
||||
if value is None:
|
||||
continue
|
||||
extracted_labels = _extract_revision_labels(
|
||||
value,
|
||||
source='custom_metadata["revision_labels"]',
|
||||
)
|
||||
if extracted_labels:
|
||||
custom_revision_labels.update(extracted_labels)
|
||||
continue
|
||||
if key in config_keys:
|
||||
if value is None:
|
||||
continue
|
||||
@@ -467,56 +490,155 @@ def _build_create_memory_config(
|
||||
else:
|
||||
metadata_by_key[key] = value
|
||||
|
||||
if not metadata_by_key:
|
||||
return config
|
||||
if metadata_by_key:
|
||||
if not supports_metadata:
|
||||
logger.warning(
|
||||
'Ignoring custom metadata keys %s because installed Vertex SDK does '
|
||||
'not support create config.metadata.',
|
||||
sorted(metadata_by_key.keys()),
|
||||
)
|
||||
else:
|
||||
existing_metadata = config.get('metadata')
|
||||
if existing_metadata is None:
|
||||
config['metadata'] = _build_vertex_metadata(metadata_by_key)
|
||||
elif isinstance(existing_metadata, Mapping):
|
||||
merged_metadata = dict(existing_metadata)
|
||||
merged_metadata.update(_build_vertex_metadata(metadata_by_key))
|
||||
config['metadata'] = merged_metadata
|
||||
else:
|
||||
logger.warning(
|
||||
'Ignoring custom metadata keys %s because config.metadata is not a'
|
||||
' mapping.',
|
||||
sorted(metadata_by_key.keys()),
|
||||
)
|
||||
|
||||
if not supports_metadata:
|
||||
logger.warning(
|
||||
'Ignoring custom metadata keys %s because installed Vertex SDK does '
|
||||
'not support create config.metadata.',
|
||||
sorted(metadata_by_key.keys()),
|
||||
)
|
||||
return config
|
||||
|
||||
existing_metadata = config.get('metadata')
|
||||
if existing_metadata is None:
|
||||
config['metadata'] = _build_vertex_metadata(metadata_by_key)
|
||||
return config
|
||||
|
||||
if isinstance(existing_metadata, Mapping):
|
||||
merged_metadata = dict(existing_metadata)
|
||||
merged_metadata.update(_build_vertex_metadata(metadata_by_key))
|
||||
config['metadata'] = merged_metadata
|
||||
return config
|
||||
|
||||
logger.warning(
|
||||
'Ignoring custom metadata keys %s because config.metadata is not a'
|
||||
' mapping.',
|
||||
sorted(metadata_by_key.keys()),
|
||||
)
|
||||
revision_labels = dict(custom_revision_labels)
|
||||
if memory_revision_labels:
|
||||
revision_labels.update(memory_revision_labels)
|
||||
if revision_labels:
|
||||
if supports_revision_labels:
|
||||
config['revision_labels'] = revision_labels
|
||||
else:
|
||||
logger.warning(
|
||||
'Ignoring revision labels %s because installed Vertex SDK does not '
|
||||
'support create config.revision_labels.',
|
||||
sorted(revision_labels.keys()),
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def _validate_memory_texts(
|
||||
memories: Sequence[str],
|
||||
) -> list[str]:
|
||||
"""Validates direct textual memory items passed to add_memory."""
|
||||
def _normalize_memories_for_create(
|
||||
memories: Sequence[MemoryEntry],
|
||||
) -> list[MemoryEntry]:
|
||||
"""Validates add_memory inputs."""
|
||||
if isinstance(memories, str):
|
||||
raise TypeError('memories must be a sequence of strings.')
|
||||
raise TypeError('memories must be a sequence of memory items.')
|
||||
if not isinstance(memories, Sequence):
|
||||
raise TypeError('memories must be a sequence of strings.')
|
||||
memory_texts: list[str] = []
|
||||
for index, raw_memory in enumerate(memories):
|
||||
if not isinstance(raw_memory, str):
|
||||
raise TypeError(f'memories[{index}] must be a string.')
|
||||
memory_text = raw_memory.strip()
|
||||
if not memory_text:
|
||||
raise ValueError(f'memories[{index}] must not be empty.')
|
||||
memory_texts.append(memory_text)
|
||||
raise TypeError('memories must be a sequence of memory items.')
|
||||
|
||||
if not memory_texts:
|
||||
validated_memories: list[MemoryEntry] = []
|
||||
for index, raw_memory in enumerate(memories):
|
||||
if not isinstance(raw_memory, MemoryEntry):
|
||||
raise TypeError(f'memories[{index}] must be a MemoryEntry.')
|
||||
validated_memories.append(raw_memory)
|
||||
if not validated_memories:
|
||||
raise ValueError('memories must contain at least one entry.')
|
||||
return memory_texts
|
||||
return validated_memories
|
||||
|
||||
|
||||
def _memory_entry_to_fact(
|
||||
memory: MemoryEntry,
|
||||
*,
|
||||
index: int,
|
||||
) -> str:
|
||||
"""Builds a memories.create fact payload from MemoryEntry text content."""
|
||||
if _should_filter_out_event(memory.content):
|
||||
raise ValueError(f'memories[{index}] must include text.')
|
||||
|
||||
text_parts: list[str] = []
|
||||
for part in memory.content.parts:
|
||||
if part.inline_data or part.file_data:
|
||||
raise ValueError(
|
||||
f'memories[{index}] must include text only; inline_data and '
|
||||
'file_data are not supported.'
|
||||
)
|
||||
|
||||
if not part.text:
|
||||
continue
|
||||
stripped_text = part.text.strip()
|
||||
if stripped_text:
|
||||
text_parts.append(stripped_text)
|
||||
|
||||
if not text_parts:
|
||||
raise ValueError(f'memories[{index}] must include non-whitespace text.')
|
||||
return '\n'.join(text_parts)
|
||||
|
||||
|
||||
def _merge_custom_metadata_for_memory(
|
||||
*,
|
||||
custom_metadata: Mapping[str, object] | None,
|
||||
memory: MemoryEntry,
|
||||
) -> Mapping[str, object] | None:
|
||||
"""Merges write-level metadata with MemoryEntry metadata."""
|
||||
merged_metadata: dict[str, object] = {}
|
||||
|
||||
if custom_metadata:
|
||||
merged_metadata.update(dict(custom_metadata))
|
||||
if memory.custom_metadata:
|
||||
merged_metadata.update(memory.custom_metadata)
|
||||
|
||||
if not merged_metadata:
|
||||
return None
|
||||
return merged_metadata
|
||||
|
||||
|
||||
def _revision_labels_for_memory(
|
||||
memory: MemoryEntry,
|
||||
) -> Mapping[str, str] | None:
|
||||
"""Builds revision labels from MemoryEntry revision metadata."""
|
||||
revision_labels: dict[str, str] = {}
|
||||
if memory.author is not None:
|
||||
revision_labels['author'] = memory.author
|
||||
if memory.timestamp is not None:
|
||||
revision_labels['timestamp'] = memory.timestamp
|
||||
|
||||
if not revision_labels:
|
||||
return None
|
||||
return revision_labels
|
||||
|
||||
|
||||
def _extract_revision_labels(
|
||||
value: object,
|
||||
*,
|
||||
source: str,
|
||||
) -> Mapping[str, str] | None:
|
||||
"""Extracts revision labels from config metadata."""
|
||||
if not isinstance(value, Mapping):
|
||||
logger.warning('Ignoring %s because it is not a mapping.', source)
|
||||
return None
|
||||
|
||||
revision_labels: dict[str, str] = {}
|
||||
for key, label_value in value.items():
|
||||
if not isinstance(key, str):
|
||||
logger.warning(
|
||||
'Ignoring revision label with non-string key %r from %s.',
|
||||
key,
|
||||
source,
|
||||
)
|
||||
continue
|
||||
if not isinstance(label_value, str):
|
||||
logger.warning(
|
||||
'Ignoring revision label %s from %s because its value is not a '
|
||||
'string.',
|
||||
key,
|
||||
source,
|
||||
)
|
||||
continue
|
||||
revision_labels[key] = label_value
|
||||
|
||||
if not revision_labels:
|
||||
return None
|
||||
return revision_labels
|
||||
|
||||
|
||||
def _build_vertex_metadata(
|
||||
|
||||
@@ -22,7 +22,9 @@ from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_tool import AuthConfig
|
||||
from google.adk.memory.memory_entry import MemoryEntry
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
|
||||
@@ -417,7 +419,9 @@ class TestCallbackContextAddEventsToMemory:
|
||||
"""Tests that add_memory forwards memories and metadata."""
|
||||
memory_service = AsyncMock()
|
||||
mock_invocation_context.memory_service = memory_service
|
||||
memories = ["fact one"]
|
||||
memories = [
|
||||
MemoryEntry(content=types.Content(parts=[types.Part(text="fact one")]))
|
||||
]
|
||||
metadata = {"ttl": "6000s"}
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
@@ -430,6 +434,27 @@ class TestCallbackContextAddEventsToMemory:
|
||||
custom_metadata=metadata,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_accepts_memory_entries(
|
||||
self, mock_invocation_context
|
||||
):
|
||||
"""Tests that add_memory forwards MemoryEntry inputs unchanged."""
|
||||
memory_service = AsyncMock()
|
||||
mock_invocation_context.memory_service = memory_service
|
||||
memory_entry = MemoryEntry(
|
||||
content=types.Content(parts=[types.Part(text="fact one")])
|
||||
)
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
await context.add_memory(memories=[memory_entry])
|
||||
|
||||
memory_service.add_memory.assert_called_once_with(
|
||||
app_name=mock_invocation_context.session.app_name,
|
||||
user_id=mock_invocation_context.session.user_id,
|
||||
memories=[memory_entry],
|
||||
custom_metadata=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_no_service_raises(self, mock_invocation_context):
|
||||
"""Tests that add_memory raises ValueError with no service."""
|
||||
@@ -441,7 +466,13 @@ class TestCallbackContextAddEventsToMemory:
|
||||
ValueError,
|
||||
match=r"Cannot add memory: memory service is not available\.",
|
||||
):
|
||||
await context.add_memory(memories=["fact one"])
|
||||
await context.add_memory(
|
||||
memories=[
|
||||
MemoryEntry(
|
||||
content=types.Content(parts=[types.Part(text="fact one")])
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestToolContextAddSessionToMemory:
|
||||
|
||||
@@ -22,7 +22,9 @@ from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_tool import AuthConfig
|
||||
from google.adk.memory.base_memory_service import SearchMemoryResponse
|
||||
from google.adk.memory.memory_entry import MemoryEntry
|
||||
from google.adk.tools.tool_confirmation import ToolConfirmation
|
||||
from google.genai import types
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
|
||||
@@ -492,7 +494,9 @@ class TestContextMemoryMethods:
|
||||
"""Tests that add_memory forwards memories and metadata."""
|
||||
memory_service = AsyncMock()
|
||||
mock_invocation_context.memory_service = memory_service
|
||||
memories = ["fact one"]
|
||||
memories = [
|
||||
MemoryEntry(content=types.Content(parts=[types.Part(text="fact one")]))
|
||||
]
|
||||
metadata = {"ttl": "6000s"}
|
||||
|
||||
context = Context(mock_invocation_context)
|
||||
@@ -505,6 +509,27 @@ class TestContextMemoryMethods:
|
||||
custom_metadata=metadata,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_accepts_memory_entries(
|
||||
self, mock_invocation_context
|
||||
):
|
||||
"""Tests that add_memory forwards MemoryEntry inputs unchanged."""
|
||||
memory_service = AsyncMock()
|
||||
mock_invocation_context.memory_service = memory_service
|
||||
memory_entry = MemoryEntry(
|
||||
content=types.Content(parts=[types.Part(text="fact one")])
|
||||
)
|
||||
|
||||
context = Context(mock_invocation_context)
|
||||
await context.add_memory(memories=[memory_entry])
|
||||
|
||||
memory_service.add_memory.assert_called_once_with(
|
||||
app_name=mock_invocation_context.session.app_name,
|
||||
user_id=mock_invocation_context.session.user_id,
|
||||
memories=[memory_entry],
|
||||
custom_metadata=None,
|
||||
)
|
||||
|
||||
async def test_add_memory_no_service_raises(self, mock_invocation_context):
|
||||
"""Test that add_memory raises ValueError when no service."""
|
||||
mock_invocation_context.memory_service = None
|
||||
@@ -515,4 +540,10 @@ class TestContextMemoryMethods:
|
||||
ValueError,
|
||||
match=r"Cannot add memory: memory service is not available\.",
|
||||
):
|
||||
await context.add_memory(memories=["fact one"])
|
||||
await context.add_memory(
|
||||
memories=[
|
||||
MemoryEntry(
|
||||
content=types.Content(parts=[types.Part(text="fact one")])
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ from unittest import mock
|
||||
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.memory import vertex_ai_memory_bank_service as memory_service_module
|
||||
from google.adk.memory.memory_entry import MemoryEntry
|
||||
from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai import types
|
||||
@@ -41,6 +42,13 @@ def _supports_create_memory_metadata() -> bool:
|
||||
return 'metadata' in vertex_common_types.AgentEngineMemoryConfig.model_fields
|
||||
|
||||
|
||||
def _supports_create_memory_revision_labels() -> bool:
|
||||
return (
|
||||
'revision_labels'
|
||||
in vertex_common_types.AgentEngineMemoryConfig.model_fields
|
||||
)
|
||||
|
||||
|
||||
class _AsyncListIterator:
|
||||
"""Minimal async iterator wrapper for list-like results."""
|
||||
|
||||
@@ -165,6 +173,33 @@ def test_build_create_memory_config_uses_runtime_config_keys():
|
||||
}
|
||||
|
||||
|
||||
def test_build_create_memory_config_merges_revision_labels_when_supported():
|
||||
with (
|
||||
mock.patch.object(
|
||||
memory_service_module,
|
||||
'_get_create_memory_config_keys',
|
||||
return_value=frozenset({'wait_for_completion', 'revision_labels'}),
|
||||
),
|
||||
mock.patch.object(
|
||||
memory_service_module,
|
||||
'_supports_create_memory_metadata',
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
config = memory_service_module._build_create_memory_config(
|
||||
{'revision_labels': {'source': 'global'}},
|
||||
memory_revision_labels={'author': 'agent'},
|
||||
)
|
||||
|
||||
assert config == {
|
||||
'wait_for_completion': False,
|
||||
'revision_labels': {
|
||||
'source': 'global',
|
||||
'author': 'agent',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_client():
|
||||
with mock.patch('vertexai.Client') as mock_client_constructor:
|
||||
@@ -437,7 +472,14 @@ async def test_add_memory_calls_create(
|
||||
await memory_service.add_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
memories=['fact one', 'fact two'],
|
||||
memories=[
|
||||
MemoryEntry(
|
||||
content=types.Content(parts=[types.Part(text='fact one')])
|
||||
),
|
||||
MemoryEntry(
|
||||
content=types.Content(parts=[types.Part(text='fact two')])
|
||||
),
|
||||
],
|
||||
custom_metadata={
|
||||
'ttl': '6000s',
|
||||
'source': 'agent',
|
||||
@@ -476,6 +518,176 @@ async def test_add_memory_calls_create(
|
||||
vertex_common_types.AgentEngineMemoryConfig(**create_config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_calls_create_with_memory_entry_metadata(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
await memory_service.add_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
memories=[
|
||||
MemoryEntry(
|
||||
author='agent',
|
||||
timestamp='2026-02-13T14:46:21Z',
|
||||
content=types.Content(parts=[types.Part(text='fact one')]),
|
||||
custom_metadata={'source': 'entry'},
|
||||
)
|
||||
],
|
||||
custom_metadata={'ttl': '6000s', 'source': 'global'},
|
||||
)
|
||||
|
||||
expected_config = {
|
||||
'wait_for_completion': False,
|
||||
'ttl': '6000s',
|
||||
}
|
||||
if _supports_create_memory_metadata():
|
||||
expected_config['metadata'] = {
|
||||
'source': {'string_value': 'entry'},
|
||||
}
|
||||
if _supports_create_memory_revision_labels():
|
||||
expected_config['revision_labels'] = {
|
||||
'author': 'agent',
|
||||
'timestamp': '2026-02-13T14:46:21Z',
|
||||
}
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
|
||||
mock_vertexai_client.agent_engines.memories.create.assert_awaited_once_with(
|
||||
name='reasoningEngines/123',
|
||||
fact='fact one',
|
||||
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
|
||||
config=expected_config,
|
||||
)
|
||||
create_config = (
|
||||
mock_vertexai_client.agent_engines.memories.create.call_args.kwargs[
|
||||
'config'
|
||||
]
|
||||
)
|
||||
vertex_common_types.AgentEngineMemoryConfig(**create_config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_calls_create_with_multimodal_content(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
r'memories\[0\] must include text only; inline_data and file_data '
|
||||
r'are not supported'
|
||||
),
|
||||
):
|
||||
await memory_service.add_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
memories=[
|
||||
MemoryEntry(
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(text='caption'),
|
||||
types.Part(
|
||||
file_data=types.FileData(
|
||||
mime_type='image/png',
|
||||
file_uri='gs://bucket/image.png',
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
|
||||
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_with_missing_text_raises(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r'memories\[0\] must include text',
|
||||
):
|
||||
await memory_service.add_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
memories=[
|
||||
MemoryEntry(
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(name='tool')
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
|
||||
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_with_whitespace_only_text_raises(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r'memories\[0\] must include non-whitespace text',
|
||||
):
|
||||
await memory_service.add_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
memories=[
|
||||
MemoryEntry(content=types.Content(parts=[types.Part(text=' ')]))
|
||||
],
|
||||
)
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
|
||||
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_with_whitespace_and_non_text_parts_raises(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
r'memories\[0\] must include text only; inline_data and file_data '
|
||||
r'are not supported'
|
||||
),
|
||||
):
|
||||
await memory_service.add_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
memories=[
|
||||
MemoryEntry(
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(text=' '),
|
||||
types.Part(
|
||||
inline_data=types.Blob(
|
||||
mime_type='image/png',
|
||||
data=b'abc',
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
|
||||
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_missing_memories_raises(
|
||||
mock_vertexai_client,
|
||||
@@ -498,7 +710,10 @@ async def test_add_memory_with_invalid_memory_type_raises(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
with pytest.raises(TypeError, match=r'memories\[0\] must be a string'):
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=r'memories\[0\] must be a MemoryEntry',
|
||||
):
|
||||
await memory_service.add_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
@@ -508,6 +723,25 @@ async def test_add_memory_with_invalid_memory_type_raises(
|
||||
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_memory_with_content_type_raises(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=r'memories\[0\] must be a MemoryEntry',
|
||||
):
|
||||
await memory_service.add_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
memories=[types.Content(parts=[types.Part(text='fact one')])],
|
||||
)
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
|
||||
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_empty_session_to_memory(mock_vertexai_client):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
|
||||
Reference in New Issue
Block a user