diff --git a/pyproject.toml b/pyproject.toml index 7550349c..4559ff45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "click>=8.1.8, <9.0.0", # For CLI tools "fastapi>=0.115.0, <1.0.0", # FastAPI framework "google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery - "google-cloud-aiplatform[agent_engines]>=1.95.1, <2.0.0", # For VertexAI integrations, e.g. example store. + "google-cloud-aiplatform[agent_engines]>=1.112.0, <2.0.0",# For VertexAI integrations, e.g. example store. "google-cloud-bigtable>=2.32.0", # For Bigtable database "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool "google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 69629eb9..03f4c392 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -14,16 +14,13 @@ from __future__ import annotations -import json import logging -from typing import Any -from typing import Dict from typing import Optional from typing import TYPE_CHECKING -from google.genai import Client from google.genai import types from typing_extensions import override +import vertexai from .base_memory_service import BaseMemoryService from .base_memory_service import SearchMemoryResponse @@ -59,8 +56,6 @@ class VertexAiMemoryBankService(BaseMemoryService): @override async def add_session_to_memory(self, session: Session): - api_client = self._get_api_client() - if not self._agent_engine_id: raise ValueError('Agent Engine ID is required for Memory Bank.') @@ -72,62 +67,53 @@ class VertexAiMemoryBankService(BaseMemoryService): events.append({ 'content': event.content.model_dump(exclude_none=True, mode='json') }) - request_dict = { - 'direct_contents_source': { - 'events': events, - }, - 'scope': { - 'app_name': session.app_name, - 'user_id': session.user_id, - }, - } - if events: - api_response = await api_client.async_request( - http_method='POST', - path=f'reasoningEngines/{self._agent_engine_id}/memories:generate', - request_dict=request_dict, + client = self._get_api_client() + operation = client.agent_engines.memories.generate( + name='reasoningEngines/' + self._agent_engine_id, + direct_contents_source={'events': events}, + scope={ + 'app_name': session.app_name, + 'user_id': session.user_id, + }, + config={'wait_for_completion': False}, ) logger.info('Generate memory response received.') - logger.debug('Generate memory response: %s', api_response) + logger.debug('Generate memory response: %s', operation) else: logger.info('No events to add to memory.') @override async def search_memory(self, *, app_name: str, user_id: str, query: str): - api_client = self._get_api_client() + if not self._agent_engine_id: + raise ValueError('Agent Engine ID is required for Memory Bank.') - api_response = await api_client.async_request( - http_method='POST', - path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve', - request_dict={ - 'scope': { - 'app_name': app_name, - 'user_id': user_id, - }, - 'similarity_search_params': { - 'search_query': query, - }, + client = self._get_api_client() + retrieved_memories_iterator = client.agent_engines.memories.retrieve( + name='reasoningEngines/' + self._agent_engine_id, + scope={ + 'app_name': app_name, + 'user_id': user_id, + }, + similarity_search_params={ + 'search_query': query, }, ) - api_response = _convert_api_response(api_response) - logger.info('Search memory response received.') - logger.debug('Search memory response: %s', api_response) - if not api_response or not api_response.get('retrievedMemories', None): - return SearchMemoryResponse() + logger.info('Search memory response received.') memory_events = [] - for memory in api_response.get('retrievedMemories', []): + for retrieved_memory in retrieved_memories_iterator: # TODO: add more complex error handling + logger.debug('Retrieved memory: %s', retrieved_memory) memory_events.append( MemoryEntry( author='user', content=types.Content( - parts=[types.Part(text=memory.get('memory').get('fact'))], + parts=[types.Part(text=retrieved_memory.memory.fact)], role='user', ), - timestamp=memory.get('updateTime'), + timestamp=retrieved_memory.memory.update_time.isoformat(), ) ) return SearchMemoryResponse(memories=memory_events) @@ -141,17 +127,7 @@ class VertexAiMemoryBankService(BaseMemoryService): Returns: An API client for the given project and location. """ - client = Client( - vertexai=True, project=self._project, location=self._location - ) - return client._api_client - - -def _convert_api_response(api_response) -> Dict[str, Any]: - """Converts the API response to a JSON object based on the type.""" - if hasattr(api_response, 'body'): - return json.loads(api_response.body) - return api_response + return vertexai.Client(project=self._project, location=self._location) def _should_filter_out_event(content: types.Content) -> bool: diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index 2916b442..c47023df 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re -from typing import Any +from datetime import datetime from unittest import mock from google.adk.events.event import Event @@ -70,48 +69,6 @@ MOCK_SESSION_WITH_EMPTY_EVENTS = Session( ) -RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$' -GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$' - - -class MockApiClient: - """Mocks the API Client.""" - - def __init__(self) -> None: - """Initializes MockClient.""" - self.async_request = mock.AsyncMock() - self.async_request.side_effect = self._mock_async_request - - async def _mock_async_request( - self, http_method: str, path: str, request_dict: dict[str, Any] - ): - """Mocks the API Client request method.""" - if http_method == 'POST': - if re.match(GENERATE_MEMORIES_REGEX, path): - return {} - elif re.match(RETRIEVE_MEMORIES_REGEX, path): - if ( - request_dict.get('scope', None) - and request_dict['scope'].get('app_name', None) == MOCK_APP_NAME - ): - return { - 'retrievedMemories': [ - { - 'memory': { - 'fact': 'test_content', - }, - 'updateTime': '2024-12-12T12:12:12.123456Z', - }, - ], - } - else: - return {'retrievedMemories': []} - else: - raise ValueError(f'Unsupported path: {path}') - else: - raise ValueError(f'Unsupported http method: {http_method}') - - def mock_vertex_ai_memory_bank_service(): """Creates a mock Vertex AI Memory Bank service for testing.""" return VertexAiMemoryBankService( @@ -122,67 +79,86 @@ def mock_vertex_ai_memory_bank_service(): @pytest.fixture -def mock_get_api_client(): - api_client = MockApiClient() +def mock_vertexai_client(): with mock.patch( - 'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client', - return_value=api_client, - ): - yield api_client + 'google.adk.memory.vertex_ai_memory_bank_service.vertexai.Client' + ) as mock_client_constructor: + mock_client = mock.MagicMock() + mock_client.agent_engines.memories.generate = mock.MagicMock() + mock_client.agent_engines.memories.retrieve = mock.MagicMock() + mock_client_constructor.return_value = mock_client + yield mock_client @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') -async def test_add_session_to_memory(mock_get_api_client): +async def test_add_session_to_memory(mock_vertexai_client): memory_service = mock_vertex_ai_memory_bank_service() await memory_service.add_session_to_memory(MOCK_SESSION) - mock_get_api_client.async_request.assert_awaited_once_with( - http_method='POST', - path='reasoningEngines/123/memories:generate', - request_dict={ - 'direct_contents_source': { - 'events': [ - { - 'content': { - 'parts': [ - {'text': 'test_content'}, - ], - }, - }, - ], - }, - 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_contents_source={ + 'events': [ + { + 'content': { + 'parts': [{'text': 'test_content'}], + } + } + ] }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, ) @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') -async def test_add_empty_session_to_memory(mock_get_api_client): +async def test_add_empty_session_to_memory(mock_vertexai_client): memory_service = mock_vertex_ai_memory_bank_service() await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS) - mock_get_api_client.async_request.assert_not_called() + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') -async def test_search_memory(mock_get_api_client): +async def test_search_memory(mock_vertexai_client): + retrieved_memory = mock.MagicMock() + retrieved_memory.memory.fact = 'test_content' + retrieved_memory.memory.update_time = datetime( + 2024, 12, 12, 12, 12, 12, 123456 + ) + + mock_vertexai_client.agent_engines.memories.retrieve.return_value = [ + retrieved_memory + ] memory_service = mock_vertex_ai_memory_bank_service() result = await memory_service.search_memory( app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' ) - mock_get_api_client.async_request.assert_awaited_once_with( - http_method='POST', - path='reasoningEngines/123/memories:retrieve', - request_dict={ - 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, - 'similarity_search_params': {'search_query': 'query'}, - }, + mock_vertexai_client.agent_engines.memories.retrieve.assert_called_once_with( + name='reasoningEngines/123', + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + similarity_search_params={'search_query': 'query'}, ) assert len(result.memories) == 1 assert result.memories[0].content.parts[0].text == 'test_content' + + +@pytest.mark.asyncio +async def test_search_memory_empty_results(mock_vertexai_client): + mock_vertexai_client.agent_engines.memories.retrieve.return_value = [] + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + mock_vertexai_client.agent_engines.memories.retrieve.assert_called_once_with( + name='reasoningEngines/123', + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + similarity_search_params={'search_query': 'query'}, + ) + + assert len(result.memories) == 0