fix: Expand add_memory to accept MemoryEntry

The `add_memory` methods in `Context` and `BaseMemoryService` now accept `MemoryEntry` objects in addition to strings. The Vertex AI Memory Bank service implementation is updated to handle these new types

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 872108561
This commit is contained in:
George Weale
2026-02-18 17:05:46 -08:00
committed by Copybara-Service
parent 2d8b6a2f5b
commit f27a9cfb87
6 changed files with 479 additions and 60 deletions
+2 -1
View File
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
from ..events.event import Event
from ..events.event_actions import EventActions
from ..memory.base_memory_service import SearchMemoryResponse
from ..memory.memory_entry import MemoryEntry
from ..sessions.state import State
from ..tools.tool_confirmation import ToolConfirmation
from .invocation_context import InvocationContext
@@ -366,7 +367,7 @@ class Context(ReadonlyContext):
async def add_memory(
self,
*,
memories: Sequence[str],
memories: Sequence[MemoryEntry],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds explicit memory items directly to the memory service.
+1 -1
View File
@@ -99,7 +99,7 @@ class BaseMemoryService(ABC):
*,
app_name: str,
user_id: str,
memories: Sequence[str],
memories: Sequence[MemoryEntry],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds explicit memory items directly to the memory service.
@@ -57,6 +57,7 @@ _CREATE_MEMORY_CONFIG_FALLBACK_KEYS = frozenset({
'expire_time',
'http_options',
'metadata',
'revision_labels',
'revision_expire_time',
'revision_ttl',
'topics',
@@ -215,7 +216,7 @@ class VertexAiMemoryBankService(BaseMemoryService):
*,
app_name: str,
user_id: str,
memories: Sequence[str],
memories: Sequence[MemoryEntry],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds explicit memory items via Vertex memories.create."""
@@ -267,20 +268,29 @@ class VertexAiMemoryBankService(BaseMemoryService):
*,
app_name: str,
user_id: str,
memories: Sequence[str],
memories: Sequence[MemoryEntry],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds direct memory items without server-side extraction."""
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')
memory_texts = _validate_memory_texts(memories)
normalized_memories = _normalize_memories_for_create(memories)
api_client = self._get_api_client()
config = _build_create_memory_config(custom_metadata)
for memory_text in memory_texts:
for index, memory in enumerate(normalized_memories):
memory_fact = _memory_entry_to_fact(memory, index=index)
memory_metadata = _merge_custom_metadata_for_memory(
custom_metadata=custom_metadata,
memory=memory,
)
memory_revision_labels = _revision_labels_for_memory(memory)
config = _build_create_memory_config(
memory_metadata,
memory_revision_labels=memory_revision_labels,
)
operation = await api_client.agent_engines.memories.create(
name='reasoningEngines/' + self._agent_engine_id,
fact=memory_text,
fact=memory_fact,
scope={
'app_name': app_name,
'user_id': user_id,
@@ -431,18 +441,21 @@ def _build_generate_memories_config(
def _build_create_memory_config(
custom_metadata: Mapping[str, object] | None,
*,
memory_revision_labels: Mapping[str, str] | None = None,
) -> dict[str, object]:
"""Builds a valid memories.create config from caller metadata."""
config: dict[str, object] = {'wait_for_completion': False}
supports_metadata = _supports_create_memory_metadata()
config_keys = _get_create_memory_config_keys()
if not custom_metadata:
return config
supports_revision_labels = 'revision_labels' in config_keys
logger.debug('Memory creation metadata: %s', custom_metadata)
if custom_metadata:
logger.debug('Memory creation metadata: %s', custom_metadata)
metadata_by_key: dict[str, object] = {}
for key, value in custom_metadata.items():
custom_revision_labels: dict[str, str] = {}
for key, value in (custom_metadata or {}).items():
if key == 'metadata':
if value is None:
continue
@@ -460,6 +473,16 @@ def _build_create_memory_config(
' mapping.'
)
continue
if key == 'revision_labels':
if value is None:
continue
extracted_labels = _extract_revision_labels(
value,
source='custom_metadata["revision_labels"]',
)
if extracted_labels:
custom_revision_labels.update(extracted_labels)
continue
if key in config_keys:
if value is None:
continue
@@ -467,56 +490,155 @@ def _build_create_memory_config(
else:
metadata_by_key[key] = value
if not metadata_by_key:
return config
if metadata_by_key:
if not supports_metadata:
logger.warning(
'Ignoring custom metadata keys %s because installed Vertex SDK does '
'not support create config.metadata.',
sorted(metadata_by_key.keys()),
)
else:
existing_metadata = config.get('metadata')
if existing_metadata is None:
config['metadata'] = _build_vertex_metadata(metadata_by_key)
elif isinstance(existing_metadata, Mapping):
merged_metadata = dict(existing_metadata)
merged_metadata.update(_build_vertex_metadata(metadata_by_key))
config['metadata'] = merged_metadata
else:
logger.warning(
'Ignoring custom metadata keys %s because config.metadata is not a'
' mapping.',
sorted(metadata_by_key.keys()),
)
if not supports_metadata:
logger.warning(
'Ignoring custom metadata keys %s because installed Vertex SDK does '
'not support create config.metadata.',
sorted(metadata_by_key.keys()),
)
return config
existing_metadata = config.get('metadata')
if existing_metadata is None:
config['metadata'] = _build_vertex_metadata(metadata_by_key)
return config
if isinstance(existing_metadata, Mapping):
merged_metadata = dict(existing_metadata)
merged_metadata.update(_build_vertex_metadata(metadata_by_key))
config['metadata'] = merged_metadata
return config
logger.warning(
'Ignoring custom metadata keys %s because config.metadata is not a'
' mapping.',
sorted(metadata_by_key.keys()),
)
revision_labels = dict(custom_revision_labels)
if memory_revision_labels:
revision_labels.update(memory_revision_labels)
if revision_labels:
if supports_revision_labels:
config['revision_labels'] = revision_labels
else:
logger.warning(
'Ignoring revision labels %s because installed Vertex SDK does not '
'support create config.revision_labels.',
sorted(revision_labels.keys()),
)
return config
def _validate_memory_texts(
memories: Sequence[str],
) -> list[str]:
"""Validates direct textual memory items passed to add_memory."""
def _normalize_memories_for_create(
memories: Sequence[MemoryEntry],
) -> list[MemoryEntry]:
"""Validates add_memory inputs."""
if isinstance(memories, str):
raise TypeError('memories must be a sequence of strings.')
raise TypeError('memories must be a sequence of memory items.')
if not isinstance(memories, Sequence):
raise TypeError('memories must be a sequence of strings.')
memory_texts: list[str] = []
for index, raw_memory in enumerate(memories):
if not isinstance(raw_memory, str):
raise TypeError(f'memories[{index}] must be a string.')
memory_text = raw_memory.strip()
if not memory_text:
raise ValueError(f'memories[{index}] must not be empty.')
memory_texts.append(memory_text)
raise TypeError('memories must be a sequence of memory items.')
if not memory_texts:
validated_memories: list[MemoryEntry] = []
for index, raw_memory in enumerate(memories):
if not isinstance(raw_memory, MemoryEntry):
raise TypeError(f'memories[{index}] must be a MemoryEntry.')
validated_memories.append(raw_memory)
if not validated_memories:
raise ValueError('memories must contain at least one entry.')
return memory_texts
return validated_memories
def _memory_entry_to_fact(
memory: MemoryEntry,
*,
index: int,
) -> str:
"""Builds a memories.create fact payload from MemoryEntry text content."""
if _should_filter_out_event(memory.content):
raise ValueError(f'memories[{index}] must include text.')
text_parts: list[str] = []
for part in memory.content.parts:
if part.inline_data or part.file_data:
raise ValueError(
f'memories[{index}] must include text only; inline_data and '
'file_data are not supported.'
)
if not part.text:
continue
stripped_text = part.text.strip()
if stripped_text:
text_parts.append(stripped_text)
if not text_parts:
raise ValueError(f'memories[{index}] must include non-whitespace text.')
return '\n'.join(text_parts)
def _merge_custom_metadata_for_memory(
*,
custom_metadata: Mapping[str, object] | None,
memory: MemoryEntry,
) -> Mapping[str, object] | None:
"""Merges write-level metadata with MemoryEntry metadata."""
merged_metadata: dict[str, object] = {}
if custom_metadata:
merged_metadata.update(dict(custom_metadata))
if memory.custom_metadata:
merged_metadata.update(memory.custom_metadata)
if not merged_metadata:
return None
return merged_metadata
def _revision_labels_for_memory(
memory: MemoryEntry,
) -> Mapping[str, str] | None:
"""Builds revision labels from MemoryEntry revision metadata."""
revision_labels: dict[str, str] = {}
if memory.author is not None:
revision_labels['author'] = memory.author
if memory.timestamp is not None:
revision_labels['timestamp'] = memory.timestamp
if not revision_labels:
return None
return revision_labels
def _extract_revision_labels(
value: object,
*,
source: str,
) -> Mapping[str, str] | None:
"""Extracts revision labels from config metadata."""
if not isinstance(value, Mapping):
logger.warning('Ignoring %s because it is not a mapping.', source)
return None
revision_labels: dict[str, str] = {}
for key, label_value in value.items():
if not isinstance(key, str):
logger.warning(
'Ignoring revision label with non-string key %r from %s.',
key,
source,
)
continue
if not isinstance(label_value, str):
logger.warning(
'Ignoring revision label %s from %s because its value is not a '
'string.',
key,
source,
)
continue
revision_labels[key] = label_value
if not revision_labels:
return None
return revision_labels
def _build_vertex_metadata(
@@ -22,7 +22,9 @@ from google.adk.agents.callback_context import CallbackContext
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_tool import AuthConfig
from google.adk.memory.memory_entry import MemoryEntry
from google.adk.tools.tool_context import ToolContext
from google.genai import types
from google.genai.types import Part
import pytest
@@ -417,7 +419,9 @@ class TestCallbackContextAddEventsToMemory:
"""Tests that add_memory forwards memories and metadata."""
memory_service = AsyncMock()
mock_invocation_context.memory_service = memory_service
memories = ["fact one"]
memories = [
MemoryEntry(content=types.Content(parts=[types.Part(text="fact one")]))
]
metadata = {"ttl": "6000s"}
context = CallbackContext(mock_invocation_context)
@@ -430,6 +434,27 @@ class TestCallbackContextAddEventsToMemory:
custom_metadata=metadata,
)
@pytest.mark.asyncio
async def test_add_memory_accepts_memory_entries(
self, mock_invocation_context
):
"""Tests that add_memory forwards MemoryEntry inputs unchanged."""
memory_service = AsyncMock()
mock_invocation_context.memory_service = memory_service
memory_entry = MemoryEntry(
content=types.Content(parts=[types.Part(text="fact one")])
)
context = CallbackContext(mock_invocation_context)
await context.add_memory(memories=[memory_entry])
memory_service.add_memory.assert_called_once_with(
app_name=mock_invocation_context.session.app_name,
user_id=mock_invocation_context.session.user_id,
memories=[memory_entry],
custom_metadata=None,
)
@pytest.mark.asyncio
async def test_add_memory_no_service_raises(self, mock_invocation_context):
"""Tests that add_memory raises ValueError with no service."""
@@ -441,7 +466,13 @@ class TestCallbackContextAddEventsToMemory:
ValueError,
match=r"Cannot add memory: memory service is not available\.",
):
await context.add_memory(memories=["fact one"])
await context.add_memory(
memories=[
MemoryEntry(
content=types.Content(parts=[types.Part(text="fact one")])
)
]
)
class TestToolContextAddSessionToMemory:
+33 -2
View File
@@ -22,7 +22,9 @@ from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_tool import AuthConfig
from google.adk.memory.base_memory_service import SearchMemoryResponse
from google.adk.memory.memory_entry import MemoryEntry
from google.adk.tools.tool_confirmation import ToolConfirmation
from google.genai import types
from google.genai.types import Part
import pytest
@@ -492,7 +494,9 @@ class TestContextMemoryMethods:
"""Tests that add_memory forwards memories and metadata."""
memory_service = AsyncMock()
mock_invocation_context.memory_service = memory_service
memories = ["fact one"]
memories = [
MemoryEntry(content=types.Content(parts=[types.Part(text="fact one")]))
]
metadata = {"ttl": "6000s"}
context = Context(mock_invocation_context)
@@ -505,6 +509,27 @@ class TestContextMemoryMethods:
custom_metadata=metadata,
)
@pytest.mark.asyncio
async def test_add_memory_accepts_memory_entries(
self, mock_invocation_context
):
"""Tests that add_memory forwards MemoryEntry inputs unchanged."""
memory_service = AsyncMock()
mock_invocation_context.memory_service = memory_service
memory_entry = MemoryEntry(
content=types.Content(parts=[types.Part(text="fact one")])
)
context = Context(mock_invocation_context)
await context.add_memory(memories=[memory_entry])
memory_service.add_memory.assert_called_once_with(
app_name=mock_invocation_context.session.app_name,
user_id=mock_invocation_context.session.user_id,
memories=[memory_entry],
custom_metadata=None,
)
async def test_add_memory_no_service_raises(self, mock_invocation_context):
"""Test that add_memory raises ValueError when no service."""
mock_invocation_context.memory_service = None
@@ -515,4 +540,10 @@ class TestContextMemoryMethods:
ValueError,
match=r"Cannot add memory: memory service is not available\.",
):
await context.add_memory(memories=["fact one"])
await context.add_memory(
memories=[
MemoryEntry(
content=types.Content(parts=[types.Part(text="fact one")])
)
]
)
@@ -20,6 +20,7 @@ from unittest import mock
from google.adk.events.event import Event
from google.adk.memory import vertex_ai_memory_bank_service as memory_service_module
from google.adk.memory.memory_entry import MemoryEntry
from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
from google.adk.sessions.session import Session
from google.genai import types
@@ -41,6 +42,13 @@ def _supports_create_memory_metadata() -> bool:
return 'metadata' in vertex_common_types.AgentEngineMemoryConfig.model_fields
def _supports_create_memory_revision_labels() -> bool:
return (
'revision_labels'
in vertex_common_types.AgentEngineMemoryConfig.model_fields
)
class _AsyncListIterator:
"""Minimal async iterator wrapper for list-like results."""
@@ -165,6 +173,33 @@ def test_build_create_memory_config_uses_runtime_config_keys():
}
def test_build_create_memory_config_merges_revision_labels_when_supported():
with (
mock.patch.object(
memory_service_module,
'_get_create_memory_config_keys',
return_value=frozenset({'wait_for_completion', 'revision_labels'}),
),
mock.patch.object(
memory_service_module,
'_supports_create_memory_metadata',
return_value=False,
),
):
config = memory_service_module._build_create_memory_config(
{'revision_labels': {'source': 'global'}},
memory_revision_labels={'author': 'agent'},
)
assert config == {
'wait_for_completion': False,
'revision_labels': {
'source': 'global',
'author': 'agent',
},
}
@pytest.fixture
def mock_vertexai_client():
with mock.patch('vertexai.Client') as mock_client_constructor:
@@ -437,7 +472,14 @@ async def test_add_memory_calls_create(
await memory_service.add_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
memories=['fact one', 'fact two'],
memories=[
MemoryEntry(
content=types.Content(parts=[types.Part(text='fact one')])
),
MemoryEntry(
content=types.Content(parts=[types.Part(text='fact two')])
),
],
custom_metadata={
'ttl': '6000s',
'source': 'agent',
@@ -476,6 +518,176 @@ async def test_add_memory_calls_create(
vertex_common_types.AgentEngineMemoryConfig(**create_config)
@pytest.mark.asyncio
async def test_add_memory_calls_create_with_memory_entry_metadata(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
memories=[
MemoryEntry(
author='agent',
timestamp='2026-02-13T14:46:21Z',
content=types.Content(parts=[types.Part(text='fact one')]),
custom_metadata={'source': 'entry'},
)
],
custom_metadata={'ttl': '6000s', 'source': 'global'},
)
expected_config = {
'wait_for_completion': False,
'ttl': '6000s',
}
if _supports_create_memory_metadata():
expected_config['metadata'] = {
'source': {'string_value': 'entry'},
}
if _supports_create_memory_revision_labels():
expected_config['revision_labels'] = {
'author': 'agent',
'timestamp': '2026-02-13T14:46:21Z',
}
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
mock_vertexai_client.agent_engines.memories.create.assert_awaited_once_with(
name='reasoningEngines/123',
fact='fact one',
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
config=expected_config,
)
create_config = (
mock_vertexai_client.agent_engines.memories.create.call_args.kwargs[
'config'
]
)
vertex_common_types.AgentEngineMemoryConfig(**create_config)
@pytest.mark.asyncio
async def test_add_memory_calls_create_with_multimodal_content(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
with pytest.raises(
ValueError,
match=(
r'memories\[0\] must include text only; inline_data and file_data '
r'are not supported'
),
):
await memory_service.add_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
memories=[
MemoryEntry(
content=types.Content(
parts=[
types.Part(text='caption'),
types.Part(
file_data=types.FileData(
mime_type='image/png',
file_uri='gs://bucket/image.png',
)
),
]
)
)
],
)
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
@pytest.mark.asyncio
async def test_add_memory_with_missing_text_raises(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
with pytest.raises(
ValueError,
match=r'memories\[0\] must include text',
):
await memory_service.add_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
memories=[
MemoryEntry(
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall(name='tool')
)
]
)
)
],
)
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
@pytest.mark.asyncio
async def test_add_memory_with_whitespace_only_text_raises(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
with pytest.raises(
ValueError,
match=r'memories\[0\] must include non-whitespace text',
):
await memory_service.add_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
memories=[
MemoryEntry(content=types.Content(parts=[types.Part(text=' ')]))
],
)
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
@pytest.mark.asyncio
async def test_add_memory_with_whitespace_and_non_text_parts_raises(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
with pytest.raises(
ValueError,
match=(
r'memories\[0\] must include text only; inline_data and file_data '
r'are not supported'
),
):
await memory_service.add_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
memories=[
MemoryEntry(
content=types.Content(
parts=[
types.Part(text=' '),
types.Part(
inline_data=types.Blob(
mime_type='image/png',
data=b'abc',
)
),
]
)
)
],
)
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
@pytest.mark.asyncio
async def test_add_memory_missing_memories_raises(
mock_vertexai_client,
@@ -498,7 +710,10 @@ async def test_add_memory_with_invalid_memory_type_raises(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
with pytest.raises(TypeError, match=r'memories\[0\] must be a string'):
with pytest.raises(
TypeError,
match=r'memories\[0\] must be a MemoryEntry',
):
await memory_service.add_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
@@ -508,6 +723,25 @@ async def test_add_memory_with_invalid_memory_type_raises(
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
@pytest.mark.asyncio
async def test_add_memory_with_content_type_raises(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
with pytest.raises(
TypeError,
match=r'memories\[0\] must be a MemoryEntry',
):
await memory_service.add_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
memories=[types.Content(parts=[types.Part(text='fact one')])],
)
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
@pytest.mark.asyncio
async def test_add_empty_session_to_memory(mock_vertexai_client):
memory_service = mock_vertex_ai_memory_bank_service()