You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support returning all sessions when user_id is none in the request
resolves https://github.com/google/adk-python/issues/3154 PiperOrigin-RevId: 819417330
This commit is contained in:
committed by
Copybara-Service
parent
141318f775
commit
f9c09ef075
@@ -83,18 +83,9 @@ class BaseSessionService(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_sessions(
|
||||
self, *, app_name: str, user_id: Optional[str] = None
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
"""Lists all the sessions for a user.
|
||||
|
||||
Args:
|
||||
app_name: The name of the app.
|
||||
user_id: The ID of the user. If not provided, lists all sessions for all
|
||||
users.
|
||||
|
||||
Returns:
|
||||
A ListSessionsResponse containing the sessions.
|
||||
"""
|
||||
"""Lists all the sessions."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_session(
|
||||
|
||||
@@ -554,42 +554,30 @@ class DatabaseSessionService(BaseSessionService):
|
||||
|
||||
@override
|
||||
async def list_sessions(
|
||||
self, *, app_name: str, user_id: Optional[str] = None
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
with self.database_session_factory() as sql_session:
|
||||
query = sql_session.query(StorageSession).filter(
|
||||
StorageSession.app_name == app_name
|
||||
results = (
|
||||
sql_session.query(StorageSession)
|
||||
.filter(StorageSession.app_name == app_name)
|
||||
.filter(StorageSession.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
if user_id is not None:
|
||||
query = query.filter(StorageSession.user_id == user_id)
|
||||
results = query.all()
|
||||
|
||||
# Fetch app state from storage
|
||||
# Fetch states from storage
|
||||
storage_app_state = sql_session.get(StorageAppState, (app_name))
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
storage_user_state = sql_session.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
|
||||
# Fetch user state(s) from storage
|
||||
user_states_map = {}
|
||||
if user_id is not None:
|
||||
storage_user_state = sql_session.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
if storage_user_state:
|
||||
user_states_map[user_id] = storage_user_state.state
|
||||
else:
|
||||
all_user_states_for_app = (
|
||||
sql_session.query(StorageUserState)
|
||||
.filter(StorageUserState.app_name == app_name)
|
||||
.all()
|
||||
)
|
||||
for storage_user_state in all_user_states_for_app:
|
||||
user_states_map[storage_user_state.user_id] = storage_user_state.state
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
user_state = storage_user_state.state if storage_user_state else {}
|
||||
|
||||
sessions = []
|
||||
for storage_session in results:
|
||||
session_state = storage_session.state
|
||||
user_state = user_states_map.get(storage_session.user_id, {})
|
||||
merged_state = _merge_state(app_state, user_state, session_state)
|
||||
|
||||
sessions.append(storage_session.to_session(state=merged_state))
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
|
||||
@@ -201,41 +201,31 @@ class InMemorySessionService(BaseSessionService):
|
||||
|
||||
@override
|
||||
async def list_sessions(
|
||||
self, *, app_name: str, user_id: Optional[str] = None
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
|
||||
|
||||
def list_sessions_sync(
|
||||
self, *, app_name: str, user_id: Optional[str] = None
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
logger.warning('Deprecated. Please migrate to the async method.')
|
||||
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
|
||||
|
||||
def _list_sessions_impl(
|
||||
self, *, app_name: str, user_id: Optional[str] = None
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
empty_response = ListSessionsResponse()
|
||||
if app_name not in self.sessions:
|
||||
return empty_response
|
||||
if user_id is not None and user_id not in self.sessions[app_name]:
|
||||
if user_id not in self.sessions[app_name]:
|
||||
return empty_response
|
||||
|
||||
sessions_without_events = []
|
||||
|
||||
if user_id is None:
|
||||
for user_id in self.sessions[app_name]:
|
||||
for session_id in self.sessions[app_name][user_id]:
|
||||
session = self.sessions[app_name][user_id][session_id]
|
||||
copied_session = copy.deepcopy(session)
|
||||
copied_session.events = []
|
||||
copied_session = self._merge_state(app_name, user_id, copied_session)
|
||||
sessions_without_events.append(copied_session)
|
||||
else:
|
||||
for session in self.sessions[app_name][user_id].values():
|
||||
copied_session = copy.deepcopy(session)
|
||||
copied_session.events = []
|
||||
copied_session = self._merge_state(app_name, user_id, copied_session)
|
||||
sessions_without_events.append(copied_session)
|
||||
for session in self.sessions[app_name][user_id].values():
|
||||
copied_session = copy.deepcopy(session)
|
||||
copied_session.events = []
|
||||
copied_session = self._merge_state(app_name, user_id, copied_session)
|
||||
sessions_without_events.append(copied_session)
|
||||
return ListSessionsResponse(sessions=sessions_without_events)
|
||||
|
||||
@override
|
||||
|
||||
@@ -200,25 +200,22 @@ class VertexAiSessionService(BaseSessionService):
|
||||
|
||||
@override
|
||||
async def list_sessions(
|
||||
self, *, app_name: str, user_id: Optional[str] = None
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
|
||||
api_client = self._get_api_client()
|
||||
|
||||
sessions = []
|
||||
config = {}
|
||||
if user_id is not None:
|
||||
config['filter'] = f'user_id="{user_id}"'
|
||||
sessions_iterator = api_client.agent_engines.sessions.list(
|
||||
name=f'reasoningEngines/{reasoning_engine_id}',
|
||||
config=config,
|
||||
config={'filter': f'user_id="{user_id}"'},
|
||||
)
|
||||
|
||||
for api_session in sessions_iterator:
|
||||
sessions.append(
|
||||
Session(
|
||||
app_name=app_name,
|
||||
user_id=api_session.user_id,
|
||||
user_id=user_id,
|
||||
id=api_session.name.split('/')[-1],
|
||||
state=getattr(api_session, 'session_state', None) or {},
|
||||
last_update_time=api_session.update_time.timestamp(),
|
||||
|
||||
@@ -116,70 +116,9 @@ async def test_create_and_list_sessions(service_type):
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
sessions = list_sessions_response.sessions
|
||||
assert len(sessions) == len(session_ids)
|
||||
assert {s.id for s in sessions} == set(session_ids)
|
||||
for session in sessions:
|
||||
assert session.state == {'key': 'value' + session.id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
async def test_list_sessions_all_users(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id_1 = 'user1'
|
||||
user_id_2 = 'user2'
|
||||
|
||||
await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
session_id='session1a',
|
||||
state={'key': 'value1a'},
|
||||
)
|
||||
await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
session_id='session1b',
|
||||
state={'key': 'value1b'},
|
||||
)
|
||||
await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_2,
|
||||
session_id='session2a',
|
||||
state={'key': 'value2a'},
|
||||
)
|
||||
|
||||
# List sessions for user1
|
||||
list_sessions_response_1 = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id_1
|
||||
)
|
||||
sessions_1 = list_sessions_response_1.sessions
|
||||
assert len(sessions_1) == 2
|
||||
assert {s.id for s in sessions_1} == {'session1a', 'session1b'}
|
||||
for session in sessions_1:
|
||||
if session.id == 'session1a':
|
||||
assert session.state == {'key': 'value1a'}
|
||||
else:
|
||||
assert session.state == {'key': 'value1b'}
|
||||
|
||||
# List sessions for user2
|
||||
list_sessions_response_2 = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id_2
|
||||
)
|
||||
sessions_2 = list_sessions_response_2.sessions
|
||||
assert len(sessions_2) == 1
|
||||
assert sessions_2[0].id == 'session2a'
|
||||
assert sessions_2[0].state == {'key': 'value2a'}
|
||||
|
||||
# List sessions for all users
|
||||
list_sessions_response_all = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=None
|
||||
)
|
||||
sessions_all = list_sessions_response_all.sessions
|
||||
assert len(sessions_all) == 3
|
||||
assert {s.id for s in sessions_all} == {'session1a', 'session1b', 'session2a'}
|
||||
for i in range(len(sessions)):
|
||||
assert sessions[i].id == session_ids[i]
|
||||
assert sessions[i].state == {'key': 'value' + session_ids[i]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -252,22 +252,19 @@ class MockApiClient:
|
||||
def _list_sessions(self, name: str, config: dict[str, Any]):
|
||||
filter_val = config.get('filter', '')
|
||||
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
|
||||
if user_id_match:
|
||||
user_id = user_id_match.group(1)
|
||||
if user_id == 'user_with_pages':
|
||||
return [
|
||||
_convert_to_object(MOCK_SESSION_JSON_PAGE1),
|
||||
_convert_to_object(MOCK_SESSION_JSON_PAGE2),
|
||||
]
|
||||
return [
|
||||
_convert_to_object(session)
|
||||
for session in self.session_dict.values()
|
||||
if session['user_id'] == user_id
|
||||
]
|
||||
if not user_id_match:
|
||||
raise ValueError(f'Could not find user_id in filter: {filter_val}')
|
||||
user_id = user_id_match.group(1)
|
||||
|
||||
# No user filter, return all sessions
|
||||
if user_id == 'user_with_pages':
|
||||
return [
|
||||
_convert_to_object(MOCK_SESSION_JSON_PAGE1),
|
||||
_convert_to_object(MOCK_SESSION_JSON_PAGE2),
|
||||
]
|
||||
return [
|
||||
_convert_to_object(session) for session in self.session_dict.values()
|
||||
_convert_to_object(session)
|
||||
for session in self.session_dict.values()
|
||||
if session['user_id'] == user_id
|
||||
]
|
||||
|
||||
def _delete_session(self, name: str):
|
||||
@@ -478,15 +475,6 @@ async def test_list_sessions_with_pagination():
|
||||
assert sessions.sessions[1].id == 'page2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
async def test_list_sessions_all_users():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = await session_service.list_sessions(app_name='123', user_id=None)
|
||||
assert len(sessions.sessions) == 5
|
||||
assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
async def test_create_session():
|
||||
|
||||
Reference in New Issue
Block a user