feat: Support returning all sessions when user_id is none in the request

resolves https://github.com/google/adk-python/issues/3154

PiperOrigin-RevId: 819417330
This commit is contained in:
Google Team Member
2025-10-14 15:10:03 -07:00
committed by Copybara-Service
parent 141318f775
commit f9c09ef075
6 changed files with 41 additions and 148 deletions
@@ -83,18 +83,9 @@ class BaseSessionService(abc.ABC):
@abc.abstractmethod
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
"""Lists all the sessions for a user.
Args:
app_name: The name of the app.
user_id: The ID of the user. If not provided, lists all sessions for all
users.
Returns:
A ListSessionsResponse containing the sessions.
"""
"""Lists all the sessions."""
@abc.abstractmethod
async def delete_session(
@@ -554,42 +554,30 @@ class DatabaseSessionService(BaseSessionService):
@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
with self.database_session_factory() as sql_session:
query = sql_session.query(StorageSession).filter(
StorageSession.app_name == app_name
results = (
sql_session.query(StorageSession)
.filter(StorageSession.app_name == app_name)
.filter(StorageSession.user_id == user_id)
.all()
)
if user_id is not None:
query = query.filter(StorageSession.user_id == user_id)
results = query.all()
# Fetch app state from storage
# Fetch states from storage
storage_app_state = sql_session.get(StorageAppState, (app_name))
app_state = storage_app_state.state if storage_app_state else {}
storage_user_state = sql_session.get(
StorageUserState, (app_name, user_id)
)
# Fetch user state(s) from storage
user_states_map = {}
if user_id is not None:
storage_user_state = sql_session.get(
StorageUserState, (app_name, user_id)
)
if storage_user_state:
user_states_map[user_id] = storage_user_state.state
else:
all_user_states_for_app = (
sql_session.query(StorageUserState)
.filter(StorageUserState.app_name == app_name)
.all()
)
for storage_user_state in all_user_states_for_app:
user_states_map[storage_user_state.user_id] = storage_user_state.state
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
sessions = []
for storage_session in results:
session_state = storage_session.state
user_state = user_states_map.get(storage_session.user_id, {})
merged_state = _merge_state(app_state, user_state, session_state)
sessions.append(storage_session.to_session(state=merged_state))
return ListSessionsResponse(sessions=sessions)
@@ -201,41 +201,31 @@ class InMemorySessionService(BaseSessionService):
@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
def list_sessions_sync(
self, *, app_name: str, user_id: Optional[str] = None
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
logger.warning('Deprecated. Please migrate to the async method.')
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
def _list_sessions_impl(
self, *, app_name: str, user_id: Optional[str] = None
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
empty_response = ListSessionsResponse()
if app_name not in self.sessions:
return empty_response
if user_id is not None and user_id not in self.sessions[app_name]:
if user_id not in self.sessions[app_name]:
return empty_response
sessions_without_events = []
if user_id is None:
for user_id in self.sessions[app_name]:
for session_id in self.sessions[app_name][user_id]:
session = self.sessions[app_name][user_id][session_id]
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session = self._merge_state(app_name, user_id, copied_session)
sessions_without_events.append(copied_session)
else:
for session in self.sessions[app_name][user_id].values():
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session = self._merge_state(app_name, user_id, copied_session)
sessions_without_events.append(copied_session)
for session in self.sessions[app_name][user_id].values():
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session = self._merge_state(app_name, user_id, copied_session)
sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events)
@override
@@ -200,25 +200,22 @@ class VertexAiSessionService(BaseSessionService):
@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
api_client = self._get_api_client()
sessions = []
config = {}
if user_id is not None:
config['filter'] = f'user_id="{user_id}"'
sessions_iterator = api_client.agent_engines.sessions.list(
name=f'reasoningEngines/{reasoning_engine_id}',
config=config,
config={'filter': f'user_id="{user_id}"'},
)
for api_session in sessions_iterator:
sessions.append(
Session(
app_name=app_name,
user_id=api_session.user_id,
user_id=user_id,
id=api_session.name.split('/')[-1],
state=getattr(api_session, 'session_state', None) or {},
last_update_time=api_session.update_time.timestamp(),
@@ -116,70 +116,9 @@ async def test_create_and_list_sessions(service_type):
app_name=app_name, user_id=user_id
)
sessions = list_sessions_response.sessions
assert len(sessions) == len(session_ids)
assert {s.id for s in sessions} == set(session_ids)
for session in sessions:
assert session.state == {'key': 'value' + session.id}
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_list_sessions_all_users(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id_1 = 'user1'
user_id_2 = 'user2'
await session_service.create_session(
app_name=app_name,
user_id=user_id_1,
session_id='session1a',
state={'key': 'value1a'},
)
await session_service.create_session(
app_name=app_name,
user_id=user_id_1,
session_id='session1b',
state={'key': 'value1b'},
)
await session_service.create_session(
app_name=app_name,
user_id=user_id_2,
session_id='session2a',
state={'key': 'value2a'},
)
# List sessions for user1
list_sessions_response_1 = await session_service.list_sessions(
app_name=app_name, user_id=user_id_1
)
sessions_1 = list_sessions_response_1.sessions
assert len(sessions_1) == 2
assert {s.id for s in sessions_1} == {'session1a', 'session1b'}
for session in sessions_1:
if session.id == 'session1a':
assert session.state == {'key': 'value1a'}
else:
assert session.state == {'key': 'value1b'}
# List sessions for user2
list_sessions_response_2 = await session_service.list_sessions(
app_name=app_name, user_id=user_id_2
)
sessions_2 = list_sessions_response_2.sessions
assert len(sessions_2) == 1
assert sessions_2[0].id == 'session2a'
assert sessions_2[0].state == {'key': 'value2a'}
# List sessions for all users
list_sessions_response_all = await session_service.list_sessions(
app_name=app_name, user_id=None
)
sessions_all = list_sessions_response_all.sessions
assert len(sessions_all) == 3
assert {s.id for s in sessions_all} == {'session1a', 'session1b', 'session2a'}
for i in range(len(sessions)):
assert sessions[i].id == session_ids[i]
assert sessions[i].state == {'key': 'value' + session_ids[i]}
@pytest.mark.asyncio
@@ -252,22 +252,19 @@ class MockApiClient:
def _list_sessions(self, name: str, config: dict[str, Any]):
filter_val = config.get('filter', '')
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
if user_id_match:
user_id = user_id_match.group(1)
if user_id == 'user_with_pages':
return [
_convert_to_object(MOCK_SESSION_JSON_PAGE1),
_convert_to_object(MOCK_SESSION_JSON_PAGE2),
]
return [
_convert_to_object(session)
for session in self.session_dict.values()
if session['user_id'] == user_id
]
if not user_id_match:
raise ValueError(f'Could not find user_id in filter: {filter_val}')
user_id = user_id_match.group(1)
# No user filter, return all sessions
if user_id == 'user_with_pages':
return [
_convert_to_object(MOCK_SESSION_JSON_PAGE1),
_convert_to_object(MOCK_SESSION_JSON_PAGE2),
]
return [
_convert_to_object(session) for session in self.session_dict.values()
_convert_to_object(session)
for session in self.session_dict.values()
if session['user_id'] == user_id
]
def _delete_session(self, name: str):
@@ -478,15 +475,6 @@ async def test_list_sessions_with_pagination():
assert sessions.sessions[1].id == 'page2'
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_list_sessions_all_users():
session_service = mock_vertex_ai_session_service()
sessions = await session_service.list_sessions(app_name='123', user_id=None)
assert len(sessions.sessions) == 5
assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'}
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_create_session():