From 59e88972ae4f10274444593db0607f40cfcc597e Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 10 Feb 2026 12:14:11 -0800 Subject: [PATCH] feat: add add_events_to_memory facade for event-delta Adds BaseMemoryService.add_events_to_memory(session, events=..., custom_metadata=...) and CallbackContext.add_events_to_memory(events=..., custom_metadata=...) so callers can add memories from an explicit subset of ADK events. Co-authored-by: George Weale PiperOrigin-RevId: 868261578 --- src/google/adk/agents/callback_context.py | 32 +++ src/google/adk/memory/base_memory_service.py | 41 +++- .../adk/memory/in_memory_memory_service.py | 37 ++- .../memory/vertex_ai_memory_bank_service.py | 196 +++++++++++++++- .../unittests/agents/test_callback_context.py | 40 ++++ .../memory/test_in_memory_memory_service.py | 110 +++++++++ .../test_vertex_ai_memory_bank_service.py | 219 ++++++++++++++++++ 7 files changed, 661 insertions(+), 14 deletions(-) diff --git a/src/google/adk/agents/callback_context.py b/src/google/adk/agents/callback_context.py index c9ca750d..d733540b 100644 --- a/src/google/adk/agents/callback_context.py +++ b/src/google/adk/agents/callback_context.py @@ -14,6 +14,8 @@ from __future__ import annotations +from collections.abc import Mapping +from collections.abc import Sequence from typing import Any from typing import Optional from typing import TYPE_CHECKING @@ -28,6 +30,7 @@ if TYPE_CHECKING: from ..artifacts.base_artifact_service import ArtifactVersion from ..auth.auth_credential import AuthCredential from ..auth.auth_tool import AuthConfig + from ..events.event import Event from ..events.event_actions import EventActions from ..sessions.state import State from .invocation_context import InvocationContext @@ -219,3 +222,32 @@ class CallbackContext(ReadonlyContext): await self._invocation_context.memory_service.add_session_to_memory( self._invocation_context.session ) + + async def add_events_to_memory( + self, + *, + events: Sequence[Event], + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + """Adds an explicit list of events to the memory service. + + Uses this callback's current session identifiers as memory scope. + + Args: + events: Explicit events to add to memory. + custom_metadata: Optional standard metadata for memory generation. + + Raises: + ValueError: If memory service is not available. + """ + if self._invocation_context.memory_service is None: + raise ValueError( + "Cannot add events to memory: memory service is not available." + ) + await self._invocation_context.memory_service.add_events_to_memory( + app_name=self._invocation_context.session.app_name, + user_id=self._invocation_context.session.user_id, + session_id=self._invocation_context.session.id, + events=events, + custom_metadata=custom_metadata, + ) diff --git a/src/google/adk/memory/base_memory_service.py b/src/google/adk/memory/base_memory_service.py index 231a1e04..a4ed22a3 100644 --- a/src/google/adk/memory/base_memory_service.py +++ b/src/google/adk/memory/base_memory_service.py @@ -17,6 +17,8 @@ from __future__ import annotations from abc import ABC from abc import abstractmethod +from collections.abc import Mapping +from collections.abc import Sequence from typing import TYPE_CHECKING from pydantic import BaseModel @@ -25,6 +27,7 @@ from pydantic import Field from .memory_entry import MemoryEntry if TYPE_CHECKING: + from ..events.event import Event from ..sessions.session import Session @@ -41,15 +44,15 @@ class SearchMemoryResponse(BaseModel): class BaseMemoryService(ABC): """Base class for memory services. - The service provides functionalities to ingest sessions into memory so that - the memory can be used for user queries. + The service provides functionality to ingest conversation history into memory + so that it can be used for user queries. """ @abstractmethod async def add_session_to_memory( self, session: Session, - ): + ) -> None: """Adds a session to the memory service. A session may be added multiple times during its lifetime. @@ -58,6 +61,38 @@ class BaseMemoryService(ABC): session: The session to add. """ + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: str | None = None, + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + """Adds an explicit list of events to the memory service. + + This is intended for cases where callers want to persist only a subset of + events (e.g., the latest turn), rather than re-ingesting the full session. + + Implementations should treat `events` as an incremental update (delta) and + must not assume it represents the full session. + Implementations may ignore `session_id` if it is not applicable. + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + events: The events to add to memory. + session_id: Optional session ID for memory scope/partitioning. + custom_metadata: Optional, portable metadata for memory generation. Prefer + this for service-specific fields (e.g., TTL) that may later become + first-class API parameters. + """ + raise NotImplementedError( + "This memory service does not support adding event deltas. " + "Call add_session_to_memory(session) to ingest the full session." + ) + @abstractmethod async def search_memory( self, diff --git a/src/google/adk/memory/in_memory_memory_service.py b/src/google/adk/memory/in_memory_memory_service.py index 306cedb9..02276598 100644 --- a/src/google/adk/memory/in_memory_memory_service.py +++ b/src/google/adk/memory/in_memory_memory_service.py @@ -13,6 +13,8 @@ # limitations under the License. from __future__ import annotations +from collections.abc import Mapping +from collections.abc import Sequence import re import threading from typing import TYPE_CHECKING @@ -28,8 +30,10 @@ if TYPE_CHECKING: from ..events.event import Event from ..sessions.session import Session +_UNKNOWN_SESSION_ID = '__unknown_session_id__' -def _user_key(app_name: str, user_id: str): + +def _user_key(app_name: str, user_id: str) -> str: return f'{app_name}/{user_id}' @@ -56,7 +60,7 @@ class InMemoryMemoryService(BaseMemoryService): """ @override - async def add_session_to_memory(self, session: Session): + async def add_session_to_memory(self, session: Session) -> None: user_key = _user_key(session.app_name, session.user_id) with self._lock: @@ -67,6 +71,35 @@ class InMemoryMemoryService(BaseMemoryService): if event.content and event.content.parts ] + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: str | None = None, + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + _ = custom_metadata + user_key = _user_key(app_name, user_id) + scoped_session_id = session_id or _UNKNOWN_SESSION_ID + events_to_add = [ + event for event in events if event.content and event.content.parts + ] + + with self._lock: + self._session_events[user_key] = self._session_events.get(user_key, {}) + existing_events = self._session_events[user_key].get( + scoped_session_id, [] + ) + existing_ids = {event.id for event in existing_events} + for event in events_to_add: + if event.id not in existing_ids: + existing_events.append(event) + existing_ids.add(event.id) + self._session_events[user_key][scoped_session_id] = existing_events + @override async def search_memory( self, *, app_name: str, user_id: str, query: str diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 53f5ffeb..a33095e1 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -14,6 +14,9 @@ from __future__ import annotations +from collections.abc import Mapping +from collections.abc import Sequence +from datetime import datetime import logging from typing import Optional from typing import TYPE_CHECKING @@ -29,10 +32,35 @@ from .memory_entry import MemoryEntry if TYPE_CHECKING: import vertexai + from ..events.event import Event from ..sessions.session import Session logger = logging.getLogger('google_adk.' + __name__) +_GENERATE_MEMORIES_CONFIG_KEYS = frozenset({ + 'disable_consolidation', + 'disable_memory_revisions', + 'http_options', + 'metadata', + 'metadata_merge_strategy', + 'revision_expire_time', + 'revision_labels', + 'revision_ttl', + 'wait_for_completion', +}) + + +def _supports_generate_memories_metadata() -> bool: + """Returns whether installed Vertex SDK supports config.metadata.""" + try: + from vertexai._genai.types import common as vertex_common_types + except ImportError: + return False + return ( + 'metadata' + in vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields + ) + class VertexAiMemoryBankService(BaseMemoryService): """Implementation of the BaseMemoryService using Vertex AI Memory Bank.""" @@ -77,28 +105,61 @@ class VertexAiMemoryBankService(BaseMemoryService): ) @override - async def add_session_to_memory(self, session: Session): + async def add_session_to_memory(self, session: Session) -> None: + await self._add_events_to_memory_from_events( + app_name=session.app_name, + user_id=session.user_id, + events_to_process=session.events, + ) + + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: str | None = None, + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + _ = session_id + await self._add_events_to_memory_from_events( + app_name=app_name, + user_id=user_id, + events_to_process=events, + custom_metadata=custom_metadata, + ) + + async def _add_events_to_memory_from_events( + self, + *, + app_name: str, + user_id: str, + events_to_process: Sequence[Event], + custom_metadata: Mapping[str, object] | None = None, + ) -> None: if not self._agent_engine_id: raise ValueError('Agent Engine ID is required for Memory Bank.') - events = [] - for event in session.events: + direct_events = [] + for event in events_to_process: if _should_filter_out_event(event.content): continue if event.content: - events.append({ + direct_events.append({ 'content': event.content.model_dump(exclude_none=True, mode='json') }) - if events: + if direct_events: api_client = self._get_api_client() + config = _build_generate_memories_config(custom_metadata) operation = await api_client.agent_engines.memories.generate( name='reasoningEngines/' + self._agent_engine_id, - direct_contents_source={'events': events}, + direct_contents_source={'events': direct_events}, scope={ - 'app_name': session.app_name, - 'user_id': session.user_id, + 'app_name': app_name, + 'user_id': user_id, }, - config={'wait_for_completion': False}, + config=config, ) logger.info('Generate memory response received.') logger.debug('Generate memory response: %s', operation) @@ -168,3 +229,120 @@ def _should_filter_out_event(content: types.Content) -> bool: if part.text or part.inline_data or part.file_data: return False return True + + +def _build_generate_memories_config( + custom_metadata: Mapping[str, object] | None, +) -> dict[str, object]: + """Builds a valid memories.generate config from caller metadata.""" + config: dict[str, object] = {'wait_for_completion': False} + supports_metadata = _supports_generate_memories_metadata() + if not custom_metadata: + return config + + logger.debug('Memory generation metadata: %s', custom_metadata) + + metadata_by_key: dict[str, object] = {} + for key, value in custom_metadata.items(): + if key == 'ttl': + if value is None: + continue + if custom_metadata.get('revision_ttl') is None: + config['revision_ttl'] = value + continue + if key == 'metadata': + if value is None: + continue + if not supports_metadata: + logger.warning( + 'Ignoring metadata because installed Vertex SDK does not support' + ' config.metadata.' + ) + continue + if isinstance(value, Mapping): + config['metadata'] = _build_vertex_metadata(value) + else: + logger.warning( + 'Ignoring metadata because custom_metadata["metadata"] is not a' + ' mapping.' + ) + continue + if key in _GENERATE_MEMORIES_CONFIG_KEYS: + if value is None: + continue + config[key] = value + else: + metadata_by_key[key] = value + + if not metadata_by_key: + return config + + if not supports_metadata: + logger.warning( + 'Ignoring custom metadata keys %s because installed Vertex SDK does ' + 'not support config.metadata.', + sorted(metadata_by_key.keys()), + ) + return config + + existing_metadata = config.get('metadata') + if existing_metadata is None: + config['metadata'] = _build_vertex_metadata(metadata_by_key) + return config + + if isinstance(existing_metadata, Mapping): + merged_metadata = dict(existing_metadata) + merged_metadata.update(_build_vertex_metadata(metadata_by_key)) + config['metadata'] = merged_metadata + return config + + logger.warning( + 'Ignoring custom metadata keys %s because config.metadata is not a' + ' mapping.', + sorted(metadata_by_key.keys()), + ) + return config + + +def _build_vertex_metadata( + metadata_by_key: Mapping[str, object], +) -> dict[str, object]: + """Converts metadata values to Vertex MemoryMetadataValue objects.""" + vertex_metadata: dict[str, object] = {} + for key, value in metadata_by_key.items(): + converted_value = _to_vertex_metadata_value(key, value) + if converted_value is None: + continue + vertex_metadata[key] = converted_value + return vertex_metadata + + +def _to_vertex_metadata_value( + key: str, + value: object, +) -> dict[str, object] | None: + """Converts a metadata value to Vertex MemoryMetadataValue shape.""" + if isinstance(value, bool): + return {'bool_value': value} + if isinstance(value, (int, float)): + return {'double_value': float(value)} + if isinstance(value, str): + return {'string_value': value} + if isinstance(value, datetime): + return {'timestamp_value': value} + if isinstance(value, Mapping): + if value.keys() <= { + 'bool_value', + 'double_value', + 'string_value', + 'timestamp_value', + }: + return dict(value) + return {'string_value': str(dict(value))} + if value is None: + logger.warning( + 'Ignoring custom metadata key %s because its value is None.', + key, + ) + return None + return {'string_value': str(value)} diff --git a/tests/unittests/agents/test_callback_context.py b/tests/unittests/agents/test_callback_context.py index 5c0c6118..2a4241b6 100644 --- a/tests/unittests/agents/test_callback_context.py +++ b/tests/unittests/agents/test_callback_context.py @@ -373,6 +373,46 @@ class TestCallbackContextAddSessionToMemory: await context.add_session_to_memory() +class TestCallbackContextAddEventsToMemory: + """Tests add_events_to_memory in CallbackContext.""" + + @pytest.mark.asyncio + async def test_add_events_to_memory_success(self, mock_invocation_context): + """Tests that add_events_to_memory calls the memory service correctly.""" + memory_service = AsyncMock() + mock_invocation_context.memory_service = memory_service + test_event = MagicMock() + + context = CallbackContext(mock_invocation_context) + await context.add_events_to_memory( + events=[test_event], + custom_metadata={"ttl": "6000s"}, + ) + + memory_service.add_events_to_memory.assert_called_once_with( + app_name=mock_invocation_context.session.app_name, + user_id=mock_invocation_context.session.user_id, + session_id=mock_invocation_context.session.id, + events=[test_event], + custom_metadata={"ttl": "6000s"}, + ) + + @pytest.mark.asyncio + async def test_add_events_to_memory_no_service_raises( + self, mock_invocation_context + ): + """Tests that add_events_to_memory raises ValueError with no service.""" + mock_invocation_context.memory_service = None + + context = CallbackContext(mock_invocation_context) + + with pytest.raises( + ValueError, + match=r"Cannot add events to memory: memory service is not available\.", + ): + await context.add_events_to_memory(events=[MagicMock()]) + + class TestToolContextAddSessionToMemory: """Test the add_session_to_memory method in ToolContext.""" diff --git a/tests/unittests/memory/test_in_memory_memory_service.py b/tests/unittests/memory/test_in_memory_memory_service.py index 0b1b86be..d50692f0 100644 --- a/tests/unittests/memory/test_in_memory_memory_service.py +++ b/tests/unittests/memory/test_in_memory_memory_service.py @@ -118,6 +118,116 @@ async def test_add_session_to_memory(): assert session_memory[MOCK_SESSION_1.id][1].id == 'event-1c' +@pytest.mark.asyncio +async def test_add_events_to_memory_with_explicit_events(): + """Tests that add_events_to_memory can ingest an explicit event list.""" + memory_service = InMemoryMemoryService() + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[MOCK_SESSION_1.events[0]], + ) + + user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}' + session_memory = memory_service._session_events[user_key] + assert len(session_memory[MOCK_SESSION_1.id]) == 1 + assert session_memory[MOCK_SESSION_1.id][0].id == 'event-1a' + + +@pytest.mark.asyncio +async def test_add_events_to_memory_without_session_id_uses_default_bucket(): + """Tests add_events_to_memory when no session_id is provided.""" + memory_service = InMemoryMemoryService() + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + events=[MOCK_SESSION_1.events[0]], + ) + + user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}' + session_memory = memory_service._session_events[user_key] + assert len(session_memory) == 1 + unknown_session_events = next(iter(session_memory.values())) + assert len(unknown_session_events) == 1 + assert unknown_session_events[0].id == 'event-1a' + + +@pytest.mark.asyncio +async def test_add_events_to_memory_alias_is_supported(): + """Tests that add_events_to_memory remains a compatibility alias.""" + memory_service = InMemoryMemoryService() + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[MOCK_SESSION_1.events[0]], + ) + + user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}' + session_memory = memory_service._session_events[user_key] + assert [event.id for event in session_memory[MOCK_SESSION_1.id]] == [ + 'event-1a' + ] + + +@pytest.mark.asyncio +async def test_add_events_to_memory_appends_without_replacing(): + """Tests that add_events_to_memory appends events rather than replacing.""" + memory_service = InMemoryMemoryService() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + new_event = Event( + id='event-1d', + invocation_id='inv-6', + author='user', + timestamp=12348, + content=types.Content(parts=[types.Part(text='A new fact.')]), + ) + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[new_event], + ) + + user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}' + session_memory = memory_service._session_events[user_key] + assert [event.id for event in session_memory[MOCK_SESSION_1.id]] == [ + 'event-1a', + 'event-1c', + 'event-1d', + ] + + +@pytest.mark.asyncio +async def test_add_events_to_memory_deduplicates_event_ids(): + """Tests that duplicate event IDs are not appended multiple times.""" + memory_service = InMemoryMemoryService() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + duplicate_event = Event( + id='event-1a', + invocation_id='inv-7', + author='user', + timestamp=12349, + content=types.Content(parts=[types.Part(text='Updated duplicate text.')]), + ) + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[duplicate_event], + ) + + user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}' + session_memory = memory_service._session_events[user_key] + assert [event.id for event in session_memory[MOCK_SESSION_1.id]] == [ + 'event-1a', + 'event-1c', + ] + + @pytest.mark.asyncio async def test_add_session_with_no_events_to_memory(): """Tests that adding a session with no events does not cause an error.""" diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index 4db7905f..4bb74077 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -23,11 +23,19 @@ from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankSe from google.adk.sessions.session import Session from google.genai import types import pytest +from vertexai._genai.types import common as vertex_common_types MOCK_APP_NAME = 'test-app' MOCK_USER_ID = 'test-user' +def _supports_generate_memories_metadata() -> bool: + return ( + 'metadata' + in vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields + ) + + class _AsyncListIterator: """Minimal async iterator wrapper for list-like results.""" @@ -156,6 +164,217 @@ async def test_add_session_to_memory(mock_vertexai_client): ) +@pytest.mark.asyncio +async def test_add_events_to_memory_with_explicit_events_and_metadata( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + session_id=MOCK_SESSION.id, + events=[MOCK_SESSION.events[0]], + custom_metadata={'ttl': '6000s', 'source': 'agent'}, + ) + + expected_config = { + 'wait_for_completion': False, + 'revision_ttl': '6000s', + } + if _supports_generate_memories_metadata(): + expected_config['metadata'] = {'source': {'string_value': 'agent'}} + + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_contents_source={ + 'events': [ + { + 'content': { + 'parts': [{'text': 'test_content'}], + } + } + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config=expected_config, + ) + generate_config = ( + mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) + + +@pytest.mark.asyncio +async def test_add_events_to_memory_without_session_id( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + events=[MOCK_SESSION.events[0]], + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_contents_source={ + 'events': [ + { + 'content': { + 'parts': [{'text': 'test_content'}], + } + } + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ) + generate_config = ( + mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) + + +@pytest.mark.asyncio +async def test_add_events_to_memory_merges_metadata_field_and_unknown_keys( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + session_id=MOCK_SESSION.id, + events=[MOCK_SESSION.events[0]], + custom_metadata={ + 'metadata': {'origin': 'unit-test'}, + 'source': 'agent', + }, + ) + + expected_config = {'wait_for_completion': False} + if _supports_generate_memories_metadata(): + expected_config['metadata'] = { + 'origin': {'string_value': 'unit-test'}, + 'source': {'string_value': 'agent'}, + } + + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_contents_source={ + 'events': [ + { + 'content': { + 'parts': [{'text': 'test_content'}], + } + } + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config=expected_config, + ) + generate_config = ( + mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) + + +@pytest.mark.asyncio +async def test_add_events_to_memory_none_wait_for_completion_keeps_default( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + session_id=MOCK_SESSION.id, + events=[MOCK_SESSION.events[0]], + custom_metadata={'wait_for_completion': None}, + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_contents_source={ + 'events': [ + { + 'content': { + 'parts': [{'text': 'test_content'}], + } + } + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ) + generate_config = ( + mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) + + +@pytest.mark.asyncio +async def test_add_events_to_memory_ttl_used_when_revision_ttl_is_none( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + session_id=MOCK_SESSION.id, + events=[MOCK_SESSION.events[0]], + custom_metadata={ + 'ttl': '6000s', + 'revision_ttl': None, + }, + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_contents_source={ + 'events': [ + { + 'content': { + 'parts': [{'text': 'test_content'}], + } + } + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={ + 'wait_for_completion': False, + 'revision_ttl': '6000s', + }, + ) + generate_config = ( + mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) + + +@pytest.mark.asyncio +async def test_add_events_to_memory_with_filtered_events_skips_rpc( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + session_id=MOCK_SESSION.id, + events=[MOCK_SESSION.events[1], MOCK_SESSION.events[2]], + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + + @pytest.mark.asyncio async def test_add_empty_session_to_memory(mock_vertexai_client): memory_service = mock_vertex_ai_memory_bank_service()