diff --git a/src/google/adk/sessions/_session_util.py b/src/google/adk/sessions/_session_util.py index 68340c8b..0b2f99ee 100644 --- a/src/google/adk/sessions/_session_util.py +++ b/src/google/adk/sessions/_session_util.py @@ -19,6 +19,8 @@ from typing import Optional from typing import Type from typing import TypeVar +from .state import State + M = TypeVar("M") @@ -29,3 +31,19 @@ def decode_model( if data is None: return None return model_cls.model_validate(data) + + +def extract_state_delta( + state: dict[str, Any], +) -> dict[str, dict[str, Any]]: + """Extracts app, user, and session state deltas from a state dictionary.""" + deltas = {"app": {}, "user": {}, "session": {}} + if state: + for key in state.keys(): + if key.startswith(State.APP_PREFIX): + deltas["app"][key.removeprefix(State.APP_PREFIX)] = state[key] + elif key.startswith(State.USER_PREFIX): + deltas["user"][key.removeprefix(State.USER_PREFIX)] = state[key] + elif not key.startswith(State.TEMP_PREFIX): + deltas["session"][key] = state[key] + return deltas diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 04af695b..6ce14961 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -465,20 +465,14 @@ class DatabaseSessionService(BaseSessionService): # 5. Return the session with self.database_session_factory() as sql_session: - # Fetch app and user states from storage storage_app_state = sql_session.get(StorageAppState, (app_name)) - storage_user_state = sql_session.get( - StorageUserState, (app_name, user_id) - ) - - app_state = storage_app_state.state if storage_app_state else {} - user_state = storage_user_state.state if storage_user_state else {} - - # Create state tables if not exist if not storage_app_state: storage_app_state = StorageAppState(app_name=app_name, state={}) sql_session.add(storage_app_state) + storage_user_state = sql_session.get( + StorageUserState, (app_name, user_id) + ) if not storage_user_state: storage_user_state = StorageUserState( app_name=app_name, user_id=user_id, state={} @@ -486,19 +480,16 @@ class DatabaseSessionService(BaseSessionService): sql_session.add(storage_user_state) # Extract state deltas - app_state_delta, user_state_delta, session_state = _extract_state_delta( - state - ) + state_deltas = _session_util.extract_state_delta(state) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state = state_deltas["session"] # Apply state delta - app_state.update(app_state_delta) - user_state.update(user_state_delta) - - # Store app and user state if app_state_delta: - storage_app_state.state = app_state + storage_app_state.state = storage_app_state.state | app_state_delta if user_state_delta: - storage_user_state.state = user_state + storage_user_state.state = storage_user_state.state | user_state_delta # Store the session storage_session = StorageSession( @@ -513,7 +504,9 @@ class DatabaseSessionService(BaseSessionService): sql_session.refresh(storage_session) # Merge states for response - merged_state = _merge_state(app_state, user_state, session_state) + merged_state = _merge_state( + storage_app_state.state, storage_user_state.state, session_state + ) session = storage_session.to_session(state=merged_state) return session @@ -536,19 +529,18 @@ class DatabaseSessionService(BaseSessionService): if storage_session is None: return None + query = sql_session.query(StorageEvent).filter( + StorageEvent.app_name == app_name, + StorageEvent.user_id == user_id, + StorageEvent.session_id == storage_session.id, + ) + if config and config.after_timestamp: after_dt = datetime.fromtimestamp(config.after_timestamp) - timestamp_filter = StorageEvent.timestamp >= after_dt - else: - timestamp_filter = True + query = query.filter(StorageEvent.timestamp >= after_dt) storage_events = ( - sql_session.query(StorageEvent) - .filter(StorageEvent.app_name == app_name) - .filter(StorageEvent.session_id == storage_session.id) - .filter(StorageEvent.user_id == user_id) - .filter(timestamp_filter) - .order_by(StorageEvent.timestamp.desc()) + query.order_by(StorageEvent.timestamp.desc()) .limit( config.num_recent_events if config and config.num_recent_events @@ -660,30 +652,21 @@ class DatabaseSessionService(BaseSessionService): StorageUserState, (session.app_name, session.user_id) ) - app_state = storage_app_state.state if storage_app_state else {} - user_state = storage_user_state.state if storage_user_state else {} - session_state = storage_session.state - # Extract state delta - app_state_delta = {} - user_state_delta = {} - session_state_delta = {} - if event.actions: - if event.actions.state_delta: - app_state_delta, user_state_delta, session_state_delta = ( - _extract_state_delta(event.actions.state_delta) - ) - - # Merge state and update storage - if app_state_delta: - app_state.update(app_state_delta) - storage_app_state.state = app_state - if user_state_delta: - user_state.update(user_state_delta) - storage_user_state.state = user_state - if session_state_delta: - session_state.update(session_state_delta) - storage_session.state = session_state + if event.actions and event.actions.state_delta: + state_deltas = _session_util.extract_state_delta( + event.actions.state_delta + ) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state_delta = state_deltas["session"] + # Merge state and update storage + if app_state_delta: + storage_app_state.state = storage_app_state.state | app_state_delta + if user_state_delta: + storage_user_state.state = storage_user_state.state | user_state_delta + if session_state_delta: + storage_session.state = storage_session.state | session_state_delta sql_session.add(StorageEvent.from_event(session, event)) @@ -698,21 +681,6 @@ class DatabaseSessionService(BaseSessionService): return event -def _extract_state_delta(state: dict[str, Any]): - app_state_delta = {} - user_state_delta = {} - session_state_delta = {} - if state: - for key in state.keys(): - if key.startswith(State.APP_PREFIX): - app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key] - elif key.startswith(State.USER_PREFIX): - user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key] - elif not key.startswith(State.TEMP_PREFIX): - session_state_delta[key] = state[key] - return app_state_delta, user_state_delta, session_state_delta - - def _merge_state(app_state, user_state, session_state): # Merge states for response merged_state = copy.deepcopy(session_state) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index adc09230..e45cf9d8 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -22,6 +22,7 @@ import uuid from typing_extensions import override +from . import _session_util from ..events.event import Event from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig @@ -88,6 +89,17 @@ class InMemorySessionService(BaseSessionService): state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, ) -> Session: + state_deltas = _session_util.extract_state_delta(state) + app_state_delta = state_deltas['app'] + user_state_delta = state_deltas['user'] + session_state = state_deltas['session'] + if app_state_delta: + self.app_state.setdefault(app_name, {}).update(app_state_delta) + if user_state_delta: + self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update( + user_state_delta + ) + session_id = ( session_id.strip() if session_id and session_id.strip() @@ -97,7 +109,7 @@ class InMemorySessionService(BaseSessionService): app_name=app_name, user_id=user_id, id=session_id, - state=state or {}, + state=session_state or {}, last_update_time=time.time(), ) @@ -174,11 +186,13 @@ class InMemorySessionService(BaseSessionService): if i >= 0: copied_session.events = copied_session.events[i + 1 :] + # Return a copy of the session object with merged state. return self._merge_state(app_name, user_id, copied_session) def _merge_state( self, app_name: str, user_id: str, copied_session: Session ) -> Session: + """Merges app and user state into session state.""" # Merge app state if app_name in self.app_state: for key in self.app_state[app_name].keys(): @@ -269,11 +283,9 @@ class InMemorySessionService(BaseSessionService): @override async def append_event(self, session: Session, event: Event) -> Event: - # Update the in-memory session. - await super().append_event(session=session, event=event) - session.last_update_time = event.timestamp + if event.partial: + return event - # Update the storage session app_name = session.app_name user_id = session.user_id session_id = session.id @@ -293,21 +305,29 @@ class InMemorySessionService(BaseSessionService): _warning(f'session_id {session_id} not in sessions[app_name][user_id]') return event - if event.actions and event.actions.state_delta: - for key in event.actions.state_delta: - if key.startswith(State.APP_PREFIX): - self.app_state.setdefault(app_name, {})[ - key.removeprefix(State.APP_PREFIX) - ] = event.actions.state_delta[key] - - if key.startswith(State.USER_PREFIX): - self.user_state.setdefault(app_name, {}).setdefault(user_id, {})[ - key.removeprefix(State.USER_PREFIX) - ] = event.actions.state_delta[key] + # Update the in-memory session. + await super().append_event(session=session, event=event) + session.last_update_time = event.timestamp + # Update the storage session storage_session = self.sessions[app_name][user_id].get(session_id) - await super().append_event(session=storage_session, event=event) - + storage_session.events.append(event) storage_session.last_update_time = event.timestamp + if event.actions and event.actions.state_delta: + state_deltas = _session_util.extract_state_delta( + event.actions.state_delta + ) + app_state_delta = state_deltas['app'] + user_state_delta = state_deltas['user'] + session_state_delta = state_deltas['session'] + if app_state_delta: + self.app_state.setdefault(app_name, {}).update(app_state_delta) + if user_state_delta: + self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update( + user_state_delta + ) + if session_state_delta: + storage_session.state.update(session_state_delta) + return event diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 0dd4162e..4d92dcea 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -90,7 +90,7 @@ async def test_create_get_session(service_type): await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) - != session + is None ) @@ -151,20 +151,17 @@ async def test_list_sessions_all_users(service_type): state={'key': 'value2a'}, ) - # List sessions for user1 + # List sessions for user1 - should contain merged state list_sessions_response_1 = await session_service.list_sessions( app_name=app_name, user_id=user_id_1 ) sessions_1 = list_sessions_response_1.sessions assert len(sessions_1) == 2 - assert {s.id for s in sessions_1} == {'session1a', 'session1b'} - for session in sessions_1: - if session.id == 'session1a': - assert session.state == {'key': 'value1a'} - else: - assert session.state == {'key': 'value1b'} + sessions_1_map = {s.id: s for s in sessions_1} + assert sessions_1_map['session1a'].state == {'key': 'value1a'} + assert sessions_1_map['session1b'].state == {'key': 'value1b'} - # List sessions for user2 + # List sessions for user2 - should contain merged state list_sessions_response_2 = await session_service.list_sessions( app_name=app_name, user_id=user_id_2 ) @@ -173,151 +170,170 @@ async def test_list_sessions_all_users(service_type): assert sessions_2[0].id == 'session2a' assert sessions_2[0].state == {'key': 'value2a'} - # List sessions for all users + # List sessions for all users - should contain merged state list_sessions_response_all = await session_service.list_sessions( app_name=app_name, user_id=None ) sessions_all = list_sessions_response_all.sessions assert len(sessions_all) == 3 - assert {s.id for s in sessions_all} == {'session1a', 'session1b', 'session2a'} + sessions_all_map = {s.id: s for s in sessions_all} + assert sessions_all_map['session1a'].state == {'key': 'value1a'} + assert sessions_all_map['session1b'].state == {'key': 'value1b'} + assert sessions_all_map['session2a'].state == {'key': 'value2a'} @pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -async def test_session_state(service_type): +async def test_app_state_is_shared_by_all_users_of_app(service_type): session_service = get_session_service(service_type) app_name = 'my_app' - user_id_1 = 'user1' - user_id_2 = 'user2' - user_id_malicious = 'malicious' - session_id_11 = 'session11' - session_id_12 = 'session12' - session_id_2 = 'session2' - state_11 = {'key11': 'value11'} - state_12 = {'key12': 'value12'} - - session_11 = await session_service.create_session( - app_name=app_name, - user_id=user_id_1, - state=state_11, - session_id=session_id_11, + # User 1 creates a session, establishing app:k1 + session1 = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1', state={'app:k1': 'v1'} ) - await session_service.create_session( - app_name=app_name, - user_id=user_id_1, - state=state_12, - session_id=session_id_12, - ) - await session_service.create_session( - app_name=app_name, user_id=user_id_2, session_id=session_id_2 - ) - - await session_service.create_session( - app_name=app_name, user_id=user_id_malicious, session_id=session_id_11 - ) - - assert session_11.state.get('key11') == 'value11' - + # User 1 appends an event to session1, establishing app:k2 event = Event( - invocation_id='invocation', + invocation_id='inv1', author='user', - content=types.Content(role='user', parts=[types.Part(text='text')]), - actions=EventActions( - state_delta={ - 'app:key': 'value', - 'user:key1': 'value1', - 'temp:key': 'temp', - 'key11': 'value11_new', - } - ), + actions=EventActions(state_delta={'app:k2': 'v2'}), ) - await session_service.append_event(session=session_11, event=event) + await session_service.append_event(session=session1, event=event) - # User and app state is stored, temp state is filtered. - assert session_11.state.get('app:key') == 'value' - assert session_11.state.get('key11') == 'value11_new' - assert session_11.state.get('user:key1') == 'value1' - assert not session_11.state.get('temp:key') - - session_12 = await session_service.get_session( - app_name=app_name, user_id=user_id_1, session_id=session_id_12 + # User 2 creates a new session session2, it should see app:k1 and app:k2 + session2 = await session_service.create_session( + app_name=app_name, user_id='u2', session_id='s2' ) - # After getting a new instance, the session_12 got the user and app state, - # even append_event is not applied to it, temp state has no effect - assert session_12.state.get('key12') == 'value12' - assert not session_12.state.get('temp:key') + assert session2.state == {'app:k1': 'v1', 'app:k2': 'v2'} - # The user1's state is not visible to user2, app state is visible - session_2 = await session_service.get_session( - app_name=app_name, user_id=user_id_2, session_id=session_id_2 + # If we get session session1 again, it should also see both + session1_got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' ) - assert session_2.state.get('app:key') == 'value' - assert not session_2.state.get('user:key1') - - assert not session_2.state.get('user:key1') - - # The change to session_11 is persisted - session_11 = await session_service.get_session( - app_name=app_name, user_id=user_id_1, session_id=session_id_11 - ) - assert session_11.state.get('key11') == 'value11_new' - assert session_11.state.get('user:key1') == 'value1' - assert not session_11.state.get('temp:key') - - # Make sure a malicious user cannot obtain a session and events not belonging to them - session_mismatch = await session_service.get_session( - app_name=app_name, user_id=user_id_malicious, session_id=session_id_11 - ) - - assert len(session_mismatch.events) == 0 + assert session1_got.state.get('app:k1') == 'v1' + assert session1_got.state.get('app:k2') == 'v2' @pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -async def test_create_new_session_will_merge_states(service_type): +async def test_user_state_is_shared_only_by_user_sessions(service_type): session_service = get_session_service(service_type) app_name = 'my_app' - user_id = 'user' - session_id_1 = 'session1' - session_id_2 = 'session2' - state_1 = {'key1': 'value1'} - - session_1 = await session_service.create_session( - app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1 + # User 1 creates a session, establishing user:k1 for user 1 + session1 = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1', state={'user:k1': 'v1'} ) - + # User 1 appends an event to session1, establishing user:k2 for user 1 event = Event( - invocation_id='invocation', + invocation_id='inv1', author='user', - content=types.Content(role='user', parts=[types.Part(text='text')]), - actions=EventActions( - state_delta={ - 'app:key': 'value', - 'user:key1': 'value1', - 'temp:key': 'temp', - } - ), + actions=EventActions(state_delta={'user:k2': 'v2'}), ) - await session_service.append_event(session=session_1, event=event) + await session_service.append_event(session=session1, event=event) - # User and app state is stored, temp state is filtered. - assert session_1.state.get('app:key') == 'value' - assert session_1.state.get('key1') == 'value1' - assert session_1.state.get('user:key1') == 'value1' - assert not session_1.state.get('temp:key') - - session_2 = await session_service.create_session( - app_name=app_name, user_id=user_id, state={}, session_id=session_id_2 + # Another session for User 1 should see user:k1 and user:k2 + session1b = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1b' ) - # Session 2 has the persisted states - assert session_2.state.get('app:key') == 'value' - assert session_2.state.get('user:key1') == 'value1' - assert not session_2.state.get('key1') - assert not session_2.state.get('temp:key') + assert session1b.state == {'user:k1': 'v1', 'user:k2': 'v2'} + + # A session for User 2 should NOT see user:k1 or user:k2 + session2 = await session_service.create_session( + app_name=app_name, user_id='u2', session_id='s2' + ) + assert session2.state == {} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] +) +async def test_session_state_is_not_shared(service_type): + session_service = get_session_service(service_type) + app_name = 'my_app' + # User 1 creates a session session1, establishing sk1 only for session1 + session1 = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1', state={'sk1': 'v1'} + ) + # User 1 appends an event to session1, establishing sk2 only for session1 + event = Event( + invocation_id='inv1', + author='user', + actions=EventActions(state_delta={'sk2': 'v2'}), + ) + await session_service.append_event(session=session1, event=event) + + # Getting session1 should show sk1 and sk2 + session1_got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert session1_got.state.get('sk1') == 'v1' + assert session1_got.state.get('sk2') == 'v2' + + # Creating another session session1b for User 1 should NOT see sk1 or sk2 + session1b = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1b' + ) + assert session1b.state == {} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] +) +async def test_temp_state_is_not_persisted_in_state_or_events(service_type): + session_service = get_session_service(service_type) + app_name = 'my_app' + user_id = 'u1' + session = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s1' + ) + event = Event( + invocation_id='inv1', + author='user', + actions=EventActions(state_delta={'temp:k1': 'v1', 'sk': 'v2'}), + ) + await session_service.append_event(session=session, event=event) + + # Refetch session and check state and event + session_got = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id='s1' + ) + # Check session state does not contain temp keys + assert session_got.state.get('sk') == 'v2' + assert 'temp:k1' not in session_got.state + # Check event as stored in session does not contain temp keys in state_delta + assert 'temp:k1' not in session_got.events[0].actions.state_delta + assert session_got.events[0].actions.state_delta.get('sk') == 'v2' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] +) +async def test_get_session_respects_user_id(service_type): + session_service = get_session_service(service_type) + app_name = 'my_app' + # u1 creates session 's1' and adds an event + session1 = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + event = Event(invocation_id='inv1', author='user') + await session_service.append_event(session1, event) + # u2 creates a session with the same session_id 's1' + await session_service.create_session( + app_name=app_name, user_id='u2', session_id='s1' + ) + # Check that getting s1 for u2 returns u2's session (with no events) + # not u1's session. + session2_got = await session_service.get_session( + app_name=app_name, user_id='u2', session_id='s1' + ) + assert session2_got.user_id == 'u2' + assert len(session2_got.events) == 0 @pytest.mark.asyncio @@ -390,6 +406,9 @@ async def test_append_event_complete(service_type): error_code='error_code', error_message='error_message', interrupted=True, + grounding_metadata=types.GroundingMetadata( + web_search_queries=['query1'], + ), usage_metadata=types.GenerateContentResponseUsageMetadata( prompt_token_count=1, candidates_token_count=1, total_token_count=2 ), @@ -474,72 +493,20 @@ async def test_get_session_with_config(service_type): @pytest.mark.parametrize( 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) -async def test_append_event_with_fields(service_type): - session_service = get_session_service(service_type) - app_name = 'my_app' - user_id = 'test_user' - session = await session_service.create_session( - app_name=app_name, user_id=user_id, state={} - ) - - event = Event( - invocation_id='invocation', - author='user', - content=types.Content(role='user', parts=[types.Part(text='text')]), - long_running_tool_ids={'tool1', 'tool2'}, - partial=False, - turn_complete=True, - error_code='ERROR_CODE', - error_message='error message', - interrupted=True, - grounding_metadata=types.GroundingMetadata( - web_search_queries=['query1'], - ), - custom_metadata={'custom_key': 'custom_value'}, - ) - await session_service.append_event(session, event) - - retrieved_session = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - assert retrieved_session - assert len(retrieved_session.events) == 1 - retrieved_event = retrieved_session.events[0] - - assert retrieved_event == event - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] -) -async def test_append_event_should_trim_temp_delta_state(service_type): +async def test_partial_events_are_not_persisted(service_type): session_service = get_session_service(service_type) app_name = 'my_app' user_id = 'user' - session = await session_service.create_session( app_name=app_name, user_id=user_id ) - - event = Event( - invocation_id='invocation', - author='user', - content=types.Content(role='user', parts=[types.Part(text='text')]), - actions=EventActions( - state_delta={ - 'app:key': 'app_value', - 'temp:key': 'temp_value', - } - ), - ) - + event = Event(author='user', partial=True) await session_service.append_event(session, event) - updated_session = await session_service.get_session( + # Check in-memory session + assert len(session.events) == 0 + # Check persisted session + session_got = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) - - last_event = updated_session.events[-1] - assert 'temp:key' not in last_event.actions.state_delta - assert last_event.actions.state_delta['app:key'] == 'app_value' + assert len(session_got.events) == 0