diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 18dd999a..da26dd25 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -22,11 +22,12 @@ import logging from typing import Any from typing import AsyncIterator from typing import Optional +from typing import TypeAlias +from typing import TypeVar from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import select -from sqlalchemy import text from sqlalchemy.engine import make_url from sqlalchemy.exc import ArgumentError from sqlalchemy.ext.asyncio import async_sessionmaker @@ -59,6 +60,40 @@ from .state import State logger = logging.getLogger("google_adk." + __name__) +_SQLITE_DIALECT = "sqlite" +_MARIADB_DIALECT = "mariadb" +_MYSQL_DIALECT = "mysql" +_POSTGRESQL_DIALECT = "postgresql" +# Tuple key order for in-process per-session lock maps: +# (app_name, user_id, session_id). +_SessionLockKey: TypeAlias = tuple[str, str, str] +_StorageStateT = TypeVar( + "_StorageStateT", + StorageAppStateV0, + StorageAppStateV1, + StorageUserStateV0, + StorageUserStateV1, +) + + +async def _select_required_state( + *, + sql_session: DatabaseSessionFactory, + state_model: type[_StorageStateT], + predicates: tuple[Any, ...], + use_row_level_locking: bool, + missing_message: str, +) -> _StorageStateT: + """Returns a state row, raising if the row is missing.""" + stmt = select(state_model).filter(*predicates) + if use_row_level_locking: + stmt = stmt.with_for_update() + result = await sql_session.execute(stmt) + state_row = result.scalars().one_or_none() + if state_row is None: + raise ValueError(missing_message) + return state_row + def _set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() @@ -107,16 +142,19 @@ class DatabaseSessionService(BaseSessionService): try: engine_kwargs = dict(kwargs) url = make_url(db_url) - if url.get_backend_name() == "sqlite" and url.database == ":memory:": + if ( + url.get_backend_name() == _SQLITE_DIALECT + and url.database == ":memory:" + ): engine_kwargs.setdefault("poolclass", StaticPool) connect_args = dict(engine_kwargs.get("connect_args", {})) connect_args.setdefault("check_same_thread", False) engine_kwargs["connect_args"] = connect_args - elif url.get_backend_name() != "sqlite": + elif url.get_backend_name() != _SQLITE_DIALECT: engine_kwargs.setdefault("pool_pre_ping", True) db_engine = create_async_engine(db_url, **engine_kwargs) - if db_engine.dialect.name == "sqlite": + if db_engine.dialect.name == _SQLITE_DIALECT: # Set sqlite pragma to enable foreign keys constraints event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma) @@ -152,6 +190,11 @@ class DatabaseSessionService(BaseSessionService): # Lock to ensure thread-safe schema version check self._db_schema_lock = asyncio.Lock() + # Per-session locks used to serialize append_event calls in this process. + self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {} + self._session_lock_ref_count: dict[_SessionLockKey, int] = {} + self._session_locks_guard = asyncio.Lock() + def _get_schema_classes(self) -> _SchemaClasses: return _SchemaClasses(self._db_schema_version) @@ -172,6 +215,45 @@ class DatabaseSessionService(BaseSessionService): await sql_session.rollback() raise + def _supports_row_level_locking(self) -> bool: + return self.db_engine.dialect.name in ( + _MARIADB_DIALECT, + _MYSQL_DIALECT, + _POSTGRESQL_DIALECT, + ) + + @asynccontextmanager + async def _with_session_lock( + self, *, app_name: str, user_id: str, session_id: str + ) -> AsyncIterator[None]: + """Serializes event appends for the same session within this process.""" + # Use one lock per logical ADK session to prevent concurrent append_event + # writes from racing in the same process. + lock_key = (app_name, user_id, session_id) + async with self._session_locks_guard: + lock = self._session_locks.get(lock_key) + if lock is None: + lock = asyncio.Lock() + self._session_locks[lock_key] = lock + # Reference counting keeps lock objects alive while they are in use by + # concurrent tasks and allows cleanup once all waiters complete. + self._session_lock_ref_count[lock_key] = ( + self._session_lock_ref_count.get(lock_key, 0) + 1 + ) + + try: + async with lock: + yield + finally: + async with self._session_locks_guard: + remaining = self._session_lock_ref_count.get(lock_key, 0) - 1 + # Remove lock bookkeeping after the last waiter exits. + if remaining <= 0 and not lock.locked(): + self._session_lock_ref_count.pop(lock_key, None) + self._session_locks.pop(lock_key, None) + else: + self._session_lock_ref_count[lock_key] = remaining + async def _prepare_tables(self): """Ensure database tables are ready for use. @@ -291,7 +373,7 @@ class DatabaseSessionService(BaseSessionService): # Store the session now = datetime.now(timezone.utc) - is_sqlite = self.db_engine.dialect.name == "sqlite" + is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT if is_sqlite: now = now.replace(tzinfo=None) @@ -372,7 +454,7 @@ class DatabaseSessionService(BaseSessionService): # Convert storage session to session events = [e.to_event() for e in reversed(storage_events)] - is_sqlite = self.db_engine.dialect.name == "sqlite" + is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT session = storage_session.to_session( state=merged_state, events=events, is_sqlite=is_sqlite ) @@ -418,7 +500,7 @@ class DatabaseSessionService(BaseSessionService): user_states_map[storage_user_state.user_id] = storage_user_state.state sessions = [] - is_sqlite = self.db_engine.dialect.name == "sqlite" + is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT for storage_session in results: session_state = storage_session.state user_state = user_states_map.get(storage_session.user_id, {}) @@ -456,71 +538,109 @@ class DatabaseSessionService(BaseSessionService): # 2. Update session attributes based on event config # 3. Store event to table schema = self._get_schema_classes() - async with self._rollback_on_exception_session() as sql_session: - storage_session = await sql_session.get( - schema.StorageSession, (session.app_name, session.user_id, session.id) - ) - - # Fetch states from storage - storage_app_state = await sql_session.get( - schema.StorageAppState, (session.app_name) - ) - storage_user_state = await sql_session.get( - schema.StorageUserState, (session.app_name, session.user_id) - ) - - is_sqlite = self.db_engine.dialect.name == "sqlite" - if ( - storage_session.get_update_timestamp(is_sqlite) - > session.last_update_time - ): - # Reload the session from storage if it has been updated since it was - # loaded. - app_state = storage_app_state.state if storage_app_state else {} - user_state = storage_user_state.state if storage_user_state else {} - session_state = storage_session.state - session.state = _merge_state(app_state, user_state, session_state) - - stmt = ( - select(schema.StorageEvent) - .filter(schema.StorageEvent.app_name == session.app_name) - .filter(schema.StorageEvent.session_id == session.id) - .filter(schema.StorageEvent.user_id == session.user_id) - .order_by(schema.StorageEvent.timestamp.asc()) + is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT + use_row_level_locking = self._supports_row_level_locking() + async with self._with_session_lock( + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + ): + async with self._rollback_on_exception_session() as sql_session: + storage_session_stmt = ( + select(schema.StorageSession) + .filter(schema.StorageSession.app_name == session.app_name) + .filter(schema.StorageSession.user_id == session.user_id) + .filter(schema.StorageSession.id == session.id) ) - result = await sql_session.stream_scalars(stmt) - storage_events = [e async for e in result] - session.events = [e.to_event() for e in storage_events] + if use_row_level_locking: + storage_session_stmt = storage_session_stmt.with_for_update() + storage_session_result = await sql_session.execute(storage_session_stmt) + storage_session = storage_session_result.scalars().one_or_none() + if storage_session is None: + raise ValueError(f"Session {session.id} not found.") - # Extract state delta - if event.actions and event.actions.state_delta: - state_deltas = _session_util.extract_state_delta( - event.actions.state_delta + storage_app_state = await _select_required_state( + sql_session=sql_session, + state_model=schema.StorageAppState, + predicates=(schema.StorageAppState.app_name == session.app_name,), + use_row_level_locking=use_row_level_locking, + missing_message=( + "App state missing for app_name=" + f"{session.app_name!r}. Session state tables should be " + "initialized by create_session." + ), + ) + storage_user_state = await _select_required_state( + sql_session=sql_session, + state_model=schema.StorageUserState, + predicates=( + schema.StorageUserState.app_name == session.app_name, + schema.StorageUserState.user_id == session.user_id, + ), + use_row_level_locking=use_row_level_locking, + missing_message=( + "User state missing for app_name=" + f"{session.app_name!r}, user_id={session.user_id!r}. " + "Session state tables should be initialized by " + "create_session." + ), ) - app_state_delta = state_deltas["app"] - user_state_delta = state_deltas["user"] - session_state_delta = state_deltas["session"] - # Merge state and update storage - if app_state_delta: - storage_app_state.state = storage_app_state.state | app_state_delta - if user_state_delta: - storage_user_state.state = storage_user_state.state | user_state_delta - if session_state_delta: - storage_session.state = storage_session.state | session_state_delta - if is_sqlite: - update_time = datetime.fromtimestamp( - event.timestamp, timezone.utc - ).replace(tzinfo=None) - else: - update_time = datetime.fromtimestamp(event.timestamp) - storage_session.update_time = update_time - sql_session.add(schema.StorageEvent.from_event(session, event)) + if ( + storage_session.get_update_timestamp(is_sqlite) + > session.last_update_time + ): + # Reload the session from storage if it has been updated since it was + # loaded. + app_state = storage_app_state.state + user_state = storage_user_state.state + session_state = storage_session.state + session.state = _merge_state(app_state, user_state, session_state) - await sql_session.commit() + stmt = ( + select(schema.StorageEvent) + .filter(schema.StorageEvent.app_name == session.app_name) + .filter(schema.StorageEvent.session_id == session.id) + .filter(schema.StorageEvent.user_id == session.user_id) + .order_by(schema.StorageEvent.timestamp.asc()) + ) + result = await sql_session.stream_scalars(stmt) + storage_events = [e async for e in result] + session.events = [e.to_event() for e in storage_events] - # Update timestamp with commit time - session.last_update_time = storage_session.get_update_timestamp(is_sqlite) + # Extract state delta + if event.actions and event.actions.state_delta: + state_deltas = _session_util.extract_state_delta( + event.actions.state_delta + ) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state_delta = state_deltas["session"] + # Merge state and update storage + if app_state_delta: + storage_app_state.state = storage_app_state.state | app_state_delta + if user_state_delta: + storage_user_state.state = ( + storage_user_state.state | user_state_delta + ) + if session_state_delta: + storage_session.state = storage_session.state | session_state_delta + + if is_sqlite: + update_time = datetime.fromtimestamp( + event.timestamp, timezone.utc + ).replace(tzinfo=None) + else: + update_time = datetime.fromtimestamp(event.timestamp) + storage_session.update_time = update_time + sql_session.add(schema.StorageEvent.from_event(session, event)) + + await sql_session.commit() + + # Update timestamp with commit time + session.last_update_time = storage_session.get_update_timestamp( + is_sqlite + ) # Also update the in-memory session await super().append_event(session=session, event=event) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index f6445934..e2b03f2d 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from datetime import datetime from datetime import timezone import enum @@ -28,6 +29,7 @@ from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.sqlite_session_service import SqliteSessionService from google.genai import types import pytest +from sqlalchemy import delete class SessionServiceType(enum.Enum): @@ -613,6 +615,110 @@ async def test_append_event_to_stale_session(): ] +@pytest.mark.asyncio +async def test_append_event_raises_if_app_state_row_missing(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + session = await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + schema = service._get_schema_classes() + async with service.database_session_factory() as sql_session: + await sql_session.execute( + delete(schema.StorageAppState).where( + schema.StorageAppState.app_name == session.app_name + ) + ) + await sql_session.commit() + + event = Event( + invocation_id='inv1', + author='user', + actions=EventActions(state_delta={'k': 'v'}), + ) + with pytest.raises(ValueError, match='App state missing'): + await service.append_event(session, event) + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_append_event_raises_if_user_state_row_missing(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + session = await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + schema = service._get_schema_classes() + async with service.database_session_factory() as sql_session: + await sql_session.execute( + delete(schema.StorageUserState).where( + schema.StorageUserState.app_name == session.app_name, + schema.StorageUserState.user_id == session.user_id, + ) + ) + await sql_session.commit() + + event = Event( + invocation_id='inv1', + author='user', + actions=EventActions(state_delta={'k': 'v'}), + ) + with pytest.raises(ValueError, match='User state missing'): + await service.append_event(session, event) + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_append_event_concurrent_stale_sessions_preserve_all_state(): + session_service = get_session_service( + service_type=SessionServiceType.DATABASE + ) + + async with session_service: + app_name = 'my_app' + user_id = 'user' + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) + + iteration_count = 8 + for i in range(iteration_count): + latest_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + stale_session_1 = latest_session.model_copy(deep=True) + stale_session_2 = latest_session.model_copy(deep=True) + base_timestamp = latest_session.last_update_time + 10.0 + event_1 = Event( + invocation_id=f'inv{i}-1', + author='user', + timestamp=base_timestamp + 1.0, + actions=EventActions(state_delta={f'sk{i}-1': f'v{i}-1'}), + ) + event_2 = Event( + invocation_id=f'inv{i}-2', + author='user', + timestamp=base_timestamp + 2.0, + actions=EventActions(state_delta={f'sk{i}-2': f'v{i}-2'}), + ) + + await asyncio.gather( + session_service.append_event(stale_session_1, event_1), + session_service.append_event(stale_session_2, event_2), + ) + + session_final = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + for i in range(iteration_count): + assert session_final.state.get(f'sk{i}-1') == f'v{i}-1' + assert session_final.state.get(f'sk{i}-2') == f'v{i}-2' + assert len(session_final.events) == iteration_count * 2 + + @pytest.mark.asyncio async def test_get_session_with_config(session_service): app_name = 'my_app'