You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: add add_events_to_memory facade for event-delta
Adds BaseMemoryService.add_events_to_memory(session, events=..., custom_metadata=...) and CallbackContext.add_events_to_memory(events=..., custom_metadata=...) so callers can add memories from an explicit subset of ADK events. Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 868261578
This commit is contained in:
committed by
Copybara-Service
parent
de79bf12b5
commit
59e88972ae
@@ -14,6 +14,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -28,6 +30,7 @@ if TYPE_CHECKING:
|
||||
from ..artifacts.base_artifact_service import ArtifactVersion
|
||||
from ..auth.auth_credential import AuthCredential
|
||||
from ..auth.auth_tool import AuthConfig
|
||||
from ..events.event import Event
|
||||
from ..events.event_actions import EventActions
|
||||
from ..sessions.state import State
|
||||
from .invocation_context import InvocationContext
|
||||
@@ -219,3 +222,32 @@ class CallbackContext(ReadonlyContext):
|
||||
await self._invocation_context.memory_service.add_session_to_memory(
|
||||
self._invocation_context.session
|
||||
)
|
||||
|
||||
async def add_events_to_memory(
|
||||
self,
|
||||
*,
|
||||
events: Sequence[Event],
|
||||
custom_metadata: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
"""Adds an explicit list of events to the memory service.
|
||||
|
||||
Uses this callback's current session identifiers as memory scope.
|
||||
|
||||
Args:
|
||||
events: Explicit events to add to memory.
|
||||
custom_metadata: Optional standard metadata for memory generation.
|
||||
|
||||
Raises:
|
||||
ValueError: If memory service is not available.
|
||||
"""
|
||||
if self._invocation_context.memory_service is None:
|
||||
raise ValueError(
|
||||
"Cannot add events to memory: memory service is not available."
|
||||
)
|
||||
await self._invocation_context.memory_service.add_events_to_memory(
|
||||
app_name=self._invocation_context.session.app_name,
|
||||
user_id=self._invocation_context.session.user_id,
|
||||
session_id=self._invocation_context.session.id,
|
||||
events=events,
|
||||
custom_metadata=custom_metadata,
|
||||
)
|
||||
|
||||
@@ -17,6 +17,8 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -25,6 +27,7 @@ from pydantic import Field
|
||||
from .memory_entry import MemoryEntry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..events.event import Event
|
||||
from ..sessions.session import Session
|
||||
|
||||
|
||||
@@ -41,15 +44,15 @@ class SearchMemoryResponse(BaseModel):
|
||||
class BaseMemoryService(ABC):
|
||||
"""Base class for memory services.
|
||||
|
||||
The service provides functionalities to ingest sessions into memory so that
|
||||
the memory can be used for user queries.
|
||||
The service provides functionality to ingest conversation history into memory
|
||||
so that it can be used for user queries.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def add_session_to_memory(
|
||||
self,
|
||||
session: Session,
|
||||
):
|
||||
) -> None:
|
||||
"""Adds a session to the memory service.
|
||||
|
||||
A session may be added multiple times during its lifetime.
|
||||
@@ -58,6 +61,38 @@ class BaseMemoryService(ABC):
|
||||
session: The session to add.
|
||||
"""
|
||||
|
||||
async def add_events_to_memory(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
events: Sequence[Event],
|
||||
session_id: str | None = None,
|
||||
custom_metadata: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
"""Adds an explicit list of events to the memory service.
|
||||
|
||||
This is intended for cases where callers want to persist only a subset of
|
||||
events (e.g., the latest turn), rather than re-ingesting the full session.
|
||||
|
||||
Implementations should treat `events` as an incremental update (delta) and
|
||||
must not assume it represents the full session.
|
||||
Implementations may ignore `session_id` if it is not applicable.
|
||||
|
||||
Args:
|
||||
app_name: The application name for memory scope.
|
||||
user_id: The user ID for memory scope.
|
||||
events: The events to add to memory.
|
||||
session_id: Optional session ID for memory scope/partitioning.
|
||||
custom_metadata: Optional, portable metadata for memory generation. Prefer
|
||||
this for service-specific fields (e.g., TTL) that may later become
|
||||
first-class API parameters.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"This memory service does not support adding event deltas. "
|
||||
"Call add_session_to_memory(session) to ingest the full session."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def search_memory(
|
||||
self,
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
import re
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -28,8 +30,10 @@ if TYPE_CHECKING:
|
||||
from ..events.event import Event
|
||||
from ..sessions.session import Session
|
||||
|
||||
_UNKNOWN_SESSION_ID = '__unknown_session_id__'
|
||||
|
||||
def _user_key(app_name: str, user_id: str):
|
||||
|
||||
def _user_key(app_name: str, user_id: str) -> str:
|
||||
return f'{app_name}/{user_id}'
|
||||
|
||||
|
||||
@@ -56,7 +60,7 @@ class InMemoryMemoryService(BaseMemoryService):
|
||||
"""
|
||||
|
||||
@override
|
||||
async def add_session_to_memory(self, session: Session):
|
||||
async def add_session_to_memory(self, session: Session) -> None:
|
||||
user_key = _user_key(session.app_name, session.user_id)
|
||||
|
||||
with self._lock:
|
||||
@@ -67,6 +71,35 @@ class InMemoryMemoryService(BaseMemoryService):
|
||||
if event.content and event.content.parts
|
||||
]
|
||||
|
||||
@override
|
||||
async def add_events_to_memory(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
events: Sequence[Event],
|
||||
session_id: str | None = None,
|
||||
custom_metadata: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
_ = custom_metadata
|
||||
user_key = _user_key(app_name, user_id)
|
||||
scoped_session_id = session_id or _UNKNOWN_SESSION_ID
|
||||
events_to_add = [
|
||||
event for event in events if event.content and event.content.parts
|
||||
]
|
||||
|
||||
with self._lock:
|
||||
self._session_events[user_key] = self._session_events.get(user_key, {})
|
||||
existing_events = self._session_events[user_key].get(
|
||||
scoped_session_id, []
|
||||
)
|
||||
existing_ids = {event.id for event in existing_events}
|
||||
for event in events_to_add:
|
||||
if event.id not in existing_ids:
|
||||
existing_events.append(event)
|
||||
existing_ids.add(event.id)
|
||||
self._session_events[user_key][scoped_session_id] = existing_events
|
||||
|
||||
@override
|
||||
async def search_memory(
|
||||
self, *, app_name: str, user_id: str, query: str
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -29,10 +32,35 @@ from .memory_entry import MemoryEntry
|
||||
if TYPE_CHECKING:
|
||||
import vertexai
|
||||
|
||||
from ..events.event import Event
|
||||
from ..sessions.session import Session
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
_GENERATE_MEMORIES_CONFIG_KEYS = frozenset({
|
||||
'disable_consolidation',
|
||||
'disable_memory_revisions',
|
||||
'http_options',
|
||||
'metadata',
|
||||
'metadata_merge_strategy',
|
||||
'revision_expire_time',
|
||||
'revision_labels',
|
||||
'revision_ttl',
|
||||
'wait_for_completion',
|
||||
})
|
||||
|
||||
|
||||
def _supports_generate_memories_metadata() -> bool:
|
||||
"""Returns whether installed Vertex SDK supports config.metadata."""
|
||||
try:
|
||||
from vertexai._genai.types import common as vertex_common_types
|
||||
except ImportError:
|
||||
return False
|
||||
return (
|
||||
'metadata'
|
||||
in vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields
|
||||
)
|
||||
|
||||
|
||||
class VertexAiMemoryBankService(BaseMemoryService):
|
||||
"""Implementation of the BaseMemoryService using Vertex AI Memory Bank."""
|
||||
@@ -77,28 +105,61 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
||||
)
|
||||
|
||||
@override
|
||||
async def add_session_to_memory(self, session: Session):
|
||||
async def add_session_to_memory(self, session: Session) -> None:
|
||||
await self._add_events_to_memory_from_events(
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
events_to_process=session.events,
|
||||
)
|
||||
|
||||
@override
|
||||
async def add_events_to_memory(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
events: Sequence[Event],
|
||||
session_id: str | None = None,
|
||||
custom_metadata: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
_ = session_id
|
||||
await self._add_events_to_memory_from_events(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
events_to_process=events,
|
||||
custom_metadata=custom_metadata,
|
||||
)
|
||||
|
||||
async def _add_events_to_memory_from_events(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
events_to_process: Sequence[Event],
|
||||
custom_metadata: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
if not self._agent_engine_id:
|
||||
raise ValueError('Agent Engine ID is required for Memory Bank.')
|
||||
|
||||
events = []
|
||||
for event in session.events:
|
||||
direct_events = []
|
||||
for event in events_to_process:
|
||||
if _should_filter_out_event(event.content):
|
||||
continue
|
||||
if event.content:
|
||||
events.append({
|
||||
direct_events.append({
|
||||
'content': event.content.model_dump(exclude_none=True, mode='json')
|
||||
})
|
||||
if events:
|
||||
if direct_events:
|
||||
api_client = self._get_api_client()
|
||||
config = _build_generate_memories_config(custom_metadata)
|
||||
operation = await api_client.agent_engines.memories.generate(
|
||||
name='reasoningEngines/' + self._agent_engine_id,
|
||||
direct_contents_source={'events': events},
|
||||
direct_contents_source={'events': direct_events},
|
||||
scope={
|
||||
'app_name': session.app_name,
|
||||
'user_id': session.user_id,
|
||||
'app_name': app_name,
|
||||
'user_id': user_id,
|
||||
},
|
||||
config={'wait_for_completion': False},
|
||||
config=config,
|
||||
)
|
||||
logger.info('Generate memory response received.')
|
||||
logger.debug('Generate memory response: %s', operation)
|
||||
@@ -168,3 +229,120 @@ def _should_filter_out_event(content: types.Content) -> bool:
|
||||
if part.text or part.inline_data or part.file_data:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _build_generate_memories_config(
|
||||
custom_metadata: Mapping[str, object] | None,
|
||||
) -> dict[str, object]:
|
||||
"""Builds a valid memories.generate config from caller metadata."""
|
||||
config: dict[str, object] = {'wait_for_completion': False}
|
||||
supports_metadata = _supports_generate_memories_metadata()
|
||||
if not custom_metadata:
|
||||
return config
|
||||
|
||||
logger.debug('Memory generation metadata: %s', custom_metadata)
|
||||
|
||||
metadata_by_key: dict[str, object] = {}
|
||||
for key, value in custom_metadata.items():
|
||||
if key == 'ttl':
|
||||
if value is None:
|
||||
continue
|
||||
if custom_metadata.get('revision_ttl') is None:
|
||||
config['revision_ttl'] = value
|
||||
continue
|
||||
if key == 'metadata':
|
||||
if value is None:
|
||||
continue
|
||||
if not supports_metadata:
|
||||
logger.warning(
|
||||
'Ignoring metadata because installed Vertex SDK does not support'
|
||||
' config.metadata.'
|
||||
)
|
||||
continue
|
||||
if isinstance(value, Mapping):
|
||||
config['metadata'] = _build_vertex_metadata(value)
|
||||
else:
|
||||
logger.warning(
|
||||
'Ignoring metadata because custom_metadata["metadata"] is not a'
|
||||
' mapping.'
|
||||
)
|
||||
continue
|
||||
if key in _GENERATE_MEMORIES_CONFIG_KEYS:
|
||||
if value is None:
|
||||
continue
|
||||
config[key] = value
|
||||
else:
|
||||
metadata_by_key[key] = value
|
||||
|
||||
if not metadata_by_key:
|
||||
return config
|
||||
|
||||
if not supports_metadata:
|
||||
logger.warning(
|
||||
'Ignoring custom metadata keys %s because installed Vertex SDK does '
|
||||
'not support config.metadata.',
|
||||
sorted(metadata_by_key.keys()),
|
||||
)
|
||||
return config
|
||||
|
||||
existing_metadata = config.get('metadata')
|
||||
if existing_metadata is None:
|
||||
config['metadata'] = _build_vertex_metadata(metadata_by_key)
|
||||
return config
|
||||
|
||||
if isinstance(existing_metadata, Mapping):
|
||||
merged_metadata = dict(existing_metadata)
|
||||
merged_metadata.update(_build_vertex_metadata(metadata_by_key))
|
||||
config['metadata'] = merged_metadata
|
||||
return config
|
||||
|
||||
logger.warning(
|
||||
'Ignoring custom metadata keys %s because config.metadata is not a'
|
||||
' mapping.',
|
||||
sorted(metadata_by_key.keys()),
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def _build_vertex_metadata(
|
||||
metadata_by_key: Mapping[str, object],
|
||||
) -> dict[str, object]:
|
||||
"""Converts metadata values to Vertex MemoryMetadataValue objects."""
|
||||
vertex_metadata: dict[str, object] = {}
|
||||
for key, value in metadata_by_key.items():
|
||||
converted_value = _to_vertex_metadata_value(key, value)
|
||||
if converted_value is None:
|
||||
continue
|
||||
vertex_metadata[key] = converted_value
|
||||
return vertex_metadata
|
||||
|
||||
|
||||
def _to_vertex_metadata_value(
|
||||
key: str,
|
||||
value: object,
|
||||
) -> dict[str, object] | None:
|
||||
"""Converts a metadata value to Vertex MemoryMetadataValue shape."""
|
||||
if isinstance(value, bool):
|
||||
return {'bool_value': value}
|
||||
if isinstance(value, (int, float)):
|
||||
return {'double_value': float(value)}
|
||||
if isinstance(value, str):
|
||||
return {'string_value': value}
|
||||
if isinstance(value, datetime):
|
||||
return {'timestamp_value': value}
|
||||
if isinstance(value, Mapping):
|
||||
if value.keys() <= {
|
||||
'bool_value',
|
||||
'double_value',
|
||||
'string_value',
|
||||
'timestamp_value',
|
||||
}:
|
||||
return dict(value)
|
||||
return {'string_value': str(dict(value))}
|
||||
if value is None:
|
||||
logger.warning(
|
||||
'Ignoring custom metadata key %s because its value is None.',
|
||||
key,
|
||||
)
|
||||
return None
|
||||
return {'string_value': str(value)}
|
||||
|
||||
@@ -373,6 +373,46 @@ class TestCallbackContextAddSessionToMemory:
|
||||
await context.add_session_to_memory()
|
||||
|
||||
|
||||
class TestCallbackContextAddEventsToMemory:
|
||||
"""Tests add_events_to_memory in CallbackContext."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_success(self, mock_invocation_context):
|
||||
"""Tests that add_events_to_memory calls the memory service correctly."""
|
||||
memory_service = AsyncMock()
|
||||
mock_invocation_context.memory_service = memory_service
|
||||
test_event = MagicMock()
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
await context.add_events_to_memory(
|
||||
events=[test_event],
|
||||
custom_metadata={"ttl": "6000s"},
|
||||
)
|
||||
|
||||
memory_service.add_events_to_memory.assert_called_once_with(
|
||||
app_name=mock_invocation_context.session.app_name,
|
||||
user_id=mock_invocation_context.session.user_id,
|
||||
session_id=mock_invocation_context.session.id,
|
||||
events=[test_event],
|
||||
custom_metadata={"ttl": "6000s"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_no_service_raises(
|
||||
self, mock_invocation_context
|
||||
):
|
||||
"""Tests that add_events_to_memory raises ValueError with no service."""
|
||||
mock_invocation_context.memory_service = None
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Cannot add events to memory: memory service is not available\.",
|
||||
):
|
||||
await context.add_events_to_memory(events=[MagicMock()])
|
||||
|
||||
|
||||
class TestToolContextAddSessionToMemory:
|
||||
"""Test the add_session_to_memory method in ToolContext."""
|
||||
|
||||
|
||||
@@ -118,6 +118,116 @@ async def test_add_session_to_memory():
|
||||
assert session_memory[MOCK_SESSION_1.id][1].id == 'event-1c'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_with_explicit_events():
|
||||
"""Tests that add_events_to_memory can ingest an explicit event list."""
|
||||
memory_service = InMemoryMemoryService()
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION_1.app_name,
|
||||
user_id=MOCK_SESSION_1.user_id,
|
||||
session_id=MOCK_SESSION_1.id,
|
||||
events=[MOCK_SESSION_1.events[0]],
|
||||
)
|
||||
|
||||
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
|
||||
session_memory = memory_service._session_events[user_key]
|
||||
assert len(session_memory[MOCK_SESSION_1.id]) == 1
|
||||
assert session_memory[MOCK_SESSION_1.id][0].id == 'event-1a'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_without_session_id_uses_default_bucket():
|
||||
"""Tests add_events_to_memory when no session_id is provided."""
|
||||
memory_service = InMemoryMemoryService()
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION_1.app_name,
|
||||
user_id=MOCK_SESSION_1.user_id,
|
||||
events=[MOCK_SESSION_1.events[0]],
|
||||
)
|
||||
|
||||
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
|
||||
session_memory = memory_service._session_events[user_key]
|
||||
assert len(session_memory) == 1
|
||||
unknown_session_events = next(iter(session_memory.values()))
|
||||
assert len(unknown_session_events) == 1
|
||||
assert unknown_session_events[0].id == 'event-1a'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_alias_is_supported():
|
||||
"""Tests that add_events_to_memory remains a compatibility alias."""
|
||||
memory_service = InMemoryMemoryService()
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION_1.app_name,
|
||||
user_id=MOCK_SESSION_1.user_id,
|
||||
session_id=MOCK_SESSION_1.id,
|
||||
events=[MOCK_SESSION_1.events[0]],
|
||||
)
|
||||
|
||||
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
|
||||
session_memory = memory_service._session_events[user_key]
|
||||
assert [event.id for event in session_memory[MOCK_SESSION_1.id]] == [
|
||||
'event-1a'
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_appends_without_replacing():
|
||||
"""Tests that add_events_to_memory appends events rather than replacing."""
|
||||
memory_service = InMemoryMemoryService()
|
||||
await memory_service.add_session_to_memory(MOCK_SESSION_1)
|
||||
|
||||
new_event = Event(
|
||||
id='event-1d',
|
||||
invocation_id='inv-6',
|
||||
author='user',
|
||||
timestamp=12348,
|
||||
content=types.Content(parts=[types.Part(text='A new fact.')]),
|
||||
)
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION_1.app_name,
|
||||
user_id=MOCK_SESSION_1.user_id,
|
||||
session_id=MOCK_SESSION_1.id,
|
||||
events=[new_event],
|
||||
)
|
||||
|
||||
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
|
||||
session_memory = memory_service._session_events[user_key]
|
||||
assert [event.id for event in session_memory[MOCK_SESSION_1.id]] == [
|
||||
'event-1a',
|
||||
'event-1c',
|
||||
'event-1d',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_deduplicates_event_ids():
|
||||
"""Tests that duplicate event IDs are not appended multiple times."""
|
||||
memory_service = InMemoryMemoryService()
|
||||
await memory_service.add_session_to_memory(MOCK_SESSION_1)
|
||||
|
||||
duplicate_event = Event(
|
||||
id='event-1a',
|
||||
invocation_id='inv-7',
|
||||
author='user',
|
||||
timestamp=12349,
|
||||
content=types.Content(parts=[types.Part(text='Updated duplicate text.')]),
|
||||
)
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION_1.app_name,
|
||||
user_id=MOCK_SESSION_1.user_id,
|
||||
session_id=MOCK_SESSION_1.id,
|
||||
events=[duplicate_event],
|
||||
)
|
||||
|
||||
user_key = f'{MOCK_APP_NAME}/{MOCK_USER_ID}'
|
||||
session_memory = memory_service._session_events[user_key]
|
||||
assert [event.id for event in session_memory[MOCK_SESSION_1.id]] == [
|
||||
'event-1a',
|
||||
'event-1c',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_session_with_no_events_to_memory():
|
||||
"""Tests that adding a session with no events does not cause an error."""
|
||||
|
||||
@@ -23,11 +23,19 @@ from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankSe
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai import types
|
||||
import pytest
|
||||
from vertexai._genai.types import common as vertex_common_types
|
||||
|
||||
MOCK_APP_NAME = 'test-app'
|
||||
MOCK_USER_ID = 'test-user'
|
||||
|
||||
|
||||
def _supports_generate_memories_metadata() -> bool:
|
||||
return (
|
||||
'metadata'
|
||||
in vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields
|
||||
)
|
||||
|
||||
|
||||
class _AsyncListIterator:
|
||||
"""Minimal async iterator wrapper for list-like results."""
|
||||
|
||||
@@ -156,6 +164,217 @@ async def test_add_session_to_memory(mock_vertexai_client):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_with_explicit_events_and_metadata(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
session_id=MOCK_SESSION.id,
|
||||
events=[MOCK_SESSION.events[0]],
|
||||
custom_metadata={'ttl': '6000s', 'source': 'agent'},
|
||||
)
|
||||
|
||||
expected_config = {
|
||||
'wait_for_completion': False,
|
||||
'revision_ttl': '6000s',
|
||||
}
|
||||
if _supports_generate_memories_metadata():
|
||||
expected_config['metadata'] = {'source': {'string_value': 'agent'}}
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
|
||||
name='reasoningEngines/123',
|
||||
direct_contents_source={
|
||||
'events': [
|
||||
{
|
||||
'content': {
|
||||
'parts': [{'text': 'test_content'}],
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
|
||||
config=expected_config,
|
||||
)
|
||||
generate_config = (
|
||||
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
|
||||
'config'
|
||||
]
|
||||
)
|
||||
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_without_session_id(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
events=[MOCK_SESSION.events[0]],
|
||||
)
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
|
||||
name='reasoningEngines/123',
|
||||
direct_contents_source={
|
||||
'events': [
|
||||
{
|
||||
'content': {
|
||||
'parts': [{'text': 'test_content'}],
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
|
||||
config={'wait_for_completion': False},
|
||||
)
|
||||
generate_config = (
|
||||
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
|
||||
'config'
|
||||
]
|
||||
)
|
||||
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_merges_metadata_field_and_unknown_keys(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
session_id=MOCK_SESSION.id,
|
||||
events=[MOCK_SESSION.events[0]],
|
||||
custom_metadata={
|
||||
'metadata': {'origin': 'unit-test'},
|
||||
'source': 'agent',
|
||||
},
|
||||
)
|
||||
|
||||
expected_config = {'wait_for_completion': False}
|
||||
if _supports_generate_memories_metadata():
|
||||
expected_config['metadata'] = {
|
||||
'origin': {'string_value': 'unit-test'},
|
||||
'source': {'string_value': 'agent'},
|
||||
}
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
|
||||
name='reasoningEngines/123',
|
||||
direct_contents_source={
|
||||
'events': [
|
||||
{
|
||||
'content': {
|
||||
'parts': [{'text': 'test_content'}],
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
|
||||
config=expected_config,
|
||||
)
|
||||
generate_config = (
|
||||
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
|
||||
'config'
|
||||
]
|
||||
)
|
||||
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_none_wait_for_completion_keeps_default(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
session_id=MOCK_SESSION.id,
|
||||
events=[MOCK_SESSION.events[0]],
|
||||
custom_metadata={'wait_for_completion': None},
|
||||
)
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
|
||||
name='reasoningEngines/123',
|
||||
direct_contents_source={
|
||||
'events': [
|
||||
{
|
||||
'content': {
|
||||
'parts': [{'text': 'test_content'}],
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
|
||||
config={'wait_for_completion': False},
|
||||
)
|
||||
generate_config = (
|
||||
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
|
||||
'config'
|
||||
]
|
||||
)
|
||||
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_ttl_used_when_revision_ttl_is_none(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
session_id=MOCK_SESSION.id,
|
||||
events=[MOCK_SESSION.events[0]],
|
||||
custom_metadata={
|
||||
'ttl': '6000s',
|
||||
'revision_ttl': None,
|
||||
},
|
||||
)
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
|
||||
name='reasoningEngines/123',
|
||||
direct_contents_source={
|
||||
'events': [
|
||||
{
|
||||
'content': {
|
||||
'parts': [{'text': 'test_content'}],
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
|
||||
config={
|
||||
'wait_for_completion': False,
|
||||
'revision_ttl': '6000s',
|
||||
},
|
||||
)
|
||||
generate_config = (
|
||||
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
|
||||
'config'
|
||||
]
|
||||
)
|
||||
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_events_to_memory_with_filtered_events_skips_rpc(
|
||||
mock_vertexai_client,
|
||||
):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
await memory_service.add_events_to_memory(
|
||||
app_name=MOCK_SESSION.app_name,
|
||||
user_id=MOCK_SESSION.user_id,
|
||||
session_id=MOCK_SESSION.id,
|
||||
events=[MOCK_SESSION.events[1], MOCK_SESSION.events[2]],
|
||||
)
|
||||
|
||||
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_empty_session_to_memory(mock_vertexai_client):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
|
||||
Reference in New Issue
Block a user