feat: Migrate VertexAiMemoryBankService to use Agent Engine SDK

PiperOrigin-RevId: 813104746
This commit is contained in:
Shangjie Chen
2025-09-29 23:14:04 -07:00
committed by Copybara-Service
parent ce9c39f5a8
commit 83fd045718
3 changed files with 85 additions and 133 deletions
+1 -1
View File
@@ -32,7 +32,7 @@ dependencies = [
"click>=8.1.8, <9.0.0", # For CLI tools
"fastapi>=0.115.0, <1.0.0", # FastAPI framework
"google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery
"google-cloud-aiplatform[agent_engines]>=1.95.1, <2.0.0", # For VertexAI integrations, e.g. example store.
"google-cloud-aiplatform[agent_engines]>=1.112.0, <2.0.0",# For VertexAI integrations, e.g. example store.
"google-cloud-bigtable>=2.32.0", # For Bigtable database
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
@@ -14,16 +14,13 @@
from __future__ import annotations
import json
import logging
from typing import Any
from typing import Dict
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import Client
from google.genai import types
from typing_extensions import override
import vertexai
from .base_memory_service import BaseMemoryService
from .base_memory_service import SearchMemoryResponse
@@ -59,8 +56,6 @@ class VertexAiMemoryBankService(BaseMemoryService):
@override
async def add_session_to_memory(self, session: Session):
api_client = self._get_api_client()
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')
@@ -72,62 +67,53 @@ class VertexAiMemoryBankService(BaseMemoryService):
events.append({
'content': event.content.model_dump(exclude_none=True, mode='json')
})
request_dict = {
'direct_contents_source': {
'events': events,
},
'scope': {
'app_name': session.app_name,
'user_id': session.user_id,
},
}
if events:
api_response = await api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
request_dict=request_dict,
client = self._get_api_client()
operation = client.agent_engines.memories.generate(
name='reasoningEngines/' + self._agent_engine_id,
direct_contents_source={'events': events},
scope={
'app_name': session.app_name,
'user_id': session.user_id,
},
config={'wait_for_completion': False},
)
logger.info('Generate memory response received.')
logger.debug('Generate memory response: %s', api_response)
logger.debug('Generate memory response: %s', operation)
else:
logger.info('No events to add to memory.')
@override
async def search_memory(self, *, app_name: str, user_id: str, query: str):
api_client = self._get_api_client()
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')
api_response = await api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve',
request_dict={
'scope': {
'app_name': app_name,
'user_id': user_id,
},
'similarity_search_params': {
'search_query': query,
},
client = self._get_api_client()
retrieved_memories_iterator = client.agent_engines.memories.retrieve(
name='reasoningEngines/' + self._agent_engine_id,
scope={
'app_name': app_name,
'user_id': user_id,
},
similarity_search_params={
'search_query': query,
},
)
api_response = _convert_api_response(api_response)
logger.info('Search memory response received.')
logger.debug('Search memory response: %s', api_response)
if not api_response or not api_response.get('retrievedMemories', None):
return SearchMemoryResponse()
logger.info('Search memory response received.')
memory_events = []
for memory in api_response.get('retrievedMemories', []):
for retrieved_memory in retrieved_memories_iterator:
# TODO: add more complex error handling
logger.debug('Retrieved memory: %s', retrieved_memory)
memory_events.append(
MemoryEntry(
author='user',
content=types.Content(
parts=[types.Part(text=memory.get('memory').get('fact'))],
parts=[types.Part(text=retrieved_memory.memory.fact)],
role='user',
),
timestamp=memory.get('updateTime'),
timestamp=retrieved_memory.memory.update_time.isoformat(),
)
)
return SearchMemoryResponse(memories=memory_events)
@@ -141,17 +127,7 @@ class VertexAiMemoryBankService(BaseMemoryService):
Returns:
An API client for the given project and location.
"""
client = Client(
vertexai=True, project=self._project, location=self._location
)
return client._api_client
def _convert_api_response(api_response) -> Dict[str, Any]:
"""Converts the API response to a JSON object based on the type."""
if hasattr(api_response, 'body'):
return json.loads(api_response.body)
return api_response
return vertexai.Client(project=self._project, location=self._location)
def _should_filter_out_event(content: types.Content) -> bool:
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any
from datetime import datetime
from unittest import mock
from google.adk.events.event import Event
@@ -70,48 +69,6 @@ MOCK_SESSION_WITH_EMPTY_EVENTS = Session(
)
RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$'
GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$'
class MockApiClient:
"""Mocks the API Client."""
def __init__(self) -> None:
"""Initializes MockClient."""
self.async_request = mock.AsyncMock()
self.async_request.side_effect = self._mock_async_request
async def _mock_async_request(
self, http_method: str, path: str, request_dict: dict[str, Any]
):
"""Mocks the API Client request method."""
if http_method == 'POST':
if re.match(GENERATE_MEMORIES_REGEX, path):
return {}
elif re.match(RETRIEVE_MEMORIES_REGEX, path):
if (
request_dict.get('scope', None)
and request_dict['scope'].get('app_name', None) == MOCK_APP_NAME
):
return {
'retrievedMemories': [
{
'memory': {
'fact': 'test_content',
},
'updateTime': '2024-12-12T12:12:12.123456Z',
},
],
}
else:
return {'retrievedMemories': []}
else:
raise ValueError(f'Unsupported path: {path}')
else:
raise ValueError(f'Unsupported http method: {http_method}')
def mock_vertex_ai_memory_bank_service():
"""Creates a mock Vertex AI Memory Bank service for testing."""
return VertexAiMemoryBankService(
@@ -122,67 +79,86 @@ def mock_vertex_ai_memory_bank_service():
@pytest.fixture
def mock_get_api_client():
api_client = MockApiClient()
def mock_vertexai_client():
with mock.patch(
'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client',
return_value=api_client,
):
yield api_client
'google.adk.memory.vertex_ai_memory_bank_service.vertexai.Client'
) as mock_client_constructor:
mock_client = mock.MagicMock()
mock_client.agent_engines.memories.generate = mock.MagicMock()
mock_client.agent_engines.memories.retrieve = mock.MagicMock()
mock_client_constructor.return_value = mock_client
yield mock_client
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_add_session_to_memory(mock_get_api_client):
async def test_add_session_to_memory(mock_vertexai_client):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_session_to_memory(MOCK_SESSION)
mock_get_api_client.async_request.assert_awaited_once_with(
http_method='POST',
path='reasoningEngines/123/memories:generate',
request_dict={
'direct_contents_source': {
'events': [
{
'content': {
'parts': [
{'text': 'test_content'},
],
},
},
],
},
'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
name='reasoningEngines/123',
direct_contents_source={
'events': [
{
'content': {
'parts': [{'text': 'test_content'}],
}
}
]
},
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
config={'wait_for_completion': False},
)
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_add_empty_session_to_memory(mock_get_api_client):
async def test_add_empty_session_to_memory(mock_vertexai_client):
memory_service = mock_vertex_ai_memory_bank_service()
await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS)
mock_get_api_client.async_request.assert_not_called()
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_search_memory(mock_get_api_client):
async def test_search_memory(mock_vertexai_client):
retrieved_memory = mock.MagicMock()
retrieved_memory.memory.fact = 'test_content'
retrieved_memory.memory.update_time = datetime(
2024, 12, 12, 12, 12, 12, 123456
)
mock_vertexai_client.agent_engines.memories.retrieve.return_value = [
retrieved_memory
]
memory_service = mock_vertex_ai_memory_bank_service()
result = await memory_service.search_memory(
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
)
mock_get_api_client.async_request.assert_awaited_once_with(
http_method='POST',
path='reasoningEngines/123/memories:retrieve',
request_dict={
'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
'similarity_search_params': {'search_query': 'query'},
},
mock_vertexai_client.agent_engines.memories.retrieve.assert_called_once_with(
name='reasoningEngines/123',
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
similarity_search_params={'search_query': 'query'},
)
assert len(result.memories) == 1
assert result.memories[0].content.parts[0].text == 'test_content'
@pytest.mark.asyncio
async def test_search_memory_empty_results(mock_vertexai_client):
mock_vertexai_client.agent_engines.memories.retrieve.return_value = []
memory_service = mock_vertex_ai_memory_bank_service()
result = await memory_service.search_memory(
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
)
mock_vertexai_client.agent_engines.memories.retrieve.assert_called_once_with(
name='reasoningEngines/123',
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
similarity_search_params={'search_query': 'query'},
)
assert len(result.memories) == 0