feat: Add implementation of VertexAiMemoryBankService and support in FastAPI endpoint

PiperOrigin-RevId: 775327151
This commit is contained in:
Shangjie Chen
2025-06-24 11:56:28 -07:00
committed by Copybara-Service
parent 00cc8cd643
commit abc89d2c81
5 changed files with 321 additions and 2 deletions
+2 -1
View File
@@ -489,7 +489,8 @@ def adk_services_options():
type=str,
help=(
"""Optional. The URI of the memory service.
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service."""
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Vertex AI Memory Bank Service. e.g. agentengine://12345"""
),
default=None,
)
+11
View File
@@ -71,6 +71,7 @@ from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManag
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
from ..events.event import Event
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
from ..runners import Runner
from ..sessions.database_session_service import DatabaseSessionService
@@ -282,6 +283,16 @@ def get_fast_api_app(
memory_service = VertexAiRagMemoryService(
rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}'
)
elif memory_service_uri.startswith("agentengine://"):
agent_engine_id = memory_service_uri.split("://")[1]
if not agent_engine_id:
raise click.ClickException("Agent engine id can not be empty.")
envs.load_dotenv_for_agent("", agents_dir)
memory_service = VertexAiMemoryBankService(
project=os.environ["GOOGLE_CLOUD_PROJECT"],
location=os.environ["GOOGLE_CLOUD_LOCATION"],
agent_engine_id=agent_engine_id,
)
else:
raise click.ClickException(
"Unsupported memory service URI: %s" % memory_service_uri
+3 -1
View File
@@ -15,12 +15,14 @@ import logging
from .base_memory_service import BaseMemoryService
from .in_memory_memory_service import InMemoryMemoryService
from .vertex_ai_memory_bank_service import VertexAiMemoryBankService
logger = logging.getLogger('google_adk.' + __name__)
__all__ = [
'BaseMemoryService',
'InMemoryMemoryService',
'VertexAiMemoryBankService',
]
try:
@@ -29,7 +31,7 @@ try:
__all__.append('VertexAiRagMemoryService')
except ImportError:
logger.debug(
'The Vertex sdk is not installed. If you want to use the'
'The Vertex SDK is not installed. If you want to use the'
' VertexAiRagMemoryService please install it. If not, you can ignore this'
' warning.'
)
@@ -0,0 +1,147 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import logging
from typing import Optional
from typing import TYPE_CHECKING
from typing_extensions import override
from google import genai
from .base_memory_service import BaseMemoryService
from .base_memory_service import SearchMemoryResponse
from .memory_entry import MemoryEntry
if TYPE_CHECKING:
from ..sessions.session import Session
logger = logging.getLogger('google_adk.' + __name__)
class VertexAiMemoryBankService(BaseMemoryService):
"""Implementation of the BaseMemoryService using Vertex AI Memory Bank."""
def __init__(
self,
project: Optional[str] = None,
location: Optional[str] = None,
agent_engine_id: Optional[str] = None,
):
"""Initializes a VertexAiMemoryBankService.
Args:
project: The project ID of the Memory Bank to use.
location: The location of the Memory Bank to use.
agent_engine_id: The ID of the agent engine to use for the Memory Bank.
e.g. '456' in
'projects/my-project/locations/us-central1/reasoningEngines/456'.
"""
self._project = project
self._location = location
self._agent_engine_id = agent_engine_id
@override
async def add_session_to_memory(self, session: Session):
api_client = self._get_api_client()
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')
events = []
for event in session.events:
if event.content and event.content.parts:
events.append({
'content': event.content.model_dump(exclude_none=True, mode='json')
})
request_dict = {
'direct_contents_source': {
'events': events,
},
'scope': {
'app_name': session.app_name,
'user_id': session.user_id,
},
}
api_response = await api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
request_dict=request_dict,
)
logger.info(f'Generate memory response: {api_response}')
@override
async def search_memory(self, *, app_name: str, user_id: str, query: str):
api_client = self._get_api_client()
api_response = await api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve',
request_dict={
'scope': {
'app_name': app_name,
'user_id': user_id,
},
'similarity_search_params': {
'search_query': query,
},
},
)
api_response = _convert_api_response(api_response)
logger.info(f'Search memory response: {api_response}')
if not api_response or not api_response.get('retrievedMemories', None):
return SearchMemoryResponse()
memory_events = []
for memory in api_response.get('retrievedMemories', []):
# TODO: add more complex error handling
memory_events.append(
MemoryEntry(
author='user',
content=genai.types.Content(
parts=[
genai.types.Part(text=memory.get('memory').get('fact'))
],
role='user',
),
timestamp=memory.get('updateTime'),
)
)
return SearchMemoryResponse(memories=memory_events)
def _get_api_client(self):
"""Instantiates an API client for the given project and location.
It needs to be instantiated inside each request so that the event loop
management can be properly propagated.
Returns:
An API client for the given project and location.
"""
client = genai.Client(
vertexai=True, project=self._project, location=self._location
)
return client._api_client
def _convert_api_response(api_response):
"""Converts the API response to a JSON object based on the type."""
if hasattr(api_response, 'body'):
return json.loads(api_response.body)
return api_response
@@ -0,0 +1,158 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any
from unittest import mock
from google.adk.events import Event
from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
from google.adk.sessions import Session
from google.genai import types
import pytest
MOCK_APP_NAME = 'test-app'
MOCK_USER_ID = 'test-user'
MOCK_SESSION = Session(
app_name=MOCK_APP_NAME,
user_id=MOCK_USER_ID,
id='333',
last_update_time=22333,
events=[
Event(
id='444',
invocation_id='123',
author='user',
timestamp=12345,
content=types.Content(parts=[types.Part(text='test_content')]),
),
# Empty event, should be ignored
Event(
id='555',
invocation_id='456',
author='user',
timestamp=12345,
),
],
)
RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$'
GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$'
class MockApiClient:
"""Mocks the API Client."""
def __init__(self) -> None:
"""Initializes MockClient."""
self.async_request = mock.AsyncMock()
self.async_request.side_effect = self._mock_async_request
async def _mock_async_request(
self, http_method: str, path: str, request_dict: dict[str, Any]
):
"""Mocks the API Client request method."""
if http_method == 'POST':
if re.match(GENERATE_MEMORIES_REGEX, path):
return {}
elif re.match(RETRIEVE_MEMORIES_REGEX, path):
if (
request_dict.get('scope', None)
and request_dict['scope'].get('app_name', None) == MOCK_APP_NAME
):
return {
'retrievedMemories': [
{
'memory': {
'fact': 'test_content',
},
'updateTime': '2024-12-12T12:12:12.123456Z',
},
],
}
else:
return {'retrievedMemories': []}
else:
raise ValueError(f'Unsupported path: {path}')
else:
raise ValueError(f'Unsupported http method: {http_method}')
def mock_vertex_ai_memory_bank_service():
"""Creates a mock Vertex AI Memory Bank service for testing."""
return VertexAiMemoryBankService(
project='test-project',
location='test-location',
agent_engine_id='123',
)
@pytest.fixture
def mock_get_api_client():
api_client = MockApiClient()
with mock.patch(
'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client',
return_value=api_client,
):
yield api_client
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_add_session_to_memory(mock_get_api_client):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_session_to_memory(MOCK_SESSION)
mock_get_api_client.async_request.assert_awaited_once_with(
http_method='POST',
path='reasoningEngines/123/memories:generate',
request_dict={
'direct_contents_source': {
'events': [
{
'content': {
'parts': [
{'text': 'test_content'},
],
},
},
],
},
'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
},
)
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_search_memory(mock_get_api_client):
memory_service = mock_vertex_ai_memory_bank_service()
result = await memory_service.search_memory(
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
)
mock_get_api_client.async_request.assert_awaited_once_with(
http_method='POST',
path='reasoningEngines/123/memories:retrieve',
request_dict={
'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
'similarity_search_params': {'search_query': 'query'},
},
)
assert len(result.memories) == 1
assert result.memories[0].content.parts[0].text == 'test_content'