diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index db864b3a..29681622 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -275,36 +275,47 @@ class VertexAiSessionService(BaseSessionService): reasoning_engine_id = self._get_reasoning_engine_id(app_name) api_client = self._get_api_client() - path = f'reasoningEngines/{reasoning_engine_id}/sessions' - if user_id: - parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='') - path = path + f'?filter=user_id={parsed_user_id}' - - list_sessions_api_response = await api_client.async_request( - http_method='GET', - path=path, - request_dict={}, - ) - list_sessions_api_response = _convert_api_response( - list_sessions_api_response - ) - - # Handles empty response case - if not list_sessions_api_response or list_sessions_api_response.get( - 'httpHeaders', None - ): - return ListSessionsResponse() - + base_path = f'reasoningEngines/{reasoning_engine_id}/sessions' sessions = [] - for api_session in list_sessions_api_response['sessions']: - session = Session( - app_name=app_name, - user_id=user_id, - id=api_session['name'].split('/')[-1], - state=api_session.get('sessionState', {}), - last_update_time=isoparse(api_session['updateTime']).timestamp(), + page_token = None + while True: + path = base_path + query_params = {} + if user_id: + query_params['filter'] = f'user_id="{user_id}"' + if page_token: + query_params['pageToken'] = page_token + + if query_params: + path = f'{path}?{urllib.parse.urlencode(query_params)}' + + list_sessions_api_response = await api_client.async_request( + http_method='GET', + path=path, + request_dict={}, ) - sessions.append(session) + converted_api_response = _convert_api_response(list_sessions_api_response) + + # Handles empty response case + if not converted_api_response or converted_api_response.get( + 'httpHeaders', None + ): + break + + for api_session in converted_api_response.get('sessions', []): + session = Session( + app_name=app_name, + user_id=user_id, + id=api_session['name'].split('/')[-1], + state=api_session.get('sessionState', {}), + last_update_time=isoparse(api_session['updateTime']).timestamp(), + ) + sessions.append(session) + + page_token = converted_api_response.get('nextPageToken') + if not page_token: + break + return ListSessionsResponse(sessions=sessions) async def delete_session( diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 9601c93f..f72394c4 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -19,6 +19,7 @@ from typing import List from typing import Optional from typing import Tuple from unittest import mock +from urllib import parse from dateutil.parser import isoparse from google.adk.events.event import Event @@ -107,6 +108,22 @@ MOCK_EVENT_JSON_3 = [ 'timestamp': '2024-12-12T12:12:12.123456Z', }, ] +MOCK_SESSION_JSON_PAGE1 = { + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/page1' + ), + 'updateTime': '2024-12-15T12:12:12.123456Z', + 'userId': 'user_with_pages', +} +MOCK_SESSION_JSON_PAGE2 = { + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/page2' + ), + 'updateTime': '2024-12-16T12:12:12.123456Z', + 'userId': 'user_with_pages', +} MOCK_SESSION = Session( app_name='123', @@ -157,9 +174,7 @@ MOCK_SESSION_2 = Session( SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$' -SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string - r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$' -) +SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?.*$' EVENTS_REGEX = ( r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?' ) @@ -188,12 +203,28 @@ class MockApiClient: else: raise ValueError(f'Session not found: {session_id}') elif re.match(SESSIONS_REGEX, path): - match = re.match(SESSIONS_REGEX, path) + parsed_url = parse.urlparse(path) + query_params = parse.parse_qs(parsed_url.query) + filter_val = query_params.get('filter', [''])[0] + user_id_match = re.search(r'user_id="([^"]+)"', filter_val) + if not user_id_match: + raise ValueError(f'Could not find user_id in filter: {filter_val}') + user_id = user_id_match.group(1) + + if user_id == 'user_with_pages': + page_token = query_params.get('pageToken', [None])[0] + if page_token == 'my_token': + return {'sessions': [MOCK_SESSION_JSON_PAGE2]} + else: + return { + 'sessions': [MOCK_SESSION_JSON_PAGE1], + 'nextPageToken': 'my_token', + } return { 'sessions': [ session for session in self.session_dict.values() - if session['userId'] == match.group(2) + if session['userId'] == user_id ], } elif re.match(EVENTS_REGEX, path): @@ -271,6 +302,8 @@ def mock_get_api_client(): '1': MOCK_SESSION_JSON_1, '2': MOCK_SESSION_JSON_2, '3': MOCK_SESSION_JSON_3, + 'page1': MOCK_SESSION_JSON_PAGE1, + 'page2': MOCK_SESSION_JSON_PAGE2, } api_client.event_dict = { '1': (MOCK_EVENT_JSON, None), @@ -358,6 +391,18 @@ async def test_list_sessions(): assert sessions.sessions[1].id == '2' +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_list_sessions_with_pagination(): + session_service = mock_vertex_ai_session_service() + sessions = await session_service.list_sessions( + app_name='123', user_id='user_with_pages' + ) + assert len(sessions.sessions) == 2 + assert sessions.sessions[0].id == 'page1' + assert sessions.sessions[1].id == 'page2' + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_create_session():