You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Refactor and fix state management in the session service
Also refactoring the test cases to focus on the expected behaviors PiperOrigin-RevId: 820734484
This commit is contained in:
committed by
Copybara-Service
parent
cf3403231d
commit
8b3ed059c2
@@ -19,6 +19,8 @@ from typing import Optional
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
from .state import State
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
@@ -29,3 +31,19 @@ def decode_model(
|
||||
if data is None:
|
||||
return None
|
||||
return model_cls.model_validate(data)
|
||||
|
||||
|
||||
def extract_state_delta(
|
||||
state: dict[str, Any],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Extracts app, user, and session state deltas from a state dictionary."""
|
||||
deltas = {"app": {}, "user": {}, "session": {}}
|
||||
if state:
|
||||
for key in state.keys():
|
||||
if key.startswith(State.APP_PREFIX):
|
||||
deltas["app"][key.removeprefix(State.APP_PREFIX)] = state[key]
|
||||
elif key.startswith(State.USER_PREFIX):
|
||||
deltas["user"][key.removeprefix(State.USER_PREFIX)] = state[key]
|
||||
elif not key.startswith(State.TEMP_PREFIX):
|
||||
deltas["session"][key] = state[key]
|
||||
return deltas
|
||||
|
||||
@@ -465,20 +465,14 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# 5. Return the session
|
||||
|
||||
with self.database_session_factory() as sql_session:
|
||||
|
||||
# Fetch app and user states from storage
|
||||
storage_app_state = sql_session.get(StorageAppState, (app_name))
|
||||
storage_user_state = sql_session.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
user_state = storage_user_state.state if storage_user_state else {}
|
||||
|
||||
# Create state tables if not exist
|
||||
if not storage_app_state:
|
||||
storage_app_state = StorageAppState(app_name=app_name, state={})
|
||||
sql_session.add(storage_app_state)
|
||||
storage_user_state = sql_session.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
if not storage_user_state:
|
||||
storage_user_state = StorageUserState(
|
||||
app_name=app_name, user_id=user_id, state={}
|
||||
@@ -486,19 +480,16 @@ class DatabaseSessionService(BaseSessionService):
|
||||
sql_session.add(storage_user_state)
|
||||
|
||||
# Extract state deltas
|
||||
app_state_delta, user_state_delta, session_state = _extract_state_delta(
|
||||
state
|
||||
)
|
||||
state_deltas = _session_util.extract_state_delta(state)
|
||||
app_state_delta = state_deltas["app"]
|
||||
user_state_delta = state_deltas["user"]
|
||||
session_state = state_deltas["session"]
|
||||
|
||||
# Apply state delta
|
||||
app_state.update(app_state_delta)
|
||||
user_state.update(user_state_delta)
|
||||
|
||||
# Store app and user state
|
||||
if app_state_delta:
|
||||
storage_app_state.state = app_state
|
||||
storage_app_state.state = storage_app_state.state | app_state_delta
|
||||
if user_state_delta:
|
||||
storage_user_state.state = user_state
|
||||
storage_user_state.state = storage_user_state.state | user_state_delta
|
||||
|
||||
# Store the session
|
||||
storage_session = StorageSession(
|
||||
@@ -513,7 +504,9 @@ class DatabaseSessionService(BaseSessionService):
|
||||
sql_session.refresh(storage_session)
|
||||
|
||||
# Merge states for response
|
||||
merged_state = _merge_state(app_state, user_state, session_state)
|
||||
merged_state = _merge_state(
|
||||
storage_app_state.state, storage_user_state.state, session_state
|
||||
)
|
||||
session = storage_session.to_session(state=merged_state)
|
||||
return session
|
||||
|
||||
@@ -536,19 +529,18 @@ class DatabaseSessionService(BaseSessionService):
|
||||
if storage_session is None:
|
||||
return None
|
||||
|
||||
query = sql_session.query(StorageEvent).filter(
|
||||
StorageEvent.app_name == app_name,
|
||||
StorageEvent.user_id == user_id,
|
||||
StorageEvent.session_id == storage_session.id,
|
||||
)
|
||||
|
||||
if config and config.after_timestamp:
|
||||
after_dt = datetime.fromtimestamp(config.after_timestamp)
|
||||
timestamp_filter = StorageEvent.timestamp >= after_dt
|
||||
else:
|
||||
timestamp_filter = True
|
||||
query = query.filter(StorageEvent.timestamp >= after_dt)
|
||||
|
||||
storage_events = (
|
||||
sql_session.query(StorageEvent)
|
||||
.filter(StorageEvent.app_name == app_name)
|
||||
.filter(StorageEvent.session_id == storage_session.id)
|
||||
.filter(StorageEvent.user_id == user_id)
|
||||
.filter(timestamp_filter)
|
||||
.order_by(StorageEvent.timestamp.desc())
|
||||
query.order_by(StorageEvent.timestamp.desc())
|
||||
.limit(
|
||||
config.num_recent_events
|
||||
if config and config.num_recent_events
|
||||
@@ -660,30 +652,21 @@ class DatabaseSessionService(BaseSessionService):
|
||||
StorageUserState, (session.app_name, session.user_id)
|
||||
)
|
||||
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
user_state = storage_user_state.state if storage_user_state else {}
|
||||
session_state = storage_session.state
|
||||
|
||||
# Extract state delta
|
||||
app_state_delta = {}
|
||||
user_state_delta = {}
|
||||
session_state_delta = {}
|
||||
if event.actions:
|
||||
if event.actions.state_delta:
|
||||
app_state_delta, user_state_delta, session_state_delta = (
|
||||
_extract_state_delta(event.actions.state_delta)
|
||||
)
|
||||
|
||||
# Merge state and update storage
|
||||
if app_state_delta:
|
||||
app_state.update(app_state_delta)
|
||||
storage_app_state.state = app_state
|
||||
if user_state_delta:
|
||||
user_state.update(user_state_delta)
|
||||
storage_user_state.state = user_state
|
||||
if session_state_delta:
|
||||
session_state.update(session_state_delta)
|
||||
storage_session.state = session_state
|
||||
if event.actions and event.actions.state_delta:
|
||||
state_deltas = _session_util.extract_state_delta(
|
||||
event.actions.state_delta
|
||||
)
|
||||
app_state_delta = state_deltas["app"]
|
||||
user_state_delta = state_deltas["user"]
|
||||
session_state_delta = state_deltas["session"]
|
||||
# Merge state and update storage
|
||||
if app_state_delta:
|
||||
storage_app_state.state = storage_app_state.state | app_state_delta
|
||||
if user_state_delta:
|
||||
storage_user_state.state = storage_user_state.state | user_state_delta
|
||||
if session_state_delta:
|
||||
storage_session.state = storage_session.state | session_state_delta
|
||||
|
||||
sql_session.add(StorageEvent.from_event(session, event))
|
||||
|
||||
@@ -698,21 +681,6 @@ class DatabaseSessionService(BaseSessionService):
|
||||
return event
|
||||
|
||||
|
||||
def _extract_state_delta(state: dict[str, Any]):
|
||||
app_state_delta = {}
|
||||
user_state_delta = {}
|
||||
session_state_delta = {}
|
||||
if state:
|
||||
for key in state.keys():
|
||||
if key.startswith(State.APP_PREFIX):
|
||||
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
|
||||
elif key.startswith(State.USER_PREFIX):
|
||||
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
|
||||
elif not key.startswith(State.TEMP_PREFIX):
|
||||
session_state_delta[key] = state[key]
|
||||
return app_state_delta, user_state_delta, session_state_delta
|
||||
|
||||
|
||||
def _merge_state(app_state, user_state, session_state):
|
||||
# Merge states for response
|
||||
merged_state = copy.deepcopy(session_state)
|
||||
|
||||
@@ -22,6 +22,7 @@ import uuid
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from . import _session_util
|
||||
from ..events.event import Event
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
@@ -88,6 +89,17 @@ class InMemorySessionService(BaseSessionService):
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Session:
|
||||
state_deltas = _session_util.extract_state_delta(state)
|
||||
app_state_delta = state_deltas['app']
|
||||
user_state_delta = state_deltas['user']
|
||||
session_state = state_deltas['session']
|
||||
if app_state_delta:
|
||||
self.app_state.setdefault(app_name, {}).update(app_state_delta)
|
||||
if user_state_delta:
|
||||
self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update(
|
||||
user_state_delta
|
||||
)
|
||||
|
||||
session_id = (
|
||||
session_id.strip()
|
||||
if session_id and session_id.strip()
|
||||
@@ -97,7 +109,7 @@ class InMemorySessionService(BaseSessionService):
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=state or {},
|
||||
state=session_state or {},
|
||||
last_update_time=time.time(),
|
||||
)
|
||||
|
||||
@@ -174,11 +186,13 @@ class InMemorySessionService(BaseSessionService):
|
||||
if i >= 0:
|
||||
copied_session.events = copied_session.events[i + 1 :]
|
||||
|
||||
# Return a copy of the session object with merged state.
|
||||
return self._merge_state(app_name, user_id, copied_session)
|
||||
|
||||
def _merge_state(
|
||||
self, app_name: str, user_id: str, copied_session: Session
|
||||
) -> Session:
|
||||
"""Merges app and user state into session state."""
|
||||
# Merge app state
|
||||
if app_name in self.app_state:
|
||||
for key in self.app_state[app_name].keys():
|
||||
@@ -269,11 +283,9 @@ class InMemorySessionService(BaseSessionService):
|
||||
|
||||
@override
|
||||
async def append_event(self, session: Session, event: Event) -> Event:
|
||||
# Update the in-memory session.
|
||||
await super().append_event(session=session, event=event)
|
||||
session.last_update_time = event.timestamp
|
||||
if event.partial:
|
||||
return event
|
||||
|
||||
# Update the storage session
|
||||
app_name = session.app_name
|
||||
user_id = session.user_id
|
||||
session_id = session.id
|
||||
@@ -293,21 +305,29 @@ class InMemorySessionService(BaseSessionService):
|
||||
_warning(f'session_id {session_id} not in sessions[app_name][user_id]')
|
||||
return event
|
||||
|
||||
if event.actions and event.actions.state_delta:
|
||||
for key in event.actions.state_delta:
|
||||
if key.startswith(State.APP_PREFIX):
|
||||
self.app_state.setdefault(app_name, {})[
|
||||
key.removeprefix(State.APP_PREFIX)
|
||||
] = event.actions.state_delta[key]
|
||||
|
||||
if key.startswith(State.USER_PREFIX):
|
||||
self.user_state.setdefault(app_name, {}).setdefault(user_id, {})[
|
||||
key.removeprefix(State.USER_PREFIX)
|
||||
] = event.actions.state_delta[key]
|
||||
# Update the in-memory session.
|
||||
await super().append_event(session=session, event=event)
|
||||
session.last_update_time = event.timestamp
|
||||
|
||||
# Update the storage session
|
||||
storage_session = self.sessions[app_name][user_id].get(session_id)
|
||||
await super().append_event(session=storage_session, event=event)
|
||||
|
||||
storage_session.events.append(event)
|
||||
storage_session.last_update_time = event.timestamp
|
||||
|
||||
if event.actions and event.actions.state_delta:
|
||||
state_deltas = _session_util.extract_state_delta(
|
||||
event.actions.state_delta
|
||||
)
|
||||
app_state_delta = state_deltas['app']
|
||||
user_state_delta = state_deltas['user']
|
||||
session_state_delta = state_deltas['session']
|
||||
if app_state_delta:
|
||||
self.app_state.setdefault(app_name, {}).update(app_state_delta)
|
||||
if user_state_delta:
|
||||
self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update(
|
||||
user_state_delta
|
||||
)
|
||||
if session_state_delta:
|
||||
storage_session.state.update(session_state_delta)
|
||||
|
||||
return event
|
||||
|
||||
@@ -90,7 +90,7 @@ async def test_create_get_session(service_type):
|
||||
await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
!= session
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
@@ -151,20 +151,17 @@ async def test_list_sessions_all_users(service_type):
|
||||
state={'key': 'value2a'},
|
||||
)
|
||||
|
||||
# List sessions for user1
|
||||
# List sessions for user1 - should contain merged state
|
||||
list_sessions_response_1 = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id_1
|
||||
)
|
||||
sessions_1 = list_sessions_response_1.sessions
|
||||
assert len(sessions_1) == 2
|
||||
assert {s.id for s in sessions_1} == {'session1a', 'session1b'}
|
||||
for session in sessions_1:
|
||||
if session.id == 'session1a':
|
||||
assert session.state == {'key': 'value1a'}
|
||||
else:
|
||||
assert session.state == {'key': 'value1b'}
|
||||
sessions_1_map = {s.id: s for s in sessions_1}
|
||||
assert sessions_1_map['session1a'].state == {'key': 'value1a'}
|
||||
assert sessions_1_map['session1b'].state == {'key': 'value1b'}
|
||||
|
||||
# List sessions for user2
|
||||
# List sessions for user2 - should contain merged state
|
||||
list_sessions_response_2 = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id_2
|
||||
)
|
||||
@@ -173,151 +170,170 @@ async def test_list_sessions_all_users(service_type):
|
||||
assert sessions_2[0].id == 'session2a'
|
||||
assert sessions_2[0].state == {'key': 'value2a'}
|
||||
|
||||
# List sessions for all users
|
||||
# List sessions for all users - should contain merged state
|
||||
list_sessions_response_all = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=None
|
||||
)
|
||||
sessions_all = list_sessions_response_all.sessions
|
||||
assert len(sessions_all) == 3
|
||||
assert {s.id for s in sessions_all} == {'session1a', 'session1b', 'session2a'}
|
||||
sessions_all_map = {s.id: s for s in sessions_all}
|
||||
assert sessions_all_map['session1a'].state == {'key': 'value1a'}
|
||||
assert sessions_all_map['session1b'].state == {'key': 'value1b'}
|
||||
assert sessions_all_map['session2a'].state == {'key': 'value2a'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
async def test_session_state(service_type):
|
||||
async def test_app_state_is_shared_by_all_users_of_app(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id_1 = 'user1'
|
||||
user_id_2 = 'user2'
|
||||
user_id_malicious = 'malicious'
|
||||
session_id_11 = 'session11'
|
||||
session_id_12 = 'session12'
|
||||
session_id_2 = 'session2'
|
||||
state_11 = {'key11': 'value11'}
|
||||
state_12 = {'key12': 'value12'}
|
||||
|
||||
session_11 = await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
state=state_11,
|
||||
session_id=session_id_11,
|
||||
# User 1 creates a session, establishing app:k1
|
||||
session1 = await session_service.create_session(
|
||||
app_name=app_name, user_id='u1', session_id='s1', state={'app:k1': 'v1'}
|
||||
)
|
||||
await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
state=state_12,
|
||||
session_id=session_id_12,
|
||||
)
|
||||
await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id_2, session_id=session_id_2
|
||||
)
|
||||
|
||||
await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id_malicious, session_id=session_id_11
|
||||
)
|
||||
|
||||
assert session_11.state.get('key11') == 'value11'
|
||||
|
||||
# User 1 appends an event to session1, establishing app:k2
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
invocation_id='inv1',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='text')]),
|
||||
actions=EventActions(
|
||||
state_delta={
|
||||
'app:key': 'value',
|
||||
'user:key1': 'value1',
|
||||
'temp:key': 'temp',
|
||||
'key11': 'value11_new',
|
||||
}
|
||||
),
|
||||
actions=EventActions(state_delta={'app:k2': 'v2'}),
|
||||
)
|
||||
await session_service.append_event(session=session_11, event=event)
|
||||
await session_service.append_event(session=session1, event=event)
|
||||
|
||||
# User and app state is stored, temp state is filtered.
|
||||
assert session_11.state.get('app:key') == 'value'
|
||||
assert session_11.state.get('key11') == 'value11_new'
|
||||
assert session_11.state.get('user:key1') == 'value1'
|
||||
assert not session_11.state.get('temp:key')
|
||||
|
||||
session_12 = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_id_12
|
||||
# User 2 creates a new session session2, it should see app:k1 and app:k2
|
||||
session2 = await session_service.create_session(
|
||||
app_name=app_name, user_id='u2', session_id='s2'
|
||||
)
|
||||
# After getting a new instance, the session_12 got the user and app state,
|
||||
# even append_event is not applied to it, temp state has no effect
|
||||
assert session_12.state.get('key12') == 'value12'
|
||||
assert not session_12.state.get('temp:key')
|
||||
assert session2.state == {'app:k1': 'v1', 'app:k2': 'v2'}
|
||||
|
||||
# The user1's state is not visible to user2, app state is visible
|
||||
session_2 = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_2, session_id=session_id_2
|
||||
# If we get session session1 again, it should also see both
|
||||
session1_got = await session_service.get_session(
|
||||
app_name=app_name, user_id='u1', session_id='s1'
|
||||
)
|
||||
assert session_2.state.get('app:key') == 'value'
|
||||
assert not session_2.state.get('user:key1')
|
||||
|
||||
assert not session_2.state.get('user:key1')
|
||||
|
||||
# The change to session_11 is persisted
|
||||
session_11 = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_id_11
|
||||
)
|
||||
assert session_11.state.get('key11') == 'value11_new'
|
||||
assert session_11.state.get('user:key1') == 'value1'
|
||||
assert not session_11.state.get('temp:key')
|
||||
|
||||
# Make sure a malicious user cannot obtain a session and events not belonging to them
|
||||
session_mismatch = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_malicious, session_id=session_id_11
|
||||
)
|
||||
|
||||
assert len(session_mismatch.events) == 0
|
||||
assert session1_got.state.get('app:k1') == 'v1'
|
||||
assert session1_got.state.get('app:k2') == 'v2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
async def test_create_new_session_will_merge_states(service_type):
|
||||
async def test_user_state_is_shared_only_by_user_sessions(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
session_id_1 = 'session1'
|
||||
session_id_2 = 'session2'
|
||||
state_1 = {'key1': 'value1'}
|
||||
|
||||
session_1 = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
|
||||
# User 1 creates a session, establishing user:k1 for user 1
|
||||
session1 = await session_service.create_session(
|
||||
app_name=app_name, user_id='u1', session_id='s1', state={'user:k1': 'v1'}
|
||||
)
|
||||
|
||||
# User 1 appends an event to session1, establishing user:k2 for user 1
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
invocation_id='inv1',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='text')]),
|
||||
actions=EventActions(
|
||||
state_delta={
|
||||
'app:key': 'value',
|
||||
'user:key1': 'value1',
|
||||
'temp:key': 'temp',
|
||||
}
|
||||
),
|
||||
actions=EventActions(state_delta={'user:k2': 'v2'}),
|
||||
)
|
||||
await session_service.append_event(session=session_1, event=event)
|
||||
await session_service.append_event(session=session1, event=event)
|
||||
|
||||
# User and app state is stored, temp state is filtered.
|
||||
assert session_1.state.get('app:key') == 'value'
|
||||
assert session_1.state.get('key1') == 'value1'
|
||||
assert session_1.state.get('user:key1') == 'value1'
|
||||
assert not session_1.state.get('temp:key')
|
||||
|
||||
session_2 = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
|
||||
# Another session for User 1 should see user:k1 and user:k2
|
||||
session1b = await session_service.create_session(
|
||||
app_name=app_name, user_id='u1', session_id='s1b'
|
||||
)
|
||||
# Session 2 has the persisted states
|
||||
assert session_2.state.get('app:key') == 'value'
|
||||
assert session_2.state.get('user:key1') == 'value1'
|
||||
assert not session_2.state.get('key1')
|
||||
assert not session_2.state.get('temp:key')
|
||||
assert session1b.state == {'user:k1': 'v1', 'user:k2': 'v2'}
|
||||
|
||||
# A session for User 2 should NOT see user:k1 or user:k2
|
||||
session2 = await session_service.create_session(
|
||||
app_name=app_name, user_id='u2', session_id='s2'
|
||||
)
|
||||
assert session2.state == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
async def test_session_state_is_not_shared(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
# User 1 creates a session session1, establishing sk1 only for session1
|
||||
session1 = await session_service.create_session(
|
||||
app_name=app_name, user_id='u1', session_id='s1', state={'sk1': 'v1'}
|
||||
)
|
||||
# User 1 appends an event to session1, establishing sk2 only for session1
|
||||
event = Event(
|
||||
invocation_id='inv1',
|
||||
author='user',
|
||||
actions=EventActions(state_delta={'sk2': 'v2'}),
|
||||
)
|
||||
await session_service.append_event(session=session1, event=event)
|
||||
|
||||
# Getting session1 should show sk1 and sk2
|
||||
session1_got = await session_service.get_session(
|
||||
app_name=app_name, user_id='u1', session_id='s1'
|
||||
)
|
||||
assert session1_got.state.get('sk1') == 'v1'
|
||||
assert session1_got.state.get('sk2') == 'v2'
|
||||
|
||||
# Creating another session session1b for User 1 should NOT see sk1 or sk2
|
||||
session1b = await session_service.create_session(
|
||||
app_name=app_name, user_id='u1', session_id='s1b'
|
||||
)
|
||||
assert session1b.state == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
async def test_temp_state_is_not_persisted_in_state_or_events(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'u1'
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, session_id='s1'
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='inv1',
|
||||
author='user',
|
||||
actions=EventActions(state_delta={'temp:k1': 'v1', 'sk': 'v2'}),
|
||||
)
|
||||
await session_service.append_event(session=session, event=event)
|
||||
|
||||
# Refetch session and check state and event
|
||||
session_got = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id='s1'
|
||||
)
|
||||
# Check session state does not contain temp keys
|
||||
assert session_got.state.get('sk') == 'v2'
|
||||
assert 'temp:k1' not in session_got.state
|
||||
# Check event as stored in session does not contain temp keys in state_delta
|
||||
assert 'temp:k1' not in session_got.events[0].actions.state_delta
|
||||
assert session_got.events[0].actions.state_delta.get('sk') == 'v2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
async def test_get_session_respects_user_id(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
# u1 creates session 's1' and adds an event
|
||||
session1 = await session_service.create_session(
|
||||
app_name=app_name, user_id='u1', session_id='s1'
|
||||
)
|
||||
event = Event(invocation_id='inv1', author='user')
|
||||
await session_service.append_event(session1, event)
|
||||
# u2 creates a session with the same session_id 's1'
|
||||
await session_service.create_session(
|
||||
app_name=app_name, user_id='u2', session_id='s1'
|
||||
)
|
||||
# Check that getting s1 for u2 returns u2's session (with no events)
|
||||
# not u1's session.
|
||||
session2_got = await session_service.get_session(
|
||||
app_name=app_name, user_id='u2', session_id='s1'
|
||||
)
|
||||
assert session2_got.user_id == 'u2'
|
||||
assert len(session2_got.events) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -390,6 +406,9 @@ async def test_append_event_complete(service_type):
|
||||
error_code='error_code',
|
||||
error_message='error_message',
|
||||
interrupted=True,
|
||||
grounding_metadata=types.GroundingMetadata(
|
||||
web_search_queries=['query1'],
|
||||
),
|
||||
usage_metadata=types.GenerateContentResponseUsageMetadata(
|
||||
prompt_token_count=1, candidates_token_count=1, total_token_count=2
|
||||
),
|
||||
@@ -474,72 +493,20 @@ async def test_get_session_with_config(service_type):
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
async def test_append_event_with_fields(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state={}
|
||||
)
|
||||
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='text')]),
|
||||
long_running_tool_ids={'tool1', 'tool2'},
|
||||
partial=False,
|
||||
turn_complete=True,
|
||||
error_code='ERROR_CODE',
|
||||
error_message='error message',
|
||||
interrupted=True,
|
||||
grounding_metadata=types.GroundingMetadata(
|
||||
web_search_queries=['query1'],
|
||||
),
|
||||
custom_metadata={'custom_key': 'custom_value'},
|
||||
)
|
||||
await session_service.append_event(session, event)
|
||||
|
||||
retrieved_session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert retrieved_session
|
||||
assert len(retrieved_session.events) == 1
|
||||
retrieved_event = retrieved_session.events[0]
|
||||
|
||||
assert retrieved_event == event
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
async def test_append_event_should_trim_temp_delta_state(service_type):
|
||||
async def test_partial_events_are_not_persisted(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='text')]),
|
||||
actions=EventActions(
|
||||
state_delta={
|
||||
'app:key': 'app_value',
|
||||
'temp:key': 'temp_value',
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
event = Event(author='user', partial=True)
|
||||
await session_service.append_event(session, event)
|
||||
|
||||
updated_session = await session_service.get_session(
|
||||
# Check in-memory session
|
||||
assert len(session.events) == 0
|
||||
# Check persisted session
|
||||
session_got = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
|
||||
last_event = updated_session.events[-1]
|
||||
assert 'temp:key' not in last_event.actions.state_delta
|
||||
assert last_event.actions.state_delta['app:key'] == 'app_value'
|
||||
assert len(session_got.events) == 0
|
||||
|
||||
Reference in New Issue
Block a user