You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Fix pagination of list_sessions in VertexAiSessionService
Resolves https://github.com/google/adk-python/issues/2860 PiperOrigin-RevId: 804511401
This commit is contained in:
committed by
Copybara-Service
parent
bc6b5462a7
commit
e63fe0c0eb
@@ -275,36 +275,47 @@ class VertexAiSessionService(BaseSessionService):
|
||||
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
|
||||
api_client = self._get_api_client()
|
||||
|
||||
path = f'reasoningEngines/{reasoning_engine_id}/sessions'
|
||||
if user_id:
|
||||
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
|
||||
path = path + f'?filter=user_id={parsed_user_id}'
|
||||
|
||||
list_sessions_api_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=path,
|
||||
request_dict={},
|
||||
)
|
||||
list_sessions_api_response = _convert_api_response(
|
||||
list_sessions_api_response
|
||||
)
|
||||
|
||||
# Handles empty response case
|
||||
if not list_sessions_api_response or list_sessions_api_response.get(
|
||||
'httpHeaders', None
|
||||
):
|
||||
return ListSessionsResponse()
|
||||
|
||||
base_path = f'reasoningEngines/{reasoning_engine_id}/sessions'
|
||||
sessions = []
|
||||
for api_session in list_sessions_api_response['sessions']:
|
||||
session = Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=api_session['name'].split('/')[-1],
|
||||
state=api_session.get('sessionState', {}),
|
||||
last_update_time=isoparse(api_session['updateTime']).timestamp(),
|
||||
page_token = None
|
||||
while True:
|
||||
path = base_path
|
||||
query_params = {}
|
||||
if user_id:
|
||||
query_params['filter'] = f'user_id="{user_id}"'
|
||||
if page_token:
|
||||
query_params['pageToken'] = page_token
|
||||
|
||||
if query_params:
|
||||
path = f'{path}?{urllib.parse.urlencode(query_params)}'
|
||||
|
||||
list_sessions_api_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=path,
|
||||
request_dict={},
|
||||
)
|
||||
sessions.append(session)
|
||||
converted_api_response = _convert_api_response(list_sessions_api_response)
|
||||
|
||||
# Handles empty response case
|
||||
if not converted_api_response or converted_api_response.get(
|
||||
'httpHeaders', None
|
||||
):
|
||||
break
|
||||
|
||||
for api_session in converted_api_response.get('sessions', []):
|
||||
session = Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=api_session['name'].split('/')[-1],
|
||||
state=api_session.get('sessionState', {}),
|
||||
last_update_time=isoparse(api_session['updateTime']).timestamp(),
|
||||
)
|
||||
sessions.append(session)
|
||||
|
||||
page_token = converted_api_response.get('nextPageToken')
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
async def delete_session(
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from unittest import mock
|
||||
from urllib import parse
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from google.adk.events.event import Event
|
||||
@@ -107,6 +108,22 @@ MOCK_EVENT_JSON_3 = [
|
||||
'timestamp': '2024-12-12T12:12:12.123456Z',
|
||||
},
|
||||
]
|
||||
MOCK_SESSION_JSON_PAGE1 = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/page1'
|
||||
),
|
||||
'updateTime': '2024-12-15T12:12:12.123456Z',
|
||||
'userId': 'user_with_pages',
|
||||
}
|
||||
MOCK_SESSION_JSON_PAGE2 = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/page2'
|
||||
),
|
||||
'updateTime': '2024-12-16T12:12:12.123456Z',
|
||||
'userId': 'user_with_pages',
|
||||
}
|
||||
|
||||
MOCK_SESSION = Session(
|
||||
app_name='123',
|
||||
@@ -157,9 +174,7 @@ MOCK_SESSION_2 = Session(
|
||||
|
||||
|
||||
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
|
||||
SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string
|
||||
r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$'
|
||||
)
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?.*$'
|
||||
EVENTS_REGEX = (
|
||||
r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?'
|
||||
)
|
||||
@@ -188,12 +203,28 @@ class MockApiClient:
|
||||
else:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
elif re.match(SESSIONS_REGEX, path):
|
||||
match = re.match(SESSIONS_REGEX, path)
|
||||
parsed_url = parse.urlparse(path)
|
||||
query_params = parse.parse_qs(parsed_url.query)
|
||||
filter_val = query_params.get('filter', [''])[0]
|
||||
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
|
||||
if not user_id_match:
|
||||
raise ValueError(f'Could not find user_id in filter: {filter_val}')
|
||||
user_id = user_id_match.group(1)
|
||||
|
||||
if user_id == 'user_with_pages':
|
||||
page_token = query_params.get('pageToken', [None])[0]
|
||||
if page_token == 'my_token':
|
||||
return {'sessions': [MOCK_SESSION_JSON_PAGE2]}
|
||||
else:
|
||||
return {
|
||||
'sessions': [MOCK_SESSION_JSON_PAGE1],
|
||||
'nextPageToken': 'my_token',
|
||||
}
|
||||
return {
|
||||
'sessions': [
|
||||
session
|
||||
for session in self.session_dict.values()
|
||||
if session['userId'] == match.group(2)
|
||||
if session['userId'] == user_id
|
||||
],
|
||||
}
|
||||
elif re.match(EVENTS_REGEX, path):
|
||||
@@ -271,6 +302,8 @@ def mock_get_api_client():
|
||||
'1': MOCK_SESSION_JSON_1,
|
||||
'2': MOCK_SESSION_JSON_2,
|
||||
'3': MOCK_SESSION_JSON_3,
|
||||
'page1': MOCK_SESSION_JSON_PAGE1,
|
||||
'page2': MOCK_SESSION_JSON_PAGE2,
|
||||
}
|
||||
api_client.event_dict = {
|
||||
'1': (MOCK_EVENT_JSON, None),
|
||||
@@ -358,6 +391,18 @@ async def test_list_sessions():
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
async def test_list_sessions_with_pagination():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = await session_service.list_sessions(
|
||||
app_name='123', user_id='user_with_pages'
|
||||
)
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == 'page1'
|
||||
assert sessions.sessions[1].id == 'page2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
async def test_create_session():
|
||||
|
||||
Reference in New Issue
Block a user