fix: Fix pagination of list_sessions in VertexAiSessionService

Resolves https://github.com/google/adk-python/issues/2860

PiperOrigin-RevId: 804511401
This commit is contained in:
Shangjie Chen
2025-09-08 11:13:58 -07:00
committed by Copybara-Service
parent bc6b5462a7
commit e63fe0c0eb
2 changed files with 89 additions and 33 deletions
@@ -275,36 +275,47 @@ class VertexAiSessionService(BaseSessionService):
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
api_client = self._get_api_client()
path = f'reasoningEngines/{reasoning_engine_id}/sessions'
if user_id:
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
path = path + f'?filter=user_id={parsed_user_id}'
list_sessions_api_response = await api_client.async_request(
http_method='GET',
path=path,
request_dict={},
)
list_sessions_api_response = _convert_api_response(
list_sessions_api_response
)
# Handles empty response case
if not list_sessions_api_response or list_sessions_api_response.get(
'httpHeaders', None
):
return ListSessionsResponse()
base_path = f'reasoningEngines/{reasoning_engine_id}/sessions'
sessions = []
for api_session in list_sessions_api_response['sessions']:
session = Session(
app_name=app_name,
user_id=user_id,
id=api_session['name'].split('/')[-1],
state=api_session.get('sessionState', {}),
last_update_time=isoparse(api_session['updateTime']).timestamp(),
page_token = None
while True:
path = base_path
query_params = {}
if user_id:
query_params['filter'] = f'user_id="{user_id}"'
if page_token:
query_params['pageToken'] = page_token
if query_params:
path = f'{path}?{urllib.parse.urlencode(query_params)}'
list_sessions_api_response = await api_client.async_request(
http_method='GET',
path=path,
request_dict={},
)
sessions.append(session)
converted_api_response = _convert_api_response(list_sessions_api_response)
# Handles empty response case
if not converted_api_response or converted_api_response.get(
'httpHeaders', None
):
break
for api_session in converted_api_response.get('sessions', []):
session = Session(
app_name=app_name,
user_id=user_id,
id=api_session['name'].split('/')[-1],
state=api_session.get('sessionState', {}),
last_update_time=isoparse(api_session['updateTime']).timestamp(),
)
sessions.append(session)
page_token = converted_api_response.get('nextPageToken')
if not page_token:
break
return ListSessionsResponse(sessions=sessions)
async def delete_session(
@@ -19,6 +19,7 @@ from typing import List
from typing import Optional
from typing import Tuple
from unittest import mock
from urllib import parse
from dateutil.parser import isoparse
from google.adk.events.event import Event
@@ -107,6 +108,22 @@ MOCK_EVENT_JSON_3 = [
'timestamp': '2024-12-12T12:12:12.123456Z',
},
]
MOCK_SESSION_JSON_PAGE1 = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/page1'
),
'updateTime': '2024-12-15T12:12:12.123456Z',
'userId': 'user_with_pages',
}
MOCK_SESSION_JSON_PAGE2 = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/page2'
),
'updateTime': '2024-12-16T12:12:12.123456Z',
'userId': 'user_with_pages',
}
MOCK_SESSION = Session(
app_name='123',
@@ -157,9 +174,7 @@ MOCK_SESSION_2 = Session(
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string
r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$'
)
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?.*$'
EVENTS_REGEX = (
r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?'
)
@@ -188,12 +203,28 @@ class MockApiClient:
else:
raise ValueError(f'Session not found: {session_id}')
elif re.match(SESSIONS_REGEX, path):
match = re.match(SESSIONS_REGEX, path)
parsed_url = parse.urlparse(path)
query_params = parse.parse_qs(parsed_url.query)
filter_val = query_params.get('filter', [''])[0]
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
if not user_id_match:
raise ValueError(f'Could not find user_id in filter: {filter_val}')
user_id = user_id_match.group(1)
if user_id == 'user_with_pages':
page_token = query_params.get('pageToken', [None])[0]
if page_token == 'my_token':
return {'sessions': [MOCK_SESSION_JSON_PAGE2]}
else:
return {
'sessions': [MOCK_SESSION_JSON_PAGE1],
'nextPageToken': 'my_token',
}
return {
'sessions': [
session
for session in self.session_dict.values()
if session['userId'] == match.group(2)
if session['userId'] == user_id
],
}
elif re.match(EVENTS_REGEX, path):
@@ -271,6 +302,8 @@ def mock_get_api_client():
'1': MOCK_SESSION_JSON_1,
'2': MOCK_SESSION_JSON_2,
'3': MOCK_SESSION_JSON_3,
'page1': MOCK_SESSION_JSON_PAGE1,
'page2': MOCK_SESSION_JSON_PAGE2,
}
api_client.event_dict = {
'1': (MOCK_EVENT_JSON, None),
@@ -358,6 +391,18 @@ async def test_list_sessions():
assert sessions.sessions[1].id == '2'
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_list_sessions_with_pagination():
session_service = mock_vertex_ai_session_service()
sessions = await session_service.list_sessions(
app_name='123', user_id='user_with_pages'
)
assert len(sessions.sessions) == 2
assert sessions.sessions[0].id == 'page1'
assert sessions.sessions[1].id == 'page2'
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_create_session():