chore: Refactor and fix state management in the session service

Also refactoring the test cases to focus on the expected behaviors

PiperOrigin-RevId: 820734484
This commit is contained in:
Shangjie Chen
2025-10-17 10:03:58 -07:00
committed by Copybara-Service
parent cf3403231d
commit 8b3ed059c2
4 changed files with 235 additions and 262 deletions
+18
View File
@@ -19,6 +19,8 @@ from typing import Optional
from typing import Type
from typing import TypeVar
from .state import State
M = TypeVar("M")
@@ -29,3 +31,19 @@ def decode_model(
if data is None:
return None
return model_cls.model_validate(data)
def extract_state_delta(
state: dict[str, Any],
) -> dict[str, dict[str, Any]]:
"""Extracts app, user, and session state deltas from a state dictionary."""
deltas = {"app": {}, "user": {}, "session": {}}
if state:
for key in state.keys():
if key.startswith(State.APP_PREFIX):
deltas["app"][key.removeprefix(State.APP_PREFIX)] = state[key]
elif key.startswith(State.USER_PREFIX):
deltas["user"][key.removeprefix(State.USER_PREFIX)] = state[key]
elif not key.startswith(State.TEMP_PREFIX):
deltas["session"][key] = state[key]
return deltas
@@ -465,20 +465,14 @@ class DatabaseSessionService(BaseSessionService):
# 5. Return the session
with self.database_session_factory() as sql_session:
# Fetch app and user states from storage
storage_app_state = sql_session.get(StorageAppState, (app_name))
storage_user_state = sql_session.get(
StorageUserState, (app_name, user_id)
)
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
# Create state tables if not exist
if not storage_app_state:
storage_app_state = StorageAppState(app_name=app_name, state={})
sql_session.add(storage_app_state)
storage_user_state = sql_session.get(
StorageUserState, (app_name, user_id)
)
if not storage_user_state:
storage_user_state = StorageUserState(
app_name=app_name, user_id=user_id, state={}
@@ -486,19 +480,16 @@ class DatabaseSessionService(BaseSessionService):
sql_session.add(storage_user_state)
# Extract state deltas
app_state_delta, user_state_delta, session_state = _extract_state_delta(
state
)
state_deltas = _session_util.extract_state_delta(state)
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state = state_deltas["session"]
# Apply state delta
app_state.update(app_state_delta)
user_state.update(user_state_delta)
# Store app and user state
if app_state_delta:
storage_app_state.state = app_state
storage_app_state.state = storage_app_state.state | app_state_delta
if user_state_delta:
storage_user_state.state = user_state
storage_user_state.state = storage_user_state.state | user_state_delta
# Store the session
storage_session = StorageSession(
@@ -513,7 +504,9 @@ class DatabaseSessionService(BaseSessionService):
sql_session.refresh(storage_session)
# Merge states for response
merged_state = _merge_state(app_state, user_state, session_state)
merged_state = _merge_state(
storage_app_state.state, storage_user_state.state, session_state
)
session = storage_session.to_session(state=merged_state)
return session
@@ -536,19 +529,18 @@ class DatabaseSessionService(BaseSessionService):
if storage_session is None:
return None
query = sql_session.query(StorageEvent).filter(
StorageEvent.app_name == app_name,
StorageEvent.user_id == user_id,
StorageEvent.session_id == storage_session.id,
)
if config and config.after_timestamp:
after_dt = datetime.fromtimestamp(config.after_timestamp)
timestamp_filter = StorageEvent.timestamp >= after_dt
else:
timestamp_filter = True
query = query.filter(StorageEvent.timestamp >= after_dt)
storage_events = (
sql_session.query(StorageEvent)
.filter(StorageEvent.app_name == app_name)
.filter(StorageEvent.session_id == storage_session.id)
.filter(StorageEvent.user_id == user_id)
.filter(timestamp_filter)
.order_by(StorageEvent.timestamp.desc())
query.order_by(StorageEvent.timestamp.desc())
.limit(
config.num_recent_events
if config and config.num_recent_events
@@ -660,30 +652,21 @@ class DatabaseSessionService(BaseSessionService):
StorageUserState, (session.app_name, session.user_id)
)
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
session_state = storage_session.state
# Extract state delta
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
if event.actions:
if event.actions.state_delta:
app_state_delta, user_state_delta, session_state_delta = (
_extract_state_delta(event.actions.state_delta)
)
# Merge state and update storage
if app_state_delta:
app_state.update(app_state_delta)
storage_app_state.state = app_state
if user_state_delta:
user_state.update(user_state_delta)
storage_user_state.state = user_state
if session_state_delta:
session_state.update(session_state_delta)
storage_session.state = session_state
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(
event.actions.state_delta
)
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state_delta = state_deltas["session"]
# Merge state and update storage
if app_state_delta:
storage_app_state.state = storage_app_state.state | app_state_delta
if user_state_delta:
storage_user_state.state = storage_user_state.state | user_state_delta
if session_state_delta:
storage_session.state = storage_session.state | session_state_delta
sql_session.add(StorageEvent.from_event(session, event))
@@ -698,21 +681,6 @@ class DatabaseSessionService(BaseSessionService):
return event
def _extract_state_delta(state: dict[str, Any]):
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
if state:
for key in state.keys():
if key.startswith(State.APP_PREFIX):
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
elif key.startswith(State.USER_PREFIX):
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
elif not key.startswith(State.TEMP_PREFIX):
session_state_delta[key] = state[key]
return app_state_delta, user_state_delta, session_state_delta
def _merge_state(app_state, user_state, session_state):
# Merge states for response
merged_state = copy.deepcopy(session_state)
@@ -22,6 +22,7 @@ import uuid
from typing_extensions import override
from . import _session_util
from ..events.event import Event
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
@@ -88,6 +89,17 @@ class InMemorySessionService(BaseSessionService):
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
state_deltas = _session_util.extract_state_delta(state)
app_state_delta = state_deltas['app']
user_state_delta = state_deltas['user']
session_state = state_deltas['session']
if app_state_delta:
self.app_state.setdefault(app_name, {}).update(app_state_delta)
if user_state_delta:
self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update(
user_state_delta
)
session_id = (
session_id.strip()
if session_id and session_id.strip()
@@ -97,7 +109,7 @@ class InMemorySessionService(BaseSessionService):
app_name=app_name,
user_id=user_id,
id=session_id,
state=state or {},
state=session_state or {},
last_update_time=time.time(),
)
@@ -174,11 +186,13 @@ class InMemorySessionService(BaseSessionService):
if i >= 0:
copied_session.events = copied_session.events[i + 1 :]
# Return a copy of the session object with merged state.
return self._merge_state(app_name, user_id, copied_session)
def _merge_state(
self, app_name: str, user_id: str, copied_session: Session
) -> Session:
"""Merges app and user state into session state."""
# Merge app state
if app_name in self.app_state:
for key in self.app_state[app_name].keys():
@@ -269,11 +283,9 @@ class InMemorySessionService(BaseSessionService):
@override
async def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
if event.partial:
return event
# Update the storage session
app_name = session.app_name
user_id = session.user_id
session_id = session.id
@@ -293,21 +305,29 @@ class InMemorySessionService(BaseSessionService):
_warning(f'session_id {session_id} not in sessions[app_name][user_id]')
return event
if event.actions and event.actions.state_delta:
for key in event.actions.state_delta:
if key.startswith(State.APP_PREFIX):
self.app_state.setdefault(app_name, {})[
key.removeprefix(State.APP_PREFIX)
] = event.actions.state_delta[key]
if key.startswith(State.USER_PREFIX):
self.user_state.setdefault(app_name, {}).setdefault(user_id, {})[
key.removeprefix(State.USER_PREFIX)
] = event.actions.state_delta[key]
# Update the in-memory session.
await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Update the storage session
storage_session = self.sessions[app_name][user_id].get(session_id)
await super().append_event(session=storage_session, event=event)
storage_session.events.append(event)
storage_session.last_update_time = event.timestamp
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(
event.actions.state_delta
)
app_state_delta = state_deltas['app']
user_state_delta = state_deltas['user']
session_state_delta = state_deltas['session']
if app_state_delta:
self.app_state.setdefault(app_name, {}).update(app_state_delta)
if user_state_delta:
self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update(
user_state_delta
)
if session_state_delta:
storage_session.state.update(session_state_delta)
return event
+145 -178
View File
@@ -90,7 +90,7 @@ async def test_create_get_session(service_type):
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
!= session
is None
)
@@ -151,20 +151,17 @@ async def test_list_sessions_all_users(service_type):
state={'key': 'value2a'},
)
# List sessions for user1
# List sessions for user1 - should contain merged state
list_sessions_response_1 = await session_service.list_sessions(
app_name=app_name, user_id=user_id_1
)
sessions_1 = list_sessions_response_1.sessions
assert len(sessions_1) == 2
assert {s.id for s in sessions_1} == {'session1a', 'session1b'}
for session in sessions_1:
if session.id == 'session1a':
assert session.state == {'key': 'value1a'}
else:
assert session.state == {'key': 'value1b'}
sessions_1_map = {s.id: s for s in sessions_1}
assert sessions_1_map['session1a'].state == {'key': 'value1a'}
assert sessions_1_map['session1b'].state == {'key': 'value1b'}
# List sessions for user2
# List sessions for user2 - should contain merged state
list_sessions_response_2 = await session_service.list_sessions(
app_name=app_name, user_id=user_id_2
)
@@ -173,151 +170,170 @@ async def test_list_sessions_all_users(service_type):
assert sessions_2[0].id == 'session2a'
assert sessions_2[0].state == {'key': 'value2a'}
# List sessions for all users
# List sessions for all users - should contain merged state
list_sessions_response_all = await session_service.list_sessions(
app_name=app_name, user_id=None
)
sessions_all = list_sessions_response_all.sessions
assert len(sessions_all) == 3
assert {s.id for s in sessions_all} == {'session1a', 'session1b', 'session2a'}
sessions_all_map = {s.id: s for s in sessions_all}
assert sessions_all_map['session1a'].state == {'key': 'value1a'}
assert sessions_all_map['session1b'].state == {'key': 'value1b'}
assert sessions_all_map['session2a'].state == {'key': 'value2a'}
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_session_state(service_type):
async def test_app_state_is_shared_by_all_users_of_app(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id_1 = 'user1'
user_id_2 = 'user2'
user_id_malicious = 'malicious'
session_id_11 = 'session11'
session_id_12 = 'session12'
session_id_2 = 'session2'
state_11 = {'key11': 'value11'}
state_12 = {'key12': 'value12'}
session_11 = await session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_11,
session_id=session_id_11,
# User 1 creates a session, establishing app:k1
session1 = await session_service.create_session(
app_name=app_name, user_id='u1', session_id='s1', state={'app:k1': 'v1'}
)
await session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_12,
session_id=session_id_12,
)
await session_service.create_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
)
await session_service.create_session(
app_name=app_name, user_id=user_id_malicious, session_id=session_id_11
)
assert session_11.state.get('key11') == 'value11'
# User 1 appends an event to session1, establishing app:k2
event = Event(
invocation_id='invocation',
invocation_id='inv1',
author='user',
content=types.Content(role='user', parts=[types.Part(text='text')]),
actions=EventActions(
state_delta={
'app:key': 'value',
'user:key1': 'value1',
'temp:key': 'temp',
'key11': 'value11_new',
}
),
actions=EventActions(state_delta={'app:k2': 'v2'}),
)
await session_service.append_event(session=session_11, event=event)
await session_service.append_event(session=session1, event=event)
# User and app state is stored, temp state is filtered.
assert session_11.state.get('app:key') == 'value'
assert session_11.state.get('key11') == 'value11_new'
assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key')
session_12 = await session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_12
# User 2 creates a new session session2, it should see app:k1 and app:k2
session2 = await session_service.create_session(
app_name=app_name, user_id='u2', session_id='s2'
)
# After getting a new instance, the session_12 got the user and app state,
# even append_event is not applied to it, temp state has no effect
assert session_12.state.get('key12') == 'value12'
assert not session_12.state.get('temp:key')
assert session2.state == {'app:k1': 'v1', 'app:k2': 'v2'}
# The user1's state is not visible to user2, app state is visible
session_2 = await session_service.get_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
# If we get session session1 again, it should also see both
session1_got = await session_service.get_session(
app_name=app_name, user_id='u1', session_id='s1'
)
assert session_2.state.get('app:key') == 'value'
assert not session_2.state.get('user:key1')
assert not session_2.state.get('user:key1')
# The change to session_11 is persisted
session_11 = await session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_11
)
assert session_11.state.get('key11') == 'value11_new'
assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key')
# Make sure a malicious user cannot obtain a session and events not belonging to them
session_mismatch = await session_service.get_session(
app_name=app_name, user_id=user_id_malicious, session_id=session_id_11
)
assert len(session_mismatch.events) == 0
assert session1_got.state.get('app:k1') == 'v1'
assert session1_got.state.get('app:k2') == 'v2'
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_create_new_session_will_merge_states(service_type):
async def test_user_state_is_shared_only_by_user_sessions(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session_id_1 = 'session1'
session_id_2 = 'session2'
state_1 = {'key1': 'value1'}
session_1 = await session_service.create_session(
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
# User 1 creates a session, establishing user:k1 for user 1
session1 = await session_service.create_session(
app_name=app_name, user_id='u1', session_id='s1', state={'user:k1': 'v1'}
)
# User 1 appends an event to session1, establishing user:k2 for user 1
event = Event(
invocation_id='invocation',
invocation_id='inv1',
author='user',
content=types.Content(role='user', parts=[types.Part(text='text')]),
actions=EventActions(
state_delta={
'app:key': 'value',
'user:key1': 'value1',
'temp:key': 'temp',
}
),
actions=EventActions(state_delta={'user:k2': 'v2'}),
)
await session_service.append_event(session=session_1, event=event)
await session_service.append_event(session=session1, event=event)
# User and app state is stored, temp state is filtered.
assert session_1.state.get('app:key') == 'value'
assert session_1.state.get('key1') == 'value1'
assert session_1.state.get('user:key1') == 'value1'
assert not session_1.state.get('temp:key')
session_2 = await session_service.create_session(
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
# Another session for User 1 should see user:k1 and user:k2
session1b = await session_service.create_session(
app_name=app_name, user_id='u1', session_id='s1b'
)
# Session 2 has the persisted states
assert session_2.state.get('app:key') == 'value'
assert session_2.state.get('user:key1') == 'value1'
assert not session_2.state.get('key1')
assert not session_2.state.get('temp:key')
assert session1b.state == {'user:k1': 'v1', 'user:k2': 'v2'}
# A session for User 2 should NOT see user:k1 or user:k2
session2 = await session_service.create_session(
app_name=app_name, user_id='u2', session_id='s2'
)
assert session2.state == {}
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_session_state_is_not_shared(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
# User 1 creates a session session1, establishing sk1 only for session1
session1 = await session_service.create_session(
app_name=app_name, user_id='u1', session_id='s1', state={'sk1': 'v1'}
)
# User 1 appends an event to session1, establishing sk2 only for session1
event = Event(
invocation_id='inv1',
author='user',
actions=EventActions(state_delta={'sk2': 'v2'}),
)
await session_service.append_event(session=session1, event=event)
# Getting session1 should show sk1 and sk2
session1_got = await session_service.get_session(
app_name=app_name, user_id='u1', session_id='s1'
)
assert session1_got.state.get('sk1') == 'v1'
assert session1_got.state.get('sk2') == 'v2'
# Creating another session session1b for User 1 should NOT see sk1 or sk2
session1b = await session_service.create_session(
app_name=app_name, user_id='u1', session_id='s1b'
)
assert session1b.state == {}
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_temp_state_is_not_persisted_in_state_or_events(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'u1'
session = await session_service.create_session(
app_name=app_name, user_id=user_id, session_id='s1'
)
event = Event(
invocation_id='inv1',
author='user',
actions=EventActions(state_delta={'temp:k1': 'v1', 'sk': 'v2'}),
)
await session_service.append_event(session=session, event=event)
# Refetch session and check state and event
session_got = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id='s1'
)
# Check session state does not contain temp keys
assert session_got.state.get('sk') == 'v2'
assert 'temp:k1' not in session_got.state
# Check event as stored in session does not contain temp keys in state_delta
assert 'temp:k1' not in session_got.events[0].actions.state_delta
assert session_got.events[0].actions.state_delta.get('sk') == 'v2'
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_get_session_respects_user_id(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
# u1 creates session 's1' and adds an event
session1 = await session_service.create_session(
app_name=app_name, user_id='u1', session_id='s1'
)
event = Event(invocation_id='inv1', author='user')
await session_service.append_event(session1, event)
# u2 creates a session with the same session_id 's1'
await session_service.create_session(
app_name=app_name, user_id='u2', session_id='s1'
)
# Check that getting s1 for u2 returns u2's session (with no events)
# not u1's session.
session2_got = await session_service.get_session(
app_name=app_name, user_id='u2', session_id='s1'
)
assert session2_got.user_id == 'u2'
assert len(session2_got.events) == 0
@pytest.mark.asyncio
@@ -390,6 +406,9 @@ async def test_append_event_complete(service_type):
error_code='error_code',
error_message='error_message',
interrupted=True,
grounding_metadata=types.GroundingMetadata(
web_search_queries=['query1'],
),
usage_metadata=types.GenerateContentResponseUsageMetadata(
prompt_token_count=1, candidates_token_count=1, total_token_count=2
),
@@ -474,72 +493,20 @@ async def test_get_session_with_config(service_type):
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_append_event_with_fields(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user'
session = await session_service.create_session(
app_name=app_name, user_id=user_id, state={}
)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(role='user', parts=[types.Part(text='text')]),
long_running_tool_ids={'tool1', 'tool2'},
partial=False,
turn_complete=True,
error_code='ERROR_CODE',
error_message='error message',
interrupted=True,
grounding_metadata=types.GroundingMetadata(
web_search_queries=['query1'],
),
custom_metadata={'custom_key': 'custom_value'},
)
await session_service.append_event(session, event)
retrieved_session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert retrieved_session
assert len(retrieved_session.events) == 1
retrieved_event = retrieved_session.events[0]
assert retrieved_event == event
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_append_event_should_trim_temp_delta_state(service_type):
async def test_partial_events_are_not_persisted(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(role='user', parts=[types.Part(text='text')]),
actions=EventActions(
state_delta={
'app:key': 'app_value',
'temp:key': 'temp_value',
}
),
)
event = Event(author='user', partial=True)
await session_service.append_event(session, event)
updated_session = await session_service.get_session(
# Check in-memory session
assert len(session.events) == 0
# Check persisted session
session_got = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
last_event = updated_session.events[-1]
assert 'temp:key' not in last_event.actions.state_delta
assert last_event.actions.state_delta['app:key'] == 'app_value'
assert len(session_got.events) == 0