From f27a9cfb87caecb8d52967c50637ed5ad541cd07 Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 18 Feb 2026 17:05:46 -0800 Subject: [PATCH] fix: Expand add_memory to accept MemoryEntry The `add_memory` methods in `Context` and `BaseMemoryService` now accept `MemoryEntry` objects in addition to strings. The Vertex AI Memory Bank service implementation is updated to handle these new types Co-authored-by: George Weale PiperOrigin-RevId: 872108561 --- src/google/adk/agents/context.py | 3 +- src/google/adk/memory/base_memory_service.py | 2 +- .../memory/vertex_ai_memory_bank_service.py | 226 +++++++++++++---- .../unittests/agents/test_callback_context.py | 35 ++- tests/unittests/agents/test_context.py | 35 ++- .../test_vertex_ai_memory_bank_service.py | 238 +++++++++++++++++- 6 files changed, 479 insertions(+), 60 deletions(-) diff --git a/src/google/adk/agents/context.py b/src/google/adk/agents/context.py index 70dfa05f..2e805d31 100644 --- a/src/google/adk/agents/context.py +++ b/src/google/adk/agents/context.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from ..events.event import Event from ..events.event_actions import EventActions from ..memory.base_memory_service import SearchMemoryResponse + from ..memory.memory_entry import MemoryEntry from ..sessions.state import State from ..tools.tool_confirmation import ToolConfirmation from .invocation_context import InvocationContext @@ -366,7 +367,7 @@ class Context(ReadonlyContext): async def add_memory( self, *, - memories: Sequence[str], + memories: Sequence[MemoryEntry], custom_metadata: Mapping[str, object] | None = None, ) -> None: """Adds explicit memory items directly to the memory service. diff --git a/src/google/adk/memory/base_memory_service.py b/src/google/adk/memory/base_memory_service.py index 0e831b41..55b4e8d0 100644 --- a/src/google/adk/memory/base_memory_service.py +++ b/src/google/adk/memory/base_memory_service.py @@ -99,7 +99,7 @@ class BaseMemoryService(ABC): *, app_name: str, user_id: str, - memories: Sequence[str], + memories: Sequence[MemoryEntry], custom_metadata: Mapping[str, object] | None = None, ) -> None: """Adds explicit memory items directly to the memory service. diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 0c341a6a..7bb18efa 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -57,6 +57,7 @@ _CREATE_MEMORY_CONFIG_FALLBACK_KEYS = frozenset({ 'expire_time', 'http_options', 'metadata', + 'revision_labels', 'revision_expire_time', 'revision_ttl', 'topics', @@ -215,7 +216,7 @@ class VertexAiMemoryBankService(BaseMemoryService): *, app_name: str, user_id: str, - memories: Sequence[str], + memories: Sequence[MemoryEntry], custom_metadata: Mapping[str, object] | None = None, ) -> None: """Adds explicit memory items via Vertex memories.create.""" @@ -267,20 +268,29 @@ class VertexAiMemoryBankService(BaseMemoryService): *, app_name: str, user_id: str, - memories: Sequence[str], + memories: Sequence[MemoryEntry], custom_metadata: Mapping[str, object] | None = None, ) -> None: """Adds direct memory items without server-side extraction.""" if not self._agent_engine_id: raise ValueError('Agent Engine ID is required for Memory Bank.') - memory_texts = _validate_memory_texts(memories) + normalized_memories = _normalize_memories_for_create(memories) api_client = self._get_api_client() - config = _build_create_memory_config(custom_metadata) - for memory_text in memory_texts: + for index, memory in enumerate(normalized_memories): + memory_fact = _memory_entry_to_fact(memory, index=index) + memory_metadata = _merge_custom_metadata_for_memory( + custom_metadata=custom_metadata, + memory=memory, + ) + memory_revision_labels = _revision_labels_for_memory(memory) + config = _build_create_memory_config( + memory_metadata, + memory_revision_labels=memory_revision_labels, + ) operation = await api_client.agent_engines.memories.create( name='reasoningEngines/' + self._agent_engine_id, - fact=memory_text, + fact=memory_fact, scope={ 'app_name': app_name, 'user_id': user_id, @@ -431,18 +441,21 @@ def _build_generate_memories_config( def _build_create_memory_config( custom_metadata: Mapping[str, object] | None, + *, + memory_revision_labels: Mapping[str, str] | None = None, ) -> dict[str, object]: """Builds a valid memories.create config from caller metadata.""" config: dict[str, object] = {'wait_for_completion': False} supports_metadata = _supports_create_memory_metadata() config_keys = _get_create_memory_config_keys() - if not custom_metadata: - return config + supports_revision_labels = 'revision_labels' in config_keys - logger.debug('Memory creation metadata: %s', custom_metadata) + if custom_metadata: + logger.debug('Memory creation metadata: %s', custom_metadata) metadata_by_key: dict[str, object] = {} - for key, value in custom_metadata.items(): + custom_revision_labels: dict[str, str] = {} + for key, value in (custom_metadata or {}).items(): if key == 'metadata': if value is None: continue @@ -460,6 +473,16 @@ def _build_create_memory_config( ' mapping.' ) continue + if key == 'revision_labels': + if value is None: + continue + extracted_labels = _extract_revision_labels( + value, + source='custom_metadata["revision_labels"]', + ) + if extracted_labels: + custom_revision_labels.update(extracted_labels) + continue if key in config_keys: if value is None: continue @@ -467,56 +490,155 @@ def _build_create_memory_config( else: metadata_by_key[key] = value - if not metadata_by_key: - return config + if metadata_by_key: + if not supports_metadata: + logger.warning( + 'Ignoring custom metadata keys %s because installed Vertex SDK does ' + 'not support create config.metadata.', + sorted(metadata_by_key.keys()), + ) + else: + existing_metadata = config.get('metadata') + if existing_metadata is None: + config['metadata'] = _build_vertex_metadata(metadata_by_key) + elif isinstance(existing_metadata, Mapping): + merged_metadata = dict(existing_metadata) + merged_metadata.update(_build_vertex_metadata(metadata_by_key)) + config['metadata'] = merged_metadata + else: + logger.warning( + 'Ignoring custom metadata keys %s because config.metadata is not a' + ' mapping.', + sorted(metadata_by_key.keys()), + ) - if not supports_metadata: - logger.warning( - 'Ignoring custom metadata keys %s because installed Vertex SDK does ' - 'not support create config.metadata.', - sorted(metadata_by_key.keys()), - ) - return config - - existing_metadata = config.get('metadata') - if existing_metadata is None: - config['metadata'] = _build_vertex_metadata(metadata_by_key) - return config - - if isinstance(existing_metadata, Mapping): - merged_metadata = dict(existing_metadata) - merged_metadata.update(_build_vertex_metadata(metadata_by_key)) - config['metadata'] = merged_metadata - return config - - logger.warning( - 'Ignoring custom metadata keys %s because config.metadata is not a' - ' mapping.', - sorted(metadata_by_key.keys()), - ) + revision_labels = dict(custom_revision_labels) + if memory_revision_labels: + revision_labels.update(memory_revision_labels) + if revision_labels: + if supports_revision_labels: + config['revision_labels'] = revision_labels + else: + logger.warning( + 'Ignoring revision labels %s because installed Vertex SDK does not ' + 'support create config.revision_labels.', + sorted(revision_labels.keys()), + ) return config -def _validate_memory_texts( - memories: Sequence[str], -) -> list[str]: - """Validates direct textual memory items passed to add_memory.""" +def _normalize_memories_for_create( + memories: Sequence[MemoryEntry], +) -> list[MemoryEntry]: + """Validates add_memory inputs.""" if isinstance(memories, str): - raise TypeError('memories must be a sequence of strings.') + raise TypeError('memories must be a sequence of memory items.') if not isinstance(memories, Sequence): - raise TypeError('memories must be a sequence of strings.') - memory_texts: list[str] = [] - for index, raw_memory in enumerate(memories): - if not isinstance(raw_memory, str): - raise TypeError(f'memories[{index}] must be a string.') - memory_text = raw_memory.strip() - if not memory_text: - raise ValueError(f'memories[{index}] must not be empty.') - memory_texts.append(memory_text) + raise TypeError('memories must be a sequence of memory items.') - if not memory_texts: + validated_memories: list[MemoryEntry] = [] + for index, raw_memory in enumerate(memories): + if not isinstance(raw_memory, MemoryEntry): + raise TypeError(f'memories[{index}] must be a MemoryEntry.') + validated_memories.append(raw_memory) + if not validated_memories: raise ValueError('memories must contain at least one entry.') - return memory_texts + return validated_memories + + +def _memory_entry_to_fact( + memory: MemoryEntry, + *, + index: int, +) -> str: + """Builds a memories.create fact payload from MemoryEntry text content.""" + if _should_filter_out_event(memory.content): + raise ValueError(f'memories[{index}] must include text.') + + text_parts: list[str] = [] + for part in memory.content.parts: + if part.inline_data or part.file_data: + raise ValueError( + f'memories[{index}] must include text only; inline_data and ' + 'file_data are not supported.' + ) + + if not part.text: + continue + stripped_text = part.text.strip() + if stripped_text: + text_parts.append(stripped_text) + + if not text_parts: + raise ValueError(f'memories[{index}] must include non-whitespace text.') + return '\n'.join(text_parts) + + +def _merge_custom_metadata_for_memory( + *, + custom_metadata: Mapping[str, object] | None, + memory: MemoryEntry, +) -> Mapping[str, object] | None: + """Merges write-level metadata with MemoryEntry metadata.""" + merged_metadata: dict[str, object] = {} + + if custom_metadata: + merged_metadata.update(dict(custom_metadata)) + if memory.custom_metadata: + merged_metadata.update(memory.custom_metadata) + + if not merged_metadata: + return None + return merged_metadata + + +def _revision_labels_for_memory( + memory: MemoryEntry, +) -> Mapping[str, str] | None: + """Builds revision labels from MemoryEntry revision metadata.""" + revision_labels: dict[str, str] = {} + if memory.author is not None: + revision_labels['author'] = memory.author + if memory.timestamp is not None: + revision_labels['timestamp'] = memory.timestamp + + if not revision_labels: + return None + return revision_labels + + +def _extract_revision_labels( + value: object, + *, + source: str, +) -> Mapping[str, str] | None: + """Extracts revision labels from config metadata.""" + if not isinstance(value, Mapping): + logger.warning('Ignoring %s because it is not a mapping.', source) + return None + + revision_labels: dict[str, str] = {} + for key, label_value in value.items(): + if not isinstance(key, str): + logger.warning( + 'Ignoring revision label with non-string key %r from %s.', + key, + source, + ) + continue + if not isinstance(label_value, str): + logger.warning( + 'Ignoring revision label %s from %s because its value is not a ' + 'string.', + key, + source, + ) + continue + revision_labels[key] = label_value + + if not revision_labels: + return None + return revision_labels def _build_vertex_metadata( diff --git a/tests/unittests/agents/test_callback_context.py b/tests/unittests/agents/test_callback_context.py index 00e9d6b2..28465c2e 100644 --- a/tests/unittests/agents/test_callback_context.py +++ b/tests/unittests/agents/test_callback_context.py @@ -22,7 +22,9 @@ from google.adk.agents.callback_context import CallbackContext from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_tool import AuthConfig +from google.adk.memory.memory_entry import MemoryEntry from google.adk.tools.tool_context import ToolContext +from google.genai import types from google.genai.types import Part import pytest @@ -417,7 +419,9 @@ class TestCallbackContextAddEventsToMemory: """Tests that add_memory forwards memories and metadata.""" memory_service = AsyncMock() mock_invocation_context.memory_service = memory_service - memories = ["fact one"] + memories = [ + MemoryEntry(content=types.Content(parts=[types.Part(text="fact one")])) + ] metadata = {"ttl": "6000s"} context = CallbackContext(mock_invocation_context) @@ -430,6 +434,27 @@ class TestCallbackContextAddEventsToMemory: custom_metadata=metadata, ) + @pytest.mark.asyncio + async def test_add_memory_accepts_memory_entries( + self, mock_invocation_context + ): + """Tests that add_memory forwards MemoryEntry inputs unchanged.""" + memory_service = AsyncMock() + mock_invocation_context.memory_service = memory_service + memory_entry = MemoryEntry( + content=types.Content(parts=[types.Part(text="fact one")]) + ) + + context = CallbackContext(mock_invocation_context) + await context.add_memory(memories=[memory_entry]) + + memory_service.add_memory.assert_called_once_with( + app_name=mock_invocation_context.session.app_name, + user_id=mock_invocation_context.session.user_id, + memories=[memory_entry], + custom_metadata=None, + ) + @pytest.mark.asyncio async def test_add_memory_no_service_raises(self, mock_invocation_context): """Tests that add_memory raises ValueError with no service.""" @@ -441,7 +466,13 @@ class TestCallbackContextAddEventsToMemory: ValueError, match=r"Cannot add memory: memory service is not available\.", ): - await context.add_memory(memories=["fact one"]) + await context.add_memory( + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text="fact one")]) + ) + ] + ) class TestToolContextAddSessionToMemory: diff --git a/tests/unittests/agents/test_context.py b/tests/unittests/agents/test_context.py index a2f57abe..1f9e67fb 100644 --- a/tests/unittests/agents/test_context.py +++ b/tests/unittests/agents/test_context.py @@ -22,7 +22,9 @@ from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_tool import AuthConfig from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry from google.adk.tools.tool_confirmation import ToolConfirmation +from google.genai import types from google.genai.types import Part import pytest @@ -492,7 +494,9 @@ class TestContextMemoryMethods: """Tests that add_memory forwards memories and metadata.""" memory_service = AsyncMock() mock_invocation_context.memory_service = memory_service - memories = ["fact one"] + memories = [ + MemoryEntry(content=types.Content(parts=[types.Part(text="fact one")])) + ] metadata = {"ttl": "6000s"} context = Context(mock_invocation_context) @@ -505,6 +509,27 @@ class TestContextMemoryMethods: custom_metadata=metadata, ) + @pytest.mark.asyncio + async def test_add_memory_accepts_memory_entries( + self, mock_invocation_context + ): + """Tests that add_memory forwards MemoryEntry inputs unchanged.""" + memory_service = AsyncMock() + mock_invocation_context.memory_service = memory_service + memory_entry = MemoryEntry( + content=types.Content(parts=[types.Part(text="fact one")]) + ) + + context = Context(mock_invocation_context) + await context.add_memory(memories=[memory_entry]) + + memory_service.add_memory.assert_called_once_with( + app_name=mock_invocation_context.session.app_name, + user_id=mock_invocation_context.session.user_id, + memories=[memory_entry], + custom_metadata=None, + ) + async def test_add_memory_no_service_raises(self, mock_invocation_context): """Test that add_memory raises ValueError when no service.""" mock_invocation_context.memory_service = None @@ -515,4 +540,10 @@ class TestContextMemoryMethods: ValueError, match=r"Cannot add memory: memory service is not available\.", ): - await context.add_memory(memories=["fact one"]) + await context.add_memory( + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text="fact one")]) + ) + ] + ) diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index f69ebd03..6f342a08 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -20,6 +20,7 @@ from unittest import mock from google.adk.events.event import Event from google.adk.memory import vertex_ai_memory_bank_service as memory_service_module +from google.adk.memory.memory_entry import MemoryEntry from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService from google.adk.sessions.session import Session from google.genai import types @@ -41,6 +42,13 @@ def _supports_create_memory_metadata() -> bool: return 'metadata' in vertex_common_types.AgentEngineMemoryConfig.model_fields +def _supports_create_memory_revision_labels() -> bool: + return ( + 'revision_labels' + in vertex_common_types.AgentEngineMemoryConfig.model_fields + ) + + class _AsyncListIterator: """Minimal async iterator wrapper for list-like results.""" @@ -165,6 +173,33 @@ def test_build_create_memory_config_uses_runtime_config_keys(): } +def test_build_create_memory_config_merges_revision_labels_when_supported(): + with ( + mock.patch.object( + memory_service_module, + '_get_create_memory_config_keys', + return_value=frozenset({'wait_for_completion', 'revision_labels'}), + ), + mock.patch.object( + memory_service_module, + '_supports_create_memory_metadata', + return_value=False, + ), + ): + config = memory_service_module._build_create_memory_config( + {'revision_labels': {'source': 'global'}}, + memory_revision_labels={'author': 'agent'}, + ) + + assert config == { + 'wait_for_completion': False, + 'revision_labels': { + 'source': 'global', + 'author': 'agent', + }, + } + + @pytest.fixture def mock_vertexai_client(): with mock.patch('vertexai.Client') as mock_client_constructor: @@ -437,7 +472,14 @@ async def test_add_memory_calls_create( await memory_service.add_memory( app_name=MOCK_SESSION.app_name, user_id=MOCK_SESSION.user_id, - memories=['fact one', 'fact two'], + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact two')]) + ), + ], custom_metadata={ 'ttl': '6000s', 'source': 'agent', @@ -476,6 +518,176 @@ async def test_add_memory_calls_create( vertex_common_types.AgentEngineMemoryConfig(**create_config) +@pytest.mark.asyncio +async def test_add_memory_calls_create_with_memory_entry_metadata( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + author='agent', + timestamp='2026-02-13T14:46:21Z', + content=types.Content(parts=[types.Part(text='fact one')]), + custom_metadata={'source': 'entry'}, + ) + ], + custom_metadata={'ttl': '6000s', 'source': 'global'}, + ) + + expected_config = { + 'wait_for_completion': False, + 'ttl': '6000s', + } + if _supports_create_memory_metadata(): + expected_config['metadata'] = { + 'source': {'string_value': 'entry'}, + } + if _supports_create_memory_revision_labels(): + expected_config['revision_labels'] = { + 'author': 'agent', + 'timestamp': '2026-02-13T14:46:21Z', + } + + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_awaited_once_with( + name='reasoningEngines/123', + fact='fact one', + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config=expected_config, + ) + create_config = ( + mock_vertexai_client.agent_engines.memories.create.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.AgentEngineMemoryConfig(**create_config) + + +@pytest.mark.asyncio +async def test_add_memory_calls_create_with_multimodal_content( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + with pytest.raises( + ValueError, + match=( + r'memories\[0\] must include text only; inline_data and file_data ' + r'are not supported' + ), + ): + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content( + parts=[ + types.Part(text='caption'), + types.Part( + file_data=types.FileData( + mime_type='image/png', + file_uri='gs://bucket/image.png', + ) + ), + ] + ) + ) + ], + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_memory_with_missing_text_raises( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + with pytest.raises( + ValueError, + match=r'memories\[0\] must include text', + ): + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall(name='tool') + ) + ] + ) + ) + ], + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_memory_with_whitespace_only_text_raises( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + with pytest.raises( + ValueError, + match=r'memories\[0\] must include non-whitespace text', + ): + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry(content=types.Content(parts=[types.Part(text=' ')])) + ], + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_memory_with_whitespace_and_non_text_parts_raises( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + with pytest.raises( + ValueError, + match=( + r'memories\[0\] must include text only; inline_data and file_data ' + r'are not supported' + ), + ): + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content( + parts=[ + types.Part(text=' '), + types.Part( + inline_data=types.Blob( + mime_type='image/png', + data=b'abc', + ) + ), + ] + ) + ) + ], + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + @pytest.mark.asyncio async def test_add_memory_missing_memories_raises( mock_vertexai_client, @@ -498,7 +710,10 @@ async def test_add_memory_with_invalid_memory_type_raises( mock_vertexai_client, ): memory_service = mock_vertex_ai_memory_bank_service() - with pytest.raises(TypeError, match=r'memories\[0\] must be a string'): + with pytest.raises( + TypeError, + match=r'memories\[0\] must be a MemoryEntry', + ): await memory_service.add_memory( app_name=MOCK_SESSION.app_name, user_id=MOCK_SESSION.user_id, @@ -508,6 +723,25 @@ async def test_add_memory_with_invalid_memory_type_raises( mock_vertexai_client.agent_engines.memories.create.assert_not_called() +@pytest.mark.asyncio +async def test_add_memory_with_content_type_raises( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + with pytest.raises( + TypeError, + match=r'memories\[0\] must be a MemoryEntry', + ): + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[types.Content(parts=[types.Part(text='fact one')])], + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + @pytest.mark.asyncio async def test_add_empty_session_to_memory(mock_vertexai_client): memory_service = mock_vertex_ai_memory_bank_service()