From abc89d2c811ba00805f81b27a3a07d56bdf55a0b Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 24 Jun 2025 11:56:28 -0700 Subject: [PATCH] feat: Add implementation of VertexAiMemoryBankService and support in FastAPI endpoint PiperOrigin-RevId: 775327151 --- src/google/adk/cli/cli_tools_click.py | 3 +- src/google/adk/cli/fast_api.py | 11 ++ src/google/adk/memory/__init__.py | 4 +- .../memory/vertex_ai_memory_bank_service.py | 147 ++++++++++++++++ .../test_vertex_ai_memory_bank_service.py | 158 ++++++++++++++++++ 5 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 src/google/adk/memory/vertex_ai_memory_bank_service.py create mode 100644 tests/unittests/memory/test_vertex_ai_memory_bank_service.py diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 9923b46c..c0935cce 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -489,7 +489,8 @@ def adk_services_options(): type=str, help=( """Optional. The URI of the memory service. - - Use 'rag://' to connect to Vertex AI Rag Memory Service.""" + - Use 'rag://' to connect to Vertex AI Rag Memory Service. + - Use 'agentengine://' to connect to Vertex AI Memory Bank Service. e.g. agentengine://12345""" ), default=None, ) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4b2ed6c2..abe1961e 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -71,6 +71,7 @@ from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManag from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..events.event import Event from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService from ..runners import Runner from ..sessions.database_session_service import DatabaseSessionService @@ -282,6 +283,16 @@ def get_fast_api_app( memory_service = VertexAiRagMemoryService( rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}' ) + elif memory_service_uri.startswith("agentengine://"): + agent_engine_id = memory_service_uri.split("://")[1] + if not agent_engine_id: + raise click.ClickException("Agent engine id can not be empty.") + envs.load_dotenv_for_agent("", agents_dir) + memory_service = VertexAiMemoryBankService( + project=os.environ["GOOGLE_CLOUD_PROJECT"], + location=os.environ["GOOGLE_CLOUD_LOCATION"], + agent_engine_id=agent_engine_id, + ) else: raise click.ClickException( "Unsupported memory service URI: %s" % memory_service_uri diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index f2ac4f9b..915d7e51 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -15,12 +15,14 @@ import logging from .base_memory_service import BaseMemoryService from .in_memory_memory_service import InMemoryMemoryService +from .vertex_ai_memory_bank_service import VertexAiMemoryBankService logger = logging.getLogger('google_adk.' + __name__) __all__ = [ 'BaseMemoryService', 'InMemoryMemoryService', + 'VertexAiMemoryBankService', ] try: @@ -29,7 +31,7 @@ try: __all__.append('VertexAiRagMemoryService') except ImportError: logger.debug( - 'The Vertex sdk is not installed. If you want to use the' + 'The Vertex SDK is not installed. If you want to use the' ' VertexAiRagMemoryService please install it. If not, you can ignore this' ' warning.' ) diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py new file mode 100644 index 00000000..b5b70ab1 --- /dev/null +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import logging +from typing import Optional +from typing import TYPE_CHECKING + +from typing_extensions import override + +from google import genai + +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..sessions.session import Session + +logger = logging.getLogger('google_adk.' + __name__) + + +class VertexAiMemoryBankService(BaseMemoryService): + """Implementation of the BaseMemoryService using Vertex AI Memory Bank.""" + + def __init__( + self, + project: Optional[str] = None, + location: Optional[str] = None, + agent_engine_id: Optional[str] = None, + ): + """Initializes a VertexAiMemoryBankService. + + Args: + project: The project ID of the Memory Bank to use. + location: The location of the Memory Bank to use. + agent_engine_id: The ID of the agent engine to use for the Memory Bank. + e.g. '456' in + 'projects/my-project/locations/us-central1/reasoningEngines/456'. + """ + self._project = project + self._location = location + self._agent_engine_id = agent_engine_id + + @override + async def add_session_to_memory(self, session: Session): + api_client = self._get_api_client() + + if not self._agent_engine_id: + raise ValueError('Agent Engine ID is required for Memory Bank.') + + events = [] + for event in session.events: + if event.content and event.content.parts: + events.append({ + 'content': event.content.model_dump(exclude_none=True, mode='json') + }) + request_dict = { + 'direct_contents_source': { + 'events': events, + }, + 'scope': { + 'app_name': session.app_name, + 'user_id': session.user_id, + }, + } + + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:generate', + request_dict=request_dict, + ) + logger.info(f'Generate memory response: {api_response}') + + @override + async def search_memory(self, *, app_name: str, user_id: str, query: str): + api_client = self._get_api_client() + + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve', + request_dict={ + 'scope': { + 'app_name': app_name, + 'user_id': user_id, + }, + 'similarity_search_params': { + 'search_query': query, + }, + }, + ) + api_response = _convert_api_response(api_response) + logger.info(f'Search memory response: {api_response}') + + if not api_response or not api_response.get('retrievedMemories', None): + return SearchMemoryResponse() + + memory_events = [] + for memory in api_response.get('retrievedMemories', []): + # TODO: add more complex error handling + memory_events.append( + MemoryEntry( + author='user', + content=genai.types.Content( + parts=[ + genai.types.Part(text=memory.get('memory').get('fact')) + ], + role='user', + ), + timestamp=memory.get('updateTime'), + ) + ) + return SearchMemoryResponse(memories=memory_events) + + def _get_api_client(self): + """Instantiates an API client for the given project and location. + + It needs to be instantiated inside each request so that the event loop + management can be properly propagated. + + Returns: + An API client for the given project and location. + """ + client = genai.Client( + vertexai=True, project=self._project, location=self._location + ) + return client._api_client + + +def _convert_api_response(api_response): + """Converts the API response to a JSON object based on the type.""" + if hasattr(api_response, 'body'): + return json.loads(api_response.body) + return api_response diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py new file mode 100644 index 00000000..27e2bbdd --- /dev/null +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -0,0 +1,158 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any +from unittest import mock + +from google.adk.events import Event +from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService +from google.adk.sessions import Session +from google.genai import types +import pytest + +MOCK_APP_NAME = 'test-app' +MOCK_USER_ID = 'test-user' + +MOCK_SESSION = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='333', + last_update_time=22333, + events=[ + Event( + id='444', + invocation_id='123', + author='user', + timestamp=12345, + content=types.Content(parts=[types.Part(text='test_content')]), + ), + # Empty event, should be ignored + Event( + id='555', + invocation_id='456', + author='user', + timestamp=12345, + ), + ], +) + + +RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$' +GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$' + + +class MockApiClient: + """Mocks the API Client.""" + + def __init__(self) -> None: + """Initializes MockClient.""" + self.async_request = mock.AsyncMock() + self.async_request.side_effect = self._mock_async_request + + async def _mock_async_request( + self, http_method: str, path: str, request_dict: dict[str, Any] + ): + """Mocks the API Client request method.""" + if http_method == 'POST': + if re.match(GENERATE_MEMORIES_REGEX, path): + return {} + elif re.match(RETRIEVE_MEMORIES_REGEX, path): + if ( + request_dict.get('scope', None) + and request_dict['scope'].get('app_name', None) == MOCK_APP_NAME + ): + return { + 'retrievedMemories': [ + { + 'memory': { + 'fact': 'test_content', + }, + 'updateTime': '2024-12-12T12:12:12.123456Z', + }, + ], + } + else: + return {'retrievedMemories': []} + else: + raise ValueError(f'Unsupported path: {path}') + else: + raise ValueError(f'Unsupported http method: {http_method}') + + +def mock_vertex_ai_memory_bank_service(): + """Creates a mock Vertex AI Memory Bank service for testing.""" + return VertexAiMemoryBankService( + project='test-project', + location='test-location', + agent_engine_id='123', + ) + + +@pytest.fixture +def mock_get_api_client(): + api_client = MockApiClient() + with mock.patch( + 'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client', + return_value=api_client, + ): + yield api_client + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_add_session_to_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_session_to_memory(MOCK_SESSION) + + mock_get_api_client.async_request.assert_awaited_once_with( + http_method='POST', + path='reasoningEngines/123/memories:generate', + request_dict={ + 'direct_contents_source': { + 'events': [ + { + 'content': { + 'parts': [ + {'text': 'test_content'}, + ], + }, + }, + ], + }, + 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + }, + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_search_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + mock_get_api_client.async_request.assert_awaited_once_with( + http_method='POST', + path='reasoningEngines/123/memories:retrieve', + request_dict={ + 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + 'similarity_search_params': {'search_query': 'query'}, + }, + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'test_content'