diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index f2f6f9f2..a76a6b9d 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -83,18 +83,9 @@ class BaseSessionService(abc.ABC): @abc.abstractmethod async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, *, app_name: str, user_id: str ) -> ListSessionsResponse: - """Lists all the sessions for a user. - - Args: - app_name: The name of the app. - user_id: The ID of the user. If not provided, lists all sessions for all - users. - - Returns: - A ListSessionsResponse containing the sessions. - """ + """Lists all the sessions.""" @abc.abstractmethod async def delete_session( diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 7ed3e3be..3ed87847 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -554,42 +554,30 @@ class DatabaseSessionService(BaseSessionService): @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, *, app_name: str, user_id: str ) -> ListSessionsResponse: with self.database_session_factory() as sql_session: - query = sql_session.query(StorageSession).filter( - StorageSession.app_name == app_name + results = ( + sql_session.query(StorageSession) + .filter(StorageSession.app_name == app_name) + .filter(StorageSession.user_id == user_id) + .all() ) - if user_id is not None: - query = query.filter(StorageSession.user_id == user_id) - results = query.all() - # Fetch app state from storage + # Fetch states from storage storage_app_state = sql_session.get(StorageAppState, (app_name)) - app_state = storage_app_state.state if storage_app_state else {} + storage_user_state = sql_session.get( + StorageUserState, (app_name, user_id) + ) - # Fetch user state(s) from storage - user_states_map = {} - if user_id is not None: - storage_user_state = sql_session.get( - StorageUserState, (app_name, user_id) - ) - if storage_user_state: - user_states_map[user_id] = storage_user_state.state - else: - all_user_states_for_app = ( - sql_session.query(StorageUserState) - .filter(StorageUserState.app_name == app_name) - .all() - ) - for storage_user_state in all_user_states_for_app: - user_states_map[storage_user_state.user_id] = storage_user_state.state + app_state = storage_app_state.state if storage_app_state else {} + user_state = storage_user_state.state if storage_user_state else {} sessions = [] for storage_session in results: session_state = storage_session.state - user_state = user_states_map.get(storage_session.user_id, {}) merged_state = _merge_state(app_state, user_state, session_state) + sessions.append(storage_session.to_session(state=merged_state)) return ListSessionsResponse(sessions=sessions) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index adc09230..bbb480ae 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -201,41 +201,31 @@ class InMemorySessionService(BaseSessionService): @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, *, app_name: str, user_id: str ) -> ListSessionsResponse: return self._list_sessions_impl(app_name=app_name, user_id=user_id) def list_sessions_sync( - self, *, app_name: str, user_id: Optional[str] = None + self, *, app_name: str, user_id: str ) -> ListSessionsResponse: logger.warning('Deprecated. Please migrate to the async method.') return self._list_sessions_impl(app_name=app_name, user_id=user_id) def _list_sessions_impl( - self, *, app_name: str, user_id: Optional[str] = None + self, *, app_name: str, user_id: str ) -> ListSessionsResponse: empty_response = ListSessionsResponse() if app_name not in self.sessions: return empty_response - if user_id is not None and user_id not in self.sessions[app_name]: + if user_id not in self.sessions[app_name]: return empty_response sessions_without_events = [] - - if user_id is None: - for user_id in self.sessions[app_name]: - for session_id in self.sessions[app_name][user_id]: - session = self.sessions[app_name][user_id][session_id] - copied_session = copy.deepcopy(session) - copied_session.events = [] - copied_session = self._merge_state(app_name, user_id, copied_session) - sessions_without_events.append(copied_session) - else: - for session in self.sessions[app_name][user_id].values(): - copied_session = copy.deepcopy(session) - copied_session.events = [] - copied_session = self._merge_state(app_name, user_id, copied_session) - sessions_without_events.append(copied_session) + for session in self.sessions[app_name][user_id].values(): + copied_session = copy.deepcopy(session) + copied_session.events = [] + copied_session = self._merge_state(app_name, user_id, copied_session) + sessions_without_events.append(copied_session) return ListSessionsResponse(sessions=sessions_without_events) @override diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 72ff0d6c..3cb1d61e 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -200,25 +200,22 @@ class VertexAiSessionService(BaseSessionService): @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, *, app_name: str, user_id: str ) -> ListSessionsResponse: reasoning_engine_id = self._get_reasoning_engine_id(app_name) api_client = self._get_api_client() sessions = [] - config = {} - if user_id is not None: - config['filter'] = f'user_id="{user_id}"' sessions_iterator = api_client.agent_engines.sessions.list( name=f'reasoningEngines/{reasoning_engine_id}', - config=config, + config={'filter': f'user_id="{user_id}"'}, ) for api_session in sessions_iterator: sessions.append( Session( app_name=app_name, - user_id=api_session.user_id, + user_id=user_id, id=api_session.name.split('/')[-1], state=getattr(api_session, 'session_state', None) or {}, last_update_time=api_session.update_time.timestamp(), diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 0c005d68..2ca265cb 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -116,70 +116,9 @@ async def test_create_and_list_sessions(service_type): app_name=app_name, user_id=user_id ) sessions = list_sessions_response.sessions - assert len(sessions) == len(session_ids) - assert {s.id for s in sessions} == set(session_ids) - for session in sessions: - assert session.state == {'key': 'value' + session.id} - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] -) -async def test_list_sessions_all_users(service_type): - session_service = get_session_service(service_type) - app_name = 'my_app' - user_id_1 = 'user1' - user_id_2 = 'user2' - - await session_service.create_session( - app_name=app_name, - user_id=user_id_1, - session_id='session1a', - state={'key': 'value1a'}, - ) - await session_service.create_session( - app_name=app_name, - user_id=user_id_1, - session_id='session1b', - state={'key': 'value1b'}, - ) - await session_service.create_session( - app_name=app_name, - user_id=user_id_2, - session_id='session2a', - state={'key': 'value2a'}, - ) - - # List sessions for user1 - list_sessions_response_1 = await session_service.list_sessions( - app_name=app_name, user_id=user_id_1 - ) - sessions_1 = list_sessions_response_1.sessions - assert len(sessions_1) == 2 - assert {s.id for s in sessions_1} == {'session1a', 'session1b'} - for session in sessions_1: - if session.id == 'session1a': - assert session.state == {'key': 'value1a'} - else: - assert session.state == {'key': 'value1b'} - - # List sessions for user2 - list_sessions_response_2 = await session_service.list_sessions( - app_name=app_name, user_id=user_id_2 - ) - sessions_2 = list_sessions_response_2.sessions - assert len(sessions_2) == 1 - assert sessions_2[0].id == 'session2a' - assert sessions_2[0].state == {'key': 'value2a'} - - # List sessions for all users - list_sessions_response_all = await session_service.list_sessions( - app_name=app_name, user_id=None - ) - sessions_all = list_sessions_response_all.sessions - assert len(sessions_all) == 3 - assert {s.id for s in sessions_all} == {'session1a', 'session1b', 'session2a'} + for i in range(len(sessions)): + assert sessions[i].id == session_ids[i] + assert sessions[i].state == {'key': 'value' + session_ids[i]} @pytest.mark.asyncio diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index bf06fc01..92a180e6 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -252,22 +252,19 @@ class MockApiClient: def _list_sessions(self, name: str, config: dict[str, Any]): filter_val = config.get('filter', '') user_id_match = re.search(r'user_id="([^"]+)"', filter_val) - if user_id_match: - user_id = user_id_match.group(1) - if user_id == 'user_with_pages': - return [ - _convert_to_object(MOCK_SESSION_JSON_PAGE1), - _convert_to_object(MOCK_SESSION_JSON_PAGE2), - ] - return [ - _convert_to_object(session) - for session in self.session_dict.values() - if session['user_id'] == user_id - ] + if not user_id_match: + raise ValueError(f'Could not find user_id in filter: {filter_val}') + user_id = user_id_match.group(1) - # No user filter, return all sessions + if user_id == 'user_with_pages': + return [ + _convert_to_object(MOCK_SESSION_JSON_PAGE1), + _convert_to_object(MOCK_SESSION_JSON_PAGE2), + ] return [ - _convert_to_object(session) for session in self.session_dict.values() + _convert_to_object(session) + for session in self.session_dict.values() + if session['user_id'] == user_id ] def _delete_session(self, name: str): @@ -478,15 +475,6 @@ async def test_list_sessions_with_pagination(): assert sessions.sessions[1].id == 'page2' -@pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') -async def test_list_sessions_all_users(): - session_service = mock_vertex_ai_session_service() - sessions = await session_service.list_sessions(app_name='123', user_id=None) - assert len(sessions.sessions) == 5 - assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'} - - @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_create_session():