From 90d4c19c5115c7af361effa8e12c248225ccf6ab Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 7 Oct 2025 09:57:53 -0700 Subject: [PATCH] feat: Migrate vertex ai session service to use agent engine sdk PiperOrigin-RevId: 816259798 --- src/google/adk/cli/fast_api.py | 4 +- .../adk/sessions/vertex_ai_session_service.py | 402 ++++++++---------- .../test_vertex_ai_session_service.py | 396 ++++++++++------- 3 files changed, 427 insertions(+), 375 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 02e617ed..326cab03 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -105,8 +105,8 @@ def get_fast_api_app( agent_engine_id = agent_engine_id_or_resource_name.split("/")[-1] else: envs.load_dotenv_for_agent("", agents_dir) - project = os.environ["GOOGLE_CLOUD_PROJECT"] - location = os.environ["GOOGLE_CLOUD_LOCATION"] + project = os.environ.get("GOOGLE_CLOUD_PROJECT", None) + location = os.environ.get("GOOGLE_CLOUD_LOCATION", None) agent_engine_id = agent_engine_id_or_resource_name return project, location, agent_engine_id diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 31dea1a0..def89b24 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -13,25 +13,23 @@ # limitations under the License. from __future__ import annotations -import json +import datetime import logging import os import re from typing import Any from typing import Dict from typing import Optional -import urllib.parse +from typing import Union -from dateutil import parser +from google.genai import types from google.genai.errors import ClientError from tenacity import retry from tenacity import retry_if_result -from tenacity import RetryError from tenacity import stop_after_attempt from tenacity import wait_exponential from typing_extensions import override - -from google import genai +import vertexai from . import _session_util from ..events.event import Event @@ -41,12 +39,11 @@ from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse from .session import Session -isoparse = parser.isoparse logger = logging.getLogger('google_adk.' + __name__) class VertexAiSessionService(BaseSessionService): - """Connects to the Vertex AI Agent Engine Session Service using GenAI API client. + """Connects to the Vertex AI Agent Engine Session Service using Agent Engine SDK. https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/sessions/overview """ @@ -68,20 +65,6 @@ class VertexAiSessionService(BaseSessionService): self._location = location self._agent_engine_id = agent_engine_id - async def _get_session_api_response( - self, - reasoning_engine_id: str, - session_id: str, - api_client: genai.ApiClient, - ): - get_session_api_response = await api_client.async_request( - http_method='GET', - path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', - request_dict={}, - ) - get_session_api_response = _convert_api_response(get_session_api_response) - return get_session_api_response - @override async def create_session( self, @@ -96,38 +79,35 @@ class VertexAiSessionService(BaseSessionService): 'User-provided Session id is not supported for' ' VertexAISessionService.' ) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) api_client = self._get_api_client() - session_json_dict = {'user_id': user_id} - if state: - session_json_dict['session_state'] = state + config = {'session_state': state} if state else {} - api_response = await api_client.async_request( - http_method='POST', - path=f'reasoningEngines/{reasoning_engine_id}/sessions', - request_dict=session_json_dict, - ) - api_response = _convert_api_response(api_response) - logger.info('Create session response received.') - logger.debug('Create session response: %s', api_response) - - session_id = api_response['name'].split('/')[-3] - operation_id = api_response['name'].split('/')[-1] if _is_vertex_express_mode(self._project, self._location): + config['wait_for_completion'] = False + api_response = api_client.agent_engines.sessions.create( + name=f'reasoningEngines/{reasoning_engine_id}', + user_id=user_id, + config=config, + ) + logger.info('Create session response received.') + session_id = api_response.name.split('/')[-3] + # Express mode doesn't support LRO, so we need to poll # the session resource. # TODO: remove this once LRO polling is supported in Express mode. @retry( - stop=stop_after_attempt(5), + stop=stop_after_attempt(6), wait=wait_exponential(multiplier=1, min=1, max=3), retry=retry_if_result(lambda response: not response), reraise=True, ) async def _poll_session_resource(): try: - return await self._get_session_api_response( - reasoning_engine_id, session_id, api_client + return api_client.agent_engines.sessions.get( + name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' ) except ClientError: logger.info('Polling session resource') @@ -137,45 +117,26 @@ class VertexAiSessionService(BaseSessionService): await _poll_session_resource() except Exception as exc: raise ValueError('Failed to create session.') from exc - else: - @retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=1, max=3), - retry=retry_if_result( - lambda response: not response.get('done', False), - ), - reraise=True, + get_session_response = api_client.agent_engines.sessions.get( + name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' ) - async def _poll_lro(): - lro_response = await api_client.async_request( - http_method='GET', - path=f'operations/{operation_id}', - request_dict={}, - ) - lro_response = _convert_api_response(lro_response) - return lro_response + else: + api_response = api_client.agent_engines.sessions.create( + name=f'reasoningEngines/{reasoning_engine_id}', + user_id=user_id, + config=config, + ) + logger.debug('Create session response: %s', api_response) + get_session_response = api_response.response + session_id = get_session_response.name.split('/')[-1] - try: - await _poll_lro() - except RetryError as exc: - raise TimeoutError( - f'Timeout waiting for operation {operation_id} to complete.' - ) from exc - except Exception as exc: - raise ValueError('Failed to create session.') from exc - - get_session_api_response = await self._get_session_api_response( - reasoning_engine_id, session_id, api_client - ) session = Session( app_name=str(app_name), user_id=str(user_id), id=str(session_id), - state=get_session_api_response.get('sessionState', {}), - last_update_time=isoparse( - get_session_api_response['updateTime'] - ).timestamp(), + state=getattr(get_session_response, 'session_state', None) or {}, + last_update_time=get_session_response.update_time.timestamp(), ) return session @@ -192,79 +153,49 @@ class VertexAiSessionService(BaseSessionService): api_client = self._get_api_client() # Get session resource - get_session_api_response = await self._get_session_api_response( - reasoning_engine_id, session_id, api_client + get_session_response = api_client.agent_engines.sessions.get( + name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', ) - if get_session_api_response['userId'] != user_id: - raise ValueError(f'Session not found: {session_id}') + if get_session_response.user_id != user_id: + raise ValueError( + f'Session {session_id} does not belong to user {user_id}.' + ) - session_id = get_session_api_response['name'].split('/')[-1] - update_timestamp = isoparse( - get_session_api_response['updateTime'] - ).timestamp() + session_id_from_name = get_session_response.name.split('/')[-1] + update_timestamp = get_session_response.update_time.timestamp() session = Session( app_name=str(app_name), user_id=str(user_id), - id=str(session_id), - state=get_session_api_response.get('sessionState', {}), + id=str(session_id_from_name), + state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=update_timestamp, ) - list_events_api_response = await api_client.async_request( - http_method='GET', - path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', - request_dict={}, + list_events_kwargs = {} + if config and not config.num_recent_events and config.after_timestamp: + list_events_kwargs['config'] = { + 'filter': 'timestamp>="{}"'.format( + datetime.datetime.fromtimestamp( + config.after_timestamp, tz=datetime.timezone.utc + ).isoformat() + ) + } + + events_iterator = api_client.agent_engines.sessions.events.list( + name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', + **list_events_kwargs, ) - converted_api_response = _convert_api_response(list_events_api_response) - - # Handles empty response case where there are no events to fetch - if not converted_api_response or converted_api_response.get( - 'httpHeaders', None - ): - return session - - session.events += [ - _from_api_event(event) - for event in converted_api_response['sessionEvents'] - ] - - while converted_api_response.get('nextPageToken', None): - page_token = converted_api_response.get('nextPageToken', None) - list_events_api_response = await api_client.async_request( - http_method='GET', - path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}', - request_dict={}, - ) - converted_api_response = _convert_api_response(list_events_api_response) - - # Handles empty response case where there are no more events to fetch - if not converted_api_response or converted_api_response.get( - 'httpHeaders', None - ): - break - session.events += [ - _from_api_event(event) - for event in converted_api_response['sessionEvents'] - ] + session.events += [_from_api_event(event) for event in events_iterator] session.events = [ event for event in session.events if event.timestamp <= update_timestamp ] - session.events.sort(key=lambda event: event.timestamp) # Filter events based on config if config: if config.num_recent_events: session.events = session.events[-config.num_recent_events :] - elif config.after_timestamp: - i = len(session.events) - 1 - while i >= 0: - if session.events[i].timestamp < config.after_timestamp: - break - i -= 1 - if i >= 0: - session.events = session.events[i:] return session @@ -275,46 +206,22 @@ class VertexAiSessionService(BaseSessionService): reasoning_engine_id = self._get_reasoning_engine_id(app_name) api_client = self._get_api_client() - base_path = f'reasoningEngines/{reasoning_engine_id}/sessions' sessions = [] - page_token = None - while True: - path = base_path - query_params = {} - if user_id: - query_params['filter'] = f'user_id="{user_id}"' - if page_token: - query_params['pageToken'] = page_token + sessions_iterator = api_client.agent_engines.sessions.list( + name=f'reasoningEngines/{reasoning_engine_id}', + config={'filter': f'user_id="{user_id}"'}, + ) - if query_params: - path = f'{path}?{urllib.parse.urlencode(query_params)}' - - list_sessions_api_response = await api_client.async_request( - http_method='GET', - path=path, - request_dict={}, + for api_session in sessions_iterator: + sessions.append( + Session( + app_name=app_name, + user_id=user_id, + id=api_session.name.split('/')[-1], + state=getattr(api_session, 'session_state', None) or {}, + last_update_time=api_session.update_time.timestamp(), + ) ) - converted_api_response = _convert_api_response(list_sessions_api_response) - - # Handles empty response case - if not converted_api_response or converted_api_response.get( - 'httpHeaders', None - ): - break - - for api_session in converted_api_response.get('sessions', []): - session = Session( - app_name=app_name, - user_id=user_id, - id=api_session['name'].split('/')[-1], - state=api_session.get('sessionState', {}), - last_update_time=isoparse(api_session['updateTime']).timestamp(), - ) - sessions.append(session) - - page_token = converted_api_response.get('nextPageToken') - if not page_token: - break return ListSessionsResponse(sessions=sessions) @@ -325,10 +232,8 @@ class VertexAiSessionService(BaseSessionService): api_client = self._get_api_client() try: - await api_client.async_request( - http_method='DELETE', - path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', - request_dict={}, + api_client.agent_engines.sessions.delete( + name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', ) except Exception as e: logger.error('Error deleting session %s: %s', session_id, e) @@ -341,10 +246,52 @@ class VertexAiSessionService(BaseSessionService): reasoning_engine_id = self._get_reasoning_engine_id(session.app_name) api_client = self._get_api_client() - await api_client.async_request( - http_method='POST', - path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', - request_dict=_convert_event_to_json(event), + + config = {} + if event.content: + config['content'] = event.content.model_dump( + exclude_none=True, mode='json' + ) + if event.actions: + config['actions'] = { + 'skip_summarization': event.actions.skip_summarization, + 'state_delta': event.actions.state_delta, + 'artifact_delta': event.actions.artifact_delta, + 'transfer_agent': event.actions.transfer_to_agent, + 'escalate': event.actions.escalate, + 'requested_auth_configs': event.actions.requested_auth_configs, + } + if event.error_code: + config['error_code'] = event.error_code + if event.error_message: + config['error_message'] = event.error_message + + metadata_dict = { + 'partial': event.partial, + 'turn_complete': event.turn_complete, + 'interrupted': event.interrupted, + 'branch': event.branch, + 'custom_metadata': event.custom_metadata, + 'long_running_tool_ids': ( + list(event.long_running_tool_ids) + if event.long_running_tool_ids + else None + ), + } + if event.grounding_metadata: + metadata_dict['grounding_metadata'] = event.grounding_metadata.model_dump( + exclude_none=True, mode='json' + ) + config['event_metadata'] = metadata_dict + + api_client.agent_engines.sessions.events.append( + name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}', + author=event.author, + invocation_id=event.invocation_id, + timestamp=datetime.datetime.fromtimestamp( + event.timestamp, tz=datetime.timezone.utc + ), + config=config, ) return event @@ -368,24 +315,19 @@ class VertexAiSessionService(BaseSessionService): def _api_client_http_options_override( self, - ) -> Optional[genai.types.HttpOptions]: + ) -> Optional[Union[types.HttpOptions, types.HttpOptionsDict]]: return None - def _get_api_client(self) -> genai.client.BaseApiClient: + def _get_api_client(self) -> vertexai.Client: """Instantiates an API client for the given project and location. - It needs to be instantiated inside each request so that the event loop - management can be properly propagated. - Returns: An API client for the given project and location. """ - return genai.client.BaseApiClient( - vertexai=True, - project=self._project, - location=self._location, - http_options=self._api_client_http_options_override(), - ) + client = vertexai.Client(project=self._project, location=self._location) + if self._api_client_http_options_override(): + client.http_options = self._api_client_http_options_override() + return client def _is_vertex_express_mode( @@ -400,13 +342,6 @@ def _is_vertex_express_mode( ) -def _convert_api_response(api_response): - """Converts the API response to a JSON object based on the type.""" - if hasattr(api_response, 'body'): - return json.loads(api_response.body) - return api_response - - def _convert_event_to_json(event: Event) -> Dict[str, Any]: metadata_json = { 'partial': event.partial, @@ -460,47 +395,60 @@ def _convert_event_to_json(event: Event) -> Dict[str, Any]: return event_json -def _from_api_event(api_event: Dict[str, Any]) -> Event: - event_actions = EventActions() - if api_event.get('actions', None): - event_actions = EventActions( - skip_summarization=api_event['actions'].get('skipSummarization', None), - state_delta=api_event['actions'].get('stateDelta', {}), - artifact_delta=api_event['actions'].get('artifactDelta', {}), - transfer_to_agent=api_event['actions'].get('transferAgent', None), - escalate=api_event['actions'].get('escalate', None), - requested_auth_configs=api_event['actions'].get( - 'requestedAuthConfigs', {} - ), - ) +def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: + """Converts an API event object to an Event object.""" + actions = getattr(api_event_obj, 'actions', None) + if actions: + actions_dict = actions.model_dump(exclude_none=True, mode='python') + rename_map = {'transfer_agent': 'transfer_to_agent'} + renamed_actions_dict = { + rename_map.get(k, k): v for k, v in actions_dict.items() + } + event_actions = EventActions.model_validate(renamed_actions_dict) + else: + event_actions = EventActions() - event = Event( - id=api_event['name'].split('/')[-1], - invocation_id=api_event['invocationId'], - author=api_event['author'], - actions=event_actions, - content=_session_util.decode_content(api_event.get('content', None)), - timestamp=isoparse(api_event['timestamp']).timestamp(), - error_code=api_event.get('errorCode', None), - error_message=api_event.get('errorMessage', None), - ) - - if api_event.get('eventMetadata', None): - long_running_tool_ids_list = api_event['eventMetadata'].get( - 'longRunningToolIds', None + event_metadata = getattr(api_event_obj, 'event_metadata', None) + if event_metadata: + long_running_tool_ids_list = getattr( + event_metadata, 'long_running_tool_ids', None ) - event.partial = api_event['eventMetadata'].get('partial', None) - event.turn_complete = api_event['eventMetadata'].get('turnComplete', None) - event.interrupted = api_event['eventMetadata'].get('interrupted', None) - event.branch = api_event['eventMetadata'].get('branch', None) - event.custom_metadata = api_event['eventMetadata'].get( - 'customMetadata', None - ) - event.grounding_metadata = _session_util.decode_grounding_metadata( - api_event['eventMetadata'].get('groundingMetadata', None) - ) - event.long_running_tool_ids = ( + long_running_tool_ids = ( set(long_running_tool_ids_list) if long_running_tool_ids_list else None ) + partial = getattr(event_metadata, 'partial', None) + turn_complete = getattr(event_metadata, 'turn_complete', None) + interrupted = getattr(event_metadata, 'interrupted', None) + branch = getattr(event_metadata, 'branch', None) + custom_metadata = getattr(event_metadata, 'custom_metadata', None) + grounding_metadata = _session_util.decode_grounding_metadata( + getattr(event_metadata, 'grounding_metadata', None) + ) + else: + long_running_tool_ids = None + partial = None + turn_complete = None + interrupted = None + branch = None + custom_metadata = None + grounding_metadata = None - return event + return Event( + id=api_event_obj.name.split('/')[-1], + invocation_id=api_event_obj.invocation_id, + author=api_event_obj.author, + actions=event_actions, + content=_session_util.decode_content( + getattr(api_event_obj, 'content', None) + ), + timestamp=api_event_obj.timestamp.timestamp(), + error_code=getattr(api_event_obj, 'error_code', None), + error_message=getattr(api_event_obj, 'error_message', None), + partial=partial, + turn_complete=turn_complete, + interrupted=interrupted, + branch=branch, + custom_metadata=custom_metadata, + grounding_metadata=grounding_metadata, + long_running_tool_ids=long_running_tool_ids, + ) diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index f72394c4..fc9c2b5f 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -12,21 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import re -import this +import types from typing import Any from typing import List from typing import Optional from typing import Tuple from unittest import mock -from urllib import parse from dateutil.parser import isoparse from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import GetSessionConfig from google.adk.sessions.session import Session from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService -from google.genai import types +from google.api_core import exceptions as api_core_exceptions +from google.genai import types as genai_types import pytest MOCK_SESSION_JSON_1 = { @@ -34,28 +36,28 @@ MOCK_SESSION_JSON_1 = { 'projects/test-project/locations/test-location/' 'reasoningEngines/123/sessions/1' ), - 'createTime': '2024-12-12T12:12:12.123456Z', - 'updateTime': '2024-12-12T12:12:12.123456Z', - 'sessionState': { + 'create_time': '2024-12-12T12:12:12.123456Z', + 'update_time': '2024-12-12T12:12:12.123456Z', + 'session_state': { 'key': {'value': 'test_value'}, }, - 'userId': 'user', + 'user_id': 'user', } MOCK_SESSION_JSON_2 = { 'name': ( 'projects/test-project/locations/test-location/' 'reasoningEngines/123/sessions/2' ), - 'updateTime': '2024-12-13T12:12:12.123456Z', - 'userId': 'user', + 'update_time': '2024-12-13T12:12:12.123456Z', + 'user_id': 'user', } MOCK_SESSION_JSON_3 = { 'name': ( 'projects/test-project/locations/test-location/' 'reasoningEngines/123/sessions/3' ), - 'updateTime': '2024-12-14T12:12:12.123456Z', - 'userId': 'user2', + 'update_time': '2024-12-14T12:12:12.123456Z', + 'user_id': 'user2', } MOCK_EVENT_JSON = [ { @@ -63,7 +65,7 @@ MOCK_EVENT_JSON = [ 'projects/test-project/locations/test-location/' 'reasoningEngines/123/sessions/1/events/123' ), - 'invocationId': '123', + 'invocation_id': '123', 'author': 'user', 'timestamp': '2024-12-12T12:12:12.123456Z', 'content': { @@ -72,17 +74,17 @@ MOCK_EVENT_JSON = [ ], }, 'actions': { - 'stateDelta': { + 'state_delta': { 'key': {'value': 'test_value'}, }, - 'transferAgent': 'agent', + 'transfer_agent': 'agent', }, - 'eventMetadata': { + 'event_metadata': { 'partial': False, - 'turnComplete': True, + 'turn_complete': True, 'interrupted': False, 'branch': '', - 'longRunningToolIds': ['tool1'], + 'long_running_tool_ids': ['tool1'], }, }, ] @@ -92,7 +94,7 @@ MOCK_EVENT_JSON_2 = [ 'projects/test-project/locations/test-location/' 'reasoningEngines/123/sessions/2/events/123' ), - 'invocationId': '222', + 'invocation_id': '222', 'author': 'user', 'timestamp': '2024-12-12T12:12:12.123456Z', }, @@ -103,9 +105,9 @@ MOCK_EVENT_JSON_3 = [ 'projects/test-project/locations/test-location/' 'reasoningEngines/123/sessions/2/events/456' ), - 'invocationId': '333', + 'invocation_id': '333', 'author': 'user', - 'timestamp': '2024-12-12T12:12:12.123456Z', + 'timestamp': '2024-12-12T12:12:13.123456Z', }, ] MOCK_SESSION_JSON_PAGE1 = { @@ -113,31 +115,33 @@ MOCK_SESSION_JSON_PAGE1 = { 'projects/test-project/locations/test-location/' 'reasoningEngines/123/sessions/page1' ), - 'updateTime': '2024-12-15T12:12:12.123456Z', - 'userId': 'user_with_pages', + 'update_time': '2024-12-15T12:12:12.123456Z', + 'user_id': 'user_with_pages', } MOCK_SESSION_JSON_PAGE2 = { 'name': ( 'projects/test-project/locations/test-location/' 'reasoningEngines/123/sessions/page2' ), - 'updateTime': '2024-12-16T12:12:12.123456Z', - 'userId': 'user_with_pages', + 'update_time': '2024-12-16T12:12:12.123456Z', + 'user_id': 'user_with_pages', } MOCK_SESSION = Session( app_name='123', user_id='user', id='1', - state=MOCK_SESSION_JSON_1['sessionState'], - last_update_time=isoparse(MOCK_SESSION_JSON_1['updateTime']).timestamp(), + state=MOCK_SESSION_JSON_1['session_state'], + last_update_time=isoparse(MOCK_SESSION_JSON_1['update_time']).timestamp(), events=[ Event( id='123', invocation_id='123', author='user', timestamp=isoparse(MOCK_EVENT_JSON[0]['timestamp']).timestamp(), - content=types.Content(parts=[types.Part(text='test_content')]), + content=genai_types.Content( + parts=[genai_types.Part(text='test_content')] + ), actions=EventActions( transfer_to_agent='agent', state_delta={'key': {'value': 'test_value'}}, @@ -155,7 +159,7 @@ MOCK_SESSION_2 = Session( app_name='123', user_id='user', id='2', - last_update_time=isoparse(MOCK_SESSION_JSON_2['updateTime']).timestamp(), + last_update_time=isoparse(MOCK_SESSION_JSON_2['update_time']).timestamp(), events=[ Event( id='123', @@ -173,12 +177,52 @@ MOCK_SESSION_2 = Session( ) -SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$' -SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?.*$' -EVENTS_REGEX = ( - r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?' -) -LRO_REGEX = r'^operations/([^/]+)$' +class PydanticNamespace(types.SimpleNamespace): + + def model_dump(self, exclude_none=True, mode='python'): + d = {} + for k, v in self.__dict__.items(): + if exclude_none and v is None: + continue + if isinstance(v, PydanticNamespace): + d[k] = v.model_dump(exclude_none=exclude_none, mode=mode) + elif isinstance(v, list): + d[k] = [ + i.model_dump(exclude_none=exclude_none, mode=mode) + if isinstance(i, PydanticNamespace) + else i + for i in v + ] + else: + d[k] = v + return d + + +def _convert_to_object(data): + if isinstance(data, dict): + kwargs = {} + for key, value in data.items(): + if key in [ + 'timestamp', + 'update_time', + 'create_time', + ] and isinstance(value, str): + kwargs[key] = isoparse(value) + elif key in [ + 'session_state', + 'state_delta', + 'artifact_delta', + 'custom_metadata', + 'requested_auth_configs', + ]: + kwargs[key] = value + else: + kwargs[key] = _convert_to_object(value) + return PydanticNamespace(**kwargs) + elif isinstance(data, list): + return [_convert_to_object(item) for item in data] + else: + return data class MockApiClient: @@ -186,112 +230,126 @@ class MockApiClient: def __init__(self) -> None: """Initializes MockClient.""" - this.session_dict: dict[str, Any] = {} - this.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {} + self.session_dict: dict[str, Any] = {} + self.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {} + self.agent_engines = mock.Mock() + self.agent_engines.sessions.get.side_effect = self._get_session + self.agent_engines.sessions.list.side_effect = self._list_sessions + self.agent_engines.sessions.delete.side_effect = self._delete_session + self.agent_engines.sessions.create.side_effect = self._create_session + self.agent_engines.sessions.events.list.side_effect = self._list_events + self.agent_engines.sessions.events.append.side_effect = self._append_event - async def async_request( - self, http_method: str, path: str, request_dict: dict[str, Any] - ): - """Mocks the API Client request method""" - if http_method == 'GET': - if re.match(SESSION_REGEX, path): - match = re.match(SESSION_REGEX, path) - if match: - session_id = match.group(2) - if session_id in self.session_dict: - return self.session_dict[session_id] - else: - raise ValueError(f'Session not found: {session_id}') - elif re.match(SESSIONS_REGEX, path): - parsed_url = parse.urlparse(path) - query_params = parse.parse_qs(parsed_url.query) - filter_val = query_params.get('filter', [''])[0] - user_id_match = re.search(r'user_id="([^"]+)"', filter_val) - if not user_id_match: - raise ValueError(f'Could not find user_id in filter: {filter_val}') - user_id = user_id_match.group(1) + def _get_session(self, name: str): + session_id = name.split('/')[-1] + if session_id in self.session_dict: + return _convert_to_object(self.session_dict[session_id]) + raise api_core_exceptions.NotFound(f'Session not found: {session_id}') - if user_id == 'user_with_pages': - page_token = query_params.get('pageToken', [None])[0] - if page_token == 'my_token': - return {'sessions': [MOCK_SESSION_JSON_PAGE2]} - else: - return { - 'sessions': [MOCK_SESSION_JSON_PAGE1], - 'nextPageToken': 'my_token', - } - return { - 'sessions': [ - session - for session in self.session_dict.values() - if session['userId'] == user_id - ], - } - elif re.match(EVENTS_REGEX, path): - match = re.match(EVENTS_REGEX, path) - if match: - session_id = match.group(2) - if match.group(3): - page_token = match.group(3) - if page_token == 'my_token': - response = {'sessionEvents': MOCK_EVENT_JSON_3} - response['nextPageToken'] = 'my_token2' - return response - else: - return {} - events_tuple = self.event_dict.get(session_id, ([], None)) - response = {'sessionEvents': events_tuple[0]} - if events_tuple[1]: - response['nextPageToken'] = events_tuple[1] - return response - elif re.match(LRO_REGEX, path): - # Mock long-running operation as completed - return { - 'name': path, - 'done': True, - 'response': self.session_dict['4'], # Return the created session - } - else: - raise ValueError(f'Unsupported path: {path}') - elif http_method == 'POST': - new_session_id = '4' - self.session_dict[new_session_id] = { - 'name': ( - 'projects/test-project/locations/test-location/' - 'reasoningEngines/123/sessions/' - + new_session_id - ), - 'userId': request_dict['user_id'], - 'sessionState': request_dict.get('session_state', {}), - 'updateTime': '2024-12-12T12:12:12.123456Z', - } - return { - 'name': ( - 'projects/test_project/locations/test_location/' - 'reasoningEngines/123/sessions/' - + new_session_id - + '/operations/111' - ), - 'done': False, - } - elif http_method == 'DELETE': - match = re.match(SESSION_REGEX, path) + def _list_sessions(self, name: str, config: dict[str, Any]): + filter_val = config.get('filter', '') + user_id_match = re.search(r'user_id="([^"]+)"', filter_val) + if not user_id_match: + raise ValueError(f'Could not find user_id in filter: {filter_val}') + user_id = user_id_match.group(1) + + if user_id == 'user_with_pages': + return [ + _convert_to_object(MOCK_SESSION_JSON_PAGE1), + _convert_to_object(MOCK_SESSION_JSON_PAGE2), + ] + return [ + _convert_to_object(session) + for session in self.session_dict.values() + if session['user_id'] == user_id + ] + + def _delete_session(self, name: str): + session_id = name.split('/')[-1] + self.session_dict.pop(session_id) + + def _create_session(self, name: str, user_id: str, config: dict[str, Any]): + new_session_id = '4' + self.session_dict[new_session_id] = { + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/' + + new_session_id + ), + 'user_id': user_id, + 'session_state': config.get('session_state', {}), + 'update_time': '2024-12-12T12:12:12.123456Z', + } + return _convert_to_object({ + 'name': ( + 'projects/test_project/locations/test_location/' + 'reasoningEngines/123/sessions/' + + new_session_id + + '/operations/111' + ), + 'done': True, + 'response': self.session_dict['4'], + }) + + def _list_events(self, name: str, **kwargs): + session_id = name.split('/')[-1] + events = [] + if session_id in self.event_dict: + events_tuple = self.event_dict[session_id] + events.extend(events_tuple[0]) + if events_tuple[1] == 'my_token': + events.extend(MOCK_EVENT_JSON_3) + + config = kwargs.get('config', {}) + filter_str = config.get('filter', None) + if filter_str: + match = re.search(r'timestamp>="([^"]+)"', filter_str) if match: - self.session_dict.pop(match.group(2)) + after_timestamp_str = match.group(1) + after_timestamp = isoparse(after_timestamp_str) + events = [ + event + for event in events + if isoparse(event['timestamp']) >= after_timestamp + ] + return [_convert_to_object(event) for event in events] + + def _append_event( + self, + name: str, + author: str, + invocation_id: str, + timestamp: Any, + config: dict[str, Any], + ): + session_id = name.split('/')[-1] + event_list, token = self.event_dict.get(session_id, ([], None)) + event_id = str(len(event_list) + 1000) # generate unique ID + + event_timestamp_str = timestamp.isoformat().replace('+00:00', 'Z') + event_json = { + 'name': f'{name}/events/{event_id}', + 'invocation_id': invocation_id, + 'author': author, + 'timestamp': event_timestamp_str, + } + event_json.update(config) + + if session_id in self.session_dict: + self.session_dict[session_id]['update_time'] = event_timestamp_str + + if session_id in self.event_dict: + self.event_dict[session_id][0].append(event_json) else: - raise ValueError(f'Unsupported http method: {http_method}') + self.event_dict[session_id] = ([event_json], None) def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None): """Creates a mock Vertex AI Session service for testing.""" - if agent_engine_id: - return VertexAiSessionService( - project='test-project', - location='test-location', - agent_engine_id=agent_engine_id, - ) return VertexAiSessionService( - project='test-project', location='test-location' + project='test-project', + location='test-location', + agent_engine_id=agent_engine_id, ) @@ -306,8 +364,8 @@ def mock_get_api_client(): 'page2': MOCK_SESSION_JSON_PAGE2, } api_client.event_dict = { - '1': (MOCK_EVENT_JSON, None), - '2': (MOCK_EVENT_JSON_2, 'my_token'), + '1': (copy.deepcopy(MOCK_EVENT_JSON), None), + '2': (copy.deepcopy(MOCK_EVENT_JSON_2), 'my_token'), } with mock.patch( 'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client', @@ -320,30 +378,24 @@ def mock_get_api_client(): @pytest.mark.usefixtures('mock_get_api_client') @pytest.mark.parametrize('agent_engine_id', [None, '123']) async def test_get_empty_session(agent_engine_id): - if agent_engine_id: - session_service = mock_vertex_ai_session_service(agent_engine_id) - else: - session_service = mock_vertex_ai_session_service() - with pytest.raises(ValueError) as excinfo: + session_service = mock_vertex_ai_session_service(agent_engine_id) + with pytest.raises(api_core_exceptions.NotFound) as excinfo: await session_service.get_session( app_name='123', user_id='user', session_id='0' ) - assert str(excinfo.value) == 'Session not found: 0' + assert str(excinfo.value) == '404 Session not found: 0' @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') @pytest.mark.parametrize('agent_engine_id', [None, '123']) async def test_get_another_user_session(agent_engine_id): - if agent_engine_id: - session_service = mock_vertex_ai_session_service(agent_engine_id) - else: - session_service = mock_vertex_ai_session_service() + session_service = mock_vertex_ai_session_service(agent_engine_id) with pytest.raises(ValueError) as excinfo: await session_service.get_session( app_name='123', user_id='user2', session_id='1' ) - assert str(excinfo.value) == 'Session not found: 1' + assert str(excinfo.value) == 'Session 1 does not belong to user user2.' @pytest.mark.asyncio @@ -361,11 +413,11 @@ async def test_get_and_delete_session(): await session_service.delete_session( app_name='123', user_id='user', session_id='1' ) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(api_core_exceptions.NotFound) as excinfo: await session_service.get_session( app_name='123', user_id='user', session_id='1' ) - assert str(excinfo.value) == 'Session not found: 1' + assert str(excinfo.value) == '404 Session not found: 1' @pytest.mark.asyncio @@ -381,6 +433,23 @@ async def test_get_session_with_page_token(): ) +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_get_session_with_after_timestamp_filter(): + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', + user_id='user', + session_id='2', + config=GetSessionConfig( + after_timestamp=isoparse('2024-12-12T12:12:13.0Z').timestamp() + ), + ) + assert session is not None + assert len(session.events) == 1 + assert session.events[0].id == '456' + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_list_sessions(): @@ -435,3 +504,38 @@ async def test_create_session_with_custom_session_id(): assert str(excinfo.value) == ( 'User-provided Session id is not supported for VertexAISessionService.' ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event(): + session_service = mock_vertex_ai_session_service() + session_before_append = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + event_to_append = Event( + invocation_id='new_invocation', + author='model', + timestamp=1734005533.0, + content=genai_types.Content(parts=[genai_types.Part(text='new_content')]), + actions=EventActions( + transfer_to_agent='another_agent', + state_delta={'new_key': 'new_value'}, + skip_summarization=True, + ), + error_code='1', + error_message='test_error', + branch='test_branch', + custom_metadata={'custom': 'data'}, + long_running_tool_ids={'tool2'}, + ) + + await session_service.append_event(session_before_append, event_to_append) + + retrieved_session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + + assert len(retrieved_session.events) == 2 + event_to_append.id = retrieved_session.events[1].id + assert retrieved_session.events[1] == event_to_append