From fbe9eccd05e628daa67059ba2e6a0d03966b240d Mon Sep 17 00:00:00 2001 From: Liang Wu Date: Fri, 13 Feb 2026 11:06:36 -0800 Subject: [PATCH] fix: race condition in table creation for `DatabaseSessionService` Using one lock and checking for tables creation instead of schema version. Closes issue #4445 Co-authored-by: Liang Wu PiperOrigin-RevId: 869808097 --- .../adk/sessions/database_session_service.py | 94 +++++++--------- .../sessions/test_session_service.py | 103 ++++++++++++++++++ 2 files changed, 145 insertions(+), 52 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index c7c86e6e..24f525ba 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -187,9 +187,6 @@ class DatabaseSessionService(BaseSessionService): # The current database schema version in use, "None" if not yet checked self._db_schema_version: Optional[str] = None - # Lock to ensure thread-safe schema version check - self._db_schema_lock = asyncio.Lock() - # Per-session locks used to serialize append_event calls in this process. self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {} self._session_lock_ref_count: dict[_SessionLockKey, int] = {} @@ -261,62 +258,55 @@ class DatabaseSessionService(BaseSessionService): DB schema version to use and creates the tables (including setting the schema version metadata) if needed. """ - # Check the database schema version and set the _db_schema_version if - # needed - if self._db_schema_version is not None: - return - - async with self._db_schema_lock: - # Double-check after acquiring the lock - if self._db_schema_version is not None: - return - try: - async with self.db_engine.connect() as conn: - self._db_schema_version = await conn.run_sync( - _schema_check_utils.get_db_schema_version_from_connection - ) - except Exception as e: - logger.error("Failed to inspect database tables: %s", e) - raise - - # Check if tables are created and create them if not + # Early return if tables are already created if self._tables_created: return async with self._table_creation_lock: # Double-check after acquiring the lock - if not self._tables_created: - async with self.db_engine.begin() as conn: - if ( - self._db_schema_version - == _schema_check_utils.LATEST_SCHEMA_VERSION - ): - # Uncomment to recreate DB every time - # await conn.run_sync(BaseV1.metadata.drop_all) - logger.debug("Using V1 schema tables...") - await conn.run_sync(BaseV1.metadata.create_all) - else: - # await conn.run_sync(BaseV0.metadata.drop_all) - logger.debug("Using V0 schema tables...") - await conn.run_sync(BaseV0.metadata.create_all) - self._tables_created = True + if self._tables_created: + return - if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION: - async with self._rollback_on_exception_session() as sql_session: - # Check if schema version is set, if not, set it to the latest - # version - stmt = select(StorageMetadata).where( - StorageMetadata.key == _schema_check_utils.SCHEMA_VERSION_KEY + # Check the database schema version and set the _db_schema_version + if self._db_schema_version is None: + try: + async with self.db_engine.connect() as conn: + self._db_schema_version = await conn.run_sync( + _schema_check_utils.get_db_schema_version_from_connection ) - result = await sql_session.execute(stmt) - metadata = result.scalars().first() - if not metadata: - metadata = StorageMetadata( - key=_schema_check_utils.SCHEMA_VERSION_KEY, - value=_schema_check_utils.LATEST_SCHEMA_VERSION, - ) - sql_session.add(metadata) - await sql_session.commit() + except Exception as e: + logger.error("Failed to inspect database tables: %s", e) + raise + + async with self.db_engine.begin() as conn: + if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION: + # Uncomment to recreate DB every time + # await conn.run_sync(BaseV1.metadata.drop_all) + logger.debug("Using V1 schema tables...") + await conn.run_sync(BaseV1.metadata.create_all) + else: + # await conn.run_sync(BaseV0.metadata.drop_all) + logger.debug("Using V0 schema tables...") + await conn.run_sync(BaseV0.metadata.create_all) + + if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION: + async with self._rollback_on_exception_session() as sql_session: + # Check if schema version is set, if not, set it to the latest + # version + stmt = select(StorageMetadata).where( + StorageMetadata.key == _schema_check_utils.SCHEMA_VERSION_KEY + ) + result = await sql_session.execute(stmt) + metadata = result.scalars().first() + if not metadata: + metadata = StorageMetadata( + key=_schema_check_utils.SCHEMA_VERSION_KEY, + value=_schema_check_utils.LATEST_SCHEMA_VERSION, + ) + sql_session.add(metadata) + await sql_session.commit() + + self._tables_created = True @override async def create_session( diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 29530d2e..25530bed 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -1050,3 +1050,106 @@ async def test_service_recovers_after_multiple_failures(): assert session.id == 'recovered' finally: await service.close() + + +@pytest.mark.asyncio +async def test_concurrent_prepare_tables_no_race_condition(): + """Verifies that concurrent calls to _prepare_tables wait for table creation. + Reproduces the race condition from + https://github.com/google/adk-python/issues/4445: when concurrent requests + arrive at startup, _prepare_tables must not return before tables exist. + Previously, the early-return guard checked _db_schema_version (set during + schema detection) instead of _tables_created, so a second request could + slip through after schema detection but before table creation finished. + """ + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + # Tables haven't been created yet. + assert not service._tables_created + assert service._db_schema_version is None + + # Launch several concurrent create_session calls, each with a unique + # app_name to avoid IntegrityError on the shared app_states row. + # Each will call _prepare_tables internally. If the race condition + # exists, some of these will fail because the "sessions" table doesn't + # exist yet. + num_concurrent = 5 + results = await asyncio.gather( + *[ + service.create_session( + app_name=f'app_{i}', user_id='user', session_id=f'sess_{i}' + ) + for i in range(num_concurrent) + ], + return_exceptions=True, + ) + + # Every call must succeed – no exceptions allowed. + for i, result in enumerate(results): + assert not isinstance(result, BaseException), ( + f'Concurrent create_session #{i} raised {result!r}; tables were' + ' likely not ready due to the _prepare_tables race condition.' + ) + + # All sessions should be retrievable. + for i in range(num_concurrent): + session = await service.get_session( + app_name=f'app_{i}', user_id='user', session_id=f'sess_{i}' + ) + assert session is not None, f'Session sess_{i} not found after creation.' + + assert service._tables_created + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_prepare_tables_serializes_schema_detection_and_creation(): + """Verifies schema detection and table creation happen atomically under one + lock, so concurrent callers cannot observe a partially-initialized state. + After _prepare_tables completes, both _db_schema_version and _tables_created + must be set. + """ + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + assert not service._tables_created + assert service._db_schema_version is None + + await service._prepare_tables() + + # Both must be set after a single _prepare_tables call. + assert service._tables_created + assert service._db_schema_version is not None + + # Verify tables actually exist by performing a real operation. + session = await service.create_session( + app_name='app', user_id='user', session_id='s1' + ) + assert session is not None + assert session.id == 's1' + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_prepare_tables_idempotent_after_creation(): + """Calling _prepare_tables multiple times is safe and idempotent. + After tables are created, subsequent calls should return immediately via + the fast path without errors. + """ + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + await service._prepare_tables() + assert service._tables_created + + # Call again — should be a no-op via the fast path. + await service._prepare_tables() + assert service._tables_created + + # Service should still work. + session = await service.create_session( + app_name='app', user_id='user', session_id='s1' + ) + assert session.id == 's1' + finally: + await service.close()