feat: add add_events_to_memory facade for event-delta

Adds BaseMemoryService.add_events_to_memory(session, events=..., custom_metadata=...) and CallbackContext.add_events_to_memory(events=..., custom_metadata=...) so callers can add memories from an explicit subset of ADK events.

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 868261578
This commit is contained in:
George Weale
2026-02-10 12:14:11 -08:00
committed by Copybara-Service
parent de79bf12b5
commit 59e88972ae
7 changed files with 661 additions and 14 deletions
+32
View File
@@ -14,6 +14,8 @@
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
@@ -28,6 +30,7 @@ if TYPE_CHECKING:
from ..artifacts.base_artifact_service import ArtifactVersion
from ..auth.auth_credential import AuthCredential
from ..auth.auth_tool import AuthConfig
from ..events.event import Event
from ..events.event_actions import EventActions
from ..sessions.state import State
from .invocation_context import InvocationContext
@@ -219,3 +222,32 @@ class CallbackContext(ReadonlyContext):
await self._invocation_context.memory_service.add_session_to_memory(
self._invocation_context.session
)
async def add_events_to_memory(
self,
*,
events: Sequence[Event],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds an explicit list of events to the memory service.
Uses this callback's current session identifiers as memory scope.
Args:
events: Explicit events to add to memory.
custom_metadata: Optional standard metadata for memory generation.
Raises:
ValueError: If memory service is not available.
"""
if self._invocation_context.memory_service is None:
raise ValueError(
"Cannot add events to memory: memory service is not available."
)
await self._invocation_context.memory_service.add_events_to_memory(
app_name=self._invocation_context.session.app_name,
user_id=self._invocation_context.session.user_id,
session_id=self._invocation_context.session.id,
events=events,
custom_metadata=custom_metadata,
)
+38 -3
View File
@@ -17,6 +17,8 @@ from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from collections.abc import Mapping
from collections.abc import Sequence
from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -25,6 +27,7 @@ from pydantic import Field
from .memory_entry import MemoryEntry
if TYPE_CHECKING:
from ..events.event import Event
from ..sessions.session import Session
@@ -41,15 +44,15 @@ class SearchMemoryResponse(BaseModel):
class BaseMemoryService(ABC):
"""Base class for memory services.
The service provides functionalities to ingest sessions into memory so that
the memory can be used for user queries.
The service provides functionality to ingest conversation history into memory
so that it can be used for user queries.
"""
@abstractmethod
async def add_session_to_memory(
self,
session: Session,
):
) -> None:
"""Adds a session to the memory service.
A session may be added multiple times during its lifetime.
@@ -58,6 +61,38 @@ class BaseMemoryService(ABC):
session: The session to add.
"""
async def add_events_to_memory(
self,
*,
app_name: str,
user_id: str,
events: Sequence[Event],
session_id: str | None = None,
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds an explicit list of events to the memory service.
This is intended for cases where callers want to persist only a subset of
events (e.g., the latest turn), rather than re-ingesting the full session.
Implementations should treat `events` as an incremental update (delta) and
must not assume it represents the full session.
Implementations may ignore `session_id` if it is not applicable.
Args:
app_name: The application name for memory scope.
user_id: The user ID for memory scope.
events: The events to add to memory.
session_id: Optional session ID for memory scope/partitioning.
custom_metadata: Optional, portable metadata for memory generation. Prefer
this for service-specific fields (e.g., TTL) that may later become
first-class API parameters.
"""
raise NotImplementedError(
"This memory service does not support adding event deltas. "
"Call add_session_to_memory(session) to ingest the full session."
)
@abstractmethod
async def search_memory(
self,
@@ -13,6 +13,8 @@
# limitations under the License.
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Sequence
import re
import threading
from typing import TYPE_CHECKING
@@ -28,8 +30,10 @@ if TYPE_CHECKING:
from ..events.event import Event
from ..sessions.session import Session
_UNKNOWN_SESSION_ID = '__unknown_session_id__'
def _user_key(app_name: str, user_id: str):
def _user_key(app_name: str, user_id: str) -> str:
return f'{app_name}/{user_id}'
@@ -56,7 +60,7 @@ class InMemoryMemoryService(BaseMemoryService):
"""
@override
async def add_session_to_memory(self, session: Session):
async def add_session_to_memory(self, session: Session) -> None:
user_key = _user_key(session.app_name, session.user_id)
with self._lock:
@@ -67,6 +71,35 @@ class InMemoryMemoryService(BaseMemoryService):
if event.content and event.content.parts
]
@override
async def add_events_to_memory(
self,
*,
app_name: str,
user_id: str,
events: Sequence[Event],
session_id: str | None = None,
custom_metadata: Mapping[str, object] | None = None,
) -> None:
_ = custom_metadata
user_key = _user_key(app_name, user_id)
scoped_session_id = session_id or _UNKNOWN_SESSION_ID
events_to_add = [
event for event in events if event.content and event.content.parts
]
with self._lock:
self._session_events[user_key] = self._session_events.get(user_key, {})
existing_events = self._session_events[user_key].get(
scoped_session_id, []
)
existing_ids = {event.id for event in existing_events}
for event in events_to_add:
if event.id not in existing_ids:
existing_events.append(event)
existing_ids.add(event.id)
self._session_events[user_key][scoped_session_id] = existing_events
@override
async def search_memory(
self, *, app_name: str, user_id: str, query: str
@@ -14,6 +14,9 @@
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Sequence
from datetime import datetime
import logging
from typing import Optional
from typing import TYPE_CHECKING
@@ -29,10 +32,35 @@ from .memory_entry import MemoryEntry
if TYPE_CHECKING:
import vertexai
from ..events.event import Event
from ..sessions.session import Session
logger = logging.getLogger('google_adk.' + __name__)
_GENERATE_MEMORIES_CONFIG_KEYS = frozenset({
'disable_consolidation',
'disable_memory_revisions',
'http_options',
'metadata',
'metadata_merge_strategy',
'revision_expire_time',
'revision_labels',
'revision_ttl',
'wait_for_completion',
})
def _supports_generate_memories_metadata() -> bool:
"""Returns whether installed Vertex SDK supports config.metadata."""
try:
from vertexai._genai.types import common as vertex_common_types
except ImportError:
return False
return (
'metadata'
in vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields
)
class VertexAiMemoryBankService(BaseMemoryService):
"""Implementation of the BaseMemoryService using Vertex AI Memory Bank."""
@@ -77,28 +105,61 @@ class VertexAiMemoryBankService(BaseMemoryService):
)
@override
async def add_session_to_memory(self, session: Session):
async def add_session_to_memory(self, session: Session) -> None:
await self._add_events_to_memory_from_events(
app_name=session.app_name,
user_id=session.user_id,
events_to_process=session.events,
)
@override
async def add_events_to_memory(
self,
*,
app_name: str,
user_id: str,
events: Sequence[Event],
session_id: str | None = None,
custom_metadata: Mapping[str, object] | None = None,
) -> None:
_ = session_id
await self._add_events_to_memory_from_events(
app_name=app_name,
user_id=user_id,
events_to_process=events,
custom_metadata=custom_metadata,
)
async def _add_events_to_memory_from_events(
self,
*,
app_name: str,
user_id: str,
events_to_process: Sequence[Event],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')
events = []
for event in session.events:
direct_events = []
for event in events_to_process:
if _should_filter_out_event(event.content):
continue
if event.content:
events.append({
direct_events.append({
'content': event.content.model_dump(exclude_none=True, mode='json')
})
if events:
if direct_events:
api_client = self._get_api_client()
config = _build_generate_memories_config(custom_metadata)
operation = await api_client.agent_engines.memories.generate(
name='reasoningEngines/' + self._agent_engine_id,
direct_contents_source={'events': events},
direct_contents_source={'events': direct_events},
scope={
'app_name': session.app_name,
'user_id': session.user_id,
'app_name': app_name,
'user_id': user_id,
},
config={'wait_for_completion': False},
config=config,
)
logger.info('Generate memory response received.')
logger.debug('Generate memory response: %s', operation)
@@ -168,3 +229,120 @@ def _should_filter_out_event(content: types.Content) -> bool:
if part.text or part.inline_data or part.file_data:
return False
return True
def _build_generate_memories_config(
custom_metadata: Mapping[str, object] | None,
) -> dict[str, object]:
"""Builds a valid memories.generate config from caller metadata."""
config: dict[str, object] = {'wait_for_completion': False}
supports_metadata = _supports_generate_memories_metadata()
if not custom_metadata:
return config
logger.debug('Memory generation metadata: %s', custom_metadata)
metadata_by_key: dict[str, object] = {}
for key, value in custom_metadata.items():
if key == 'ttl':
if value is None:
continue
if custom_metadata.get('revision_ttl') is None:
config['revision_ttl'] = value
continue
if key == 'metadata':
if value is None:
continue
if not supports_metadata:
logger.warning(
'Ignoring metadata because installed Vertex SDK does not support'
' config.metadata.'
)
continue
if isinstance(value, Mapping):
config['metadata'] = _build_vertex_metadata(value)
else:
logger.warning(
'Ignoring metadata because custom_metadata["metadata"] is not a'
' mapping.'
)
continue
if key in _GENERATE_MEMORIES_CONFIG_KEYS:
if value is None:
continue
config[key] = value
else:
metadata_by_key[key] = value
if not metadata_by_key:
return config
if not supports_metadata:
logger.warning(
'Ignoring custom metadata keys %s because installed Vertex SDK does '
'not support config.metadata.',
sorted(metadata_by_key.keys()),
)
return config
existing_metadata = config.get('metadata')
if existing_metadata is None:
config['metadata'] = _build_vertex_metadata(metadata_by_key)
return config
if isinstance(existing_metadata, Mapping):
merged_metadata = dict(existing_metadata)
merged_metadata.update(_build_vertex_metadata(metadata_by_key))
config['metadata'] = merged_metadata
return config
logger.warning(
'Ignoring custom metadata keys %s because config.metadata is not a'
' mapping.',
sorted(metadata_by_key.keys()),
)
return config
def _build_vertex_metadata(
metadata_by_key: Mapping[str, object],
) -> dict[str, object]:
"""Converts metadata values to Vertex MemoryMetadataValue objects."""
vertex_metadata: dict[str, object] = {}
for key, value in metadata_by_key.items():
converted_value = _to_vertex_metadata_value(key, value)
if converted_value is None:
continue
vertex_metadata[key] = converted_value
return vertex_metadata
def _to_vertex_metadata_value(
key: str,
value: object,
) -> dict[str, object] | None:
"""Converts a metadata value to Vertex MemoryMetadataValue shape."""
if isinstance(value, bool):
return {'bool_value': value}
if isinstance(value, (int, float)):
return {'double_value': float(value)}
if isinstance(value, str):
return {'string_value': value}
if isinstance(value, datetime):
return {'timestamp_value': value}
if isinstance(value, Mapping):
if value.keys() <= {
'bool_value',
'double_value',
'string_value',
'timestamp_value',
}:
return dict(value)
return {'string_value': str(dict(value))}
if value is None:
logger.warning(
'Ignoring custom metadata key %s because its value is None.',
key,
)
return None
return {'string_value': str(value)}
@@ -373,6 +373,46 @@ class TestCallbackContextAddSessionToMemory:
await context.add_session_to_memory()
class TestCallbackContextAddEventsToMemory:
"""Tests add_events_to_memory in CallbackContext."""
@pytest.mark.asyncio
async def test_add_events_to_memory_success(self, mock_invocation_context):
"""Tests that add_events_to_memory calls the memory service correctly."""
memory_service = AsyncMock()
mock_invocation_context.memory_service = memory_service
test_event = MagicMock()
context = CallbackContext(mock_invocation_context)
await context.add_events_to_memory(
events=[test_event],
custom_metadata={"ttl": "6000s"},
)
memory_service.add_events_to_memory.assert_called_once_with(
app_name=mock_invocation_context.session.app_name,
user_id=mock_invocation_context.session.user_id,
session_id=mock_invocation_context.session.id,
events=[test_event],
custom_metadata={"ttl": "6000s"},
)
@pytest.mark.asyncio
async def test_add_events_to_memory_no_service_raises(
self, mock_invocation_context
):
"""Tests that add_events_to_memory raises ValueError with no service."""
mock_invocation_context.memory_service = None
context = CallbackContext(mock_invocation_context)
with pytest.raises(
ValueError,
match=r"Cannot add events to memory: memory service is not available\.",
):
await context.add_events_to_memory(events=[MagicMock()])
class TestToolContextAddSessionToMemory:
"""Test the add_session_to_memory method in ToolContext."""
@@ -118,6 +118,116 @@ async def test_add_session_to_memory():
assert session_memory[MOCK_SESSION_1.id][1].id == 'event-1c'
@pytest.mark.asyncio
async def test_add_events_to_memory_with_explicit_events():
"""Tests that add_events_to_memory can ingest an explicit event list."""
memory_service = InMemoryMemoryService()
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION_1.app_name,
user_id=MOCK_SESSION_1.user_id,
session_id=MOCK_SESSION_1.id,
events=[MOCK_SESSION_1.events[0]],
)
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
session_memory = memory_service._session_events[user_key]
assert len(session_memory[MOCK_SESSION_1.id]) == 1
assert session_memory[MOCK_SESSION_1.id][0].id == 'event-1a'
@pytest.mark.asyncio
async def test_add_events_to_memory_without_session_id_uses_default_bucket():
"""Tests add_events_to_memory when no session_id is provided."""
memory_service = InMemoryMemoryService()
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION_1.app_name,
user_id=MOCK_SESSION_1.user_id,
events=[MOCK_SESSION_1.events[0]],
)
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
session_memory = memory_service._session_events[user_key]
assert len(session_memory) == 1
unknown_session_events = next(iter(session_memory.values()))
assert len(unknown_session_events) == 1
assert unknown_session_events[0].id == 'event-1a'
@pytest.mark.asyncio
async def test_add_events_to_memory_alias_is_supported():
"""Tests that add_events_to_memory remains a compatibility alias."""
memory_service = InMemoryMemoryService()
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION_1.app_name,
user_id=MOCK_SESSION_1.user_id,
session_id=MOCK_SESSION_1.id,
events=[MOCK_SESSION_1.events[0]],
)
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
session_memory = memory_service._session_events[user_key]
assert [event.id for event in session_memory[MOCK_SESSION_1.id]] == [
'event-1a'
]
@pytest.mark.asyncio
async def test_add_events_to_memory_appends_without_replacing():
"""Tests that add_events_to_memory appends events rather than replacing."""
memory_service = InMemoryMemoryService()
await memory_service.add_session_to_memory(MOCK_SESSION_1)
new_event = Event(
id='event-1d',
invocation_id='inv-6',
author='user',
timestamp=12348,
content=types.Content(parts=[types.Part(text='A new fact.')]),
)
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION_1.app_name,
user_id=MOCK_SESSION_1.user_id,
session_id=MOCK_SESSION_1.id,
events=[new_event],
)
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
session_memory = memory_service._session_events[user_key]
assert [event.id for event in session_memory[MOCK_SESSION_1.id]] == [
'event-1a',
'event-1c',
'event-1d',
]
@pytest.mark.asyncio
async def test_add_events_to_memory_deduplicates_event_ids():
"""Tests that duplicate event IDs are not appended multiple times."""
memory_service = InMemoryMemoryService()
await memory_service.add_session_to_memory(MOCK_SESSION_1)
duplicate_event = Event(
id='event-1a',
invocation_id='inv-7',
author='user',
timestamp=12349,
content=types.Content(parts=[types.Part(text='Updated duplicate text.')]),
)
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION_1.app_name,
user_id=MOCK_SESSION_1.user_id,
session_id=MOCK_SESSION_1.id,
events=[duplicate_event],
)
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
session_memory = memory_service._session_events[user_key]
assert [event.id for event in session_memory[MOCK_SESSION_1.id]] == [
'event-1a',
'event-1c',
]
@pytest.mark.asyncio
async def test_add_session_with_no_events_to_memory():
"""Tests that adding a session with no events does not cause an error."""
@@ -23,11 +23,19 @@ from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankSe
from google.adk.sessions.session import Session
from google.genai import types
import pytest
from vertexai._genai.types import common as vertex_common_types
MOCK_APP_NAME = 'test-app'
MOCK_USER_ID = 'test-user'
def _supports_generate_memories_metadata() -> bool:
return (
'metadata'
in vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields
)
class _AsyncListIterator:
"""Minimal async iterator wrapper for list-like results."""
@@ -156,6 +164,217 @@ async def test_add_session_to_memory(mock_vertexai_client):
)
@pytest.mark.asyncio
async def test_add_events_to_memory_with_explicit_events_and_metadata(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
session_id=MOCK_SESSION.id,
events=[MOCK_SESSION.events[0]],
custom_metadata={'ttl': '6000s', 'source': 'agent'},
)
expected_config = {
'wait_for_completion': False,
'revision_ttl': '6000s',
}
if _supports_generate_memories_metadata():
expected_config['metadata'] = {'source': {'string_value': 'agent'}}
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
name='reasoningEngines/123',
direct_contents_source={
'events': [
{
'content': {
'parts': [{'text': 'test_content'}],
}
}
]
},
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
config=expected_config,
)
generate_config = (
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
'config'
]
)
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
@pytest.mark.asyncio
async def test_add_events_to_memory_without_session_id(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
events=[MOCK_SESSION.events[0]],
)
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
name='reasoningEngines/123',
direct_contents_source={
'events': [
{
'content': {
'parts': [{'text': 'test_content'}],
}
}
]
},
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
config={'wait_for_completion': False},
)
generate_config = (
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
'config'
]
)
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
@pytest.mark.asyncio
async def test_add_events_to_memory_merges_metadata_field_and_unknown_keys(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
session_id=MOCK_SESSION.id,
events=[MOCK_SESSION.events[0]],
custom_metadata={
'metadata': {'origin': 'unit-test'},
'source': 'agent',
},
)
expected_config = {'wait_for_completion': False}
if _supports_generate_memories_metadata():
expected_config['metadata'] = {
'origin': {'string_value': 'unit-test'},
'source': {'string_value': 'agent'},
}
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
name='reasoningEngines/123',
direct_contents_source={
'events': [
{
'content': {
'parts': [{'text': 'test_content'}],
}
}
]
},
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
config=expected_config,
)
generate_config = (
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
'config'
]
)
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
@pytest.mark.asyncio
async def test_add_events_to_memory_none_wait_for_completion_keeps_default(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
session_id=MOCK_SESSION.id,
events=[MOCK_SESSION.events[0]],
custom_metadata={'wait_for_completion': None},
)
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
name='reasoningEngines/123',
direct_contents_source={
'events': [
{
'content': {
'parts': [{'text': 'test_content'}],
}
}
]
},
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
config={'wait_for_completion': False},
)
generate_config = (
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
'config'
]
)
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
@pytest.mark.asyncio
async def test_add_events_to_memory_ttl_used_when_revision_ttl_is_none(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
session_id=MOCK_SESSION.id,
events=[MOCK_SESSION.events[0]],
custom_metadata={
'ttl': '6000s',
'revision_ttl': None,
},
)
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
name='reasoningEngines/123',
direct_contents_source={
'events': [
{
'content': {
'parts': [{'text': 'test_content'}],
}
}
]
},
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
config={
'wait_for_completion': False,
'revision_ttl': '6000s',
},
)
generate_config = (
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
'config'
]
)
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
@pytest.mark.asyncio
async def test_add_events_to_memory_with_filtered_events_skips_rpc(
mock_vertexai_client,
):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_events_to_memory(
app_name=MOCK_SESSION.app_name,
user_id=MOCK_SESSION.user_id,
session_id=MOCK_SESSION.id,
events=[MOCK_SESSION.events[1], MOCK_SESSION.events[2]],
)
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
@pytest.mark.asyncio
async def test_add_empty_session_to_memory(mock_vertexai_client):
memory_service = mock_vertex_ai_memory_bank_service()