fix: race condition in table creation for DatabaseSessionService

Using one lock and checking for tables creation instead of schema version.

Closes issue #4445

Co-authored-by: Liang Wu <wuliang@google.com>
PiperOrigin-RevId: 869808097
This commit is contained in:
Liang Wu
2026-02-13 11:06:36 -08:00
committed by Copybara-Service
parent 186371f01e
commit fbe9eccd05
2 changed files with 145 additions and 52 deletions
@@ -187,9 +187,6 @@ class DatabaseSessionService(BaseSessionService):
# The current database schema version in use, "None" if not yet checked
self._db_schema_version: Optional[str] = None
# Lock to ensure thread-safe schema version check
self._db_schema_lock = asyncio.Lock()
# Per-session locks used to serialize append_event calls in this process.
self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {}
self._session_lock_ref_count: dict[_SessionLockKey, int] = {}
@@ -261,62 +258,55 @@ class DatabaseSessionService(BaseSessionService):
DB schema version to use and creates the tables (including setting the
schema version metadata) if needed.
"""
# Check the database schema version and set the _db_schema_version if
# needed
if self._db_schema_version is not None:
return
async with self._db_schema_lock:
# Double-check after acquiring the lock
if self._db_schema_version is not None:
return
try:
async with self.db_engine.connect() as conn:
self._db_schema_version = await conn.run_sync(
_schema_check_utils.get_db_schema_version_from_connection
)
except Exception as e:
logger.error("Failed to inspect database tables: %s", e)
raise
# Check if tables are created and create them if not
# Early return if tables are already created
if self._tables_created:
return
async with self._table_creation_lock:
# Double-check after acquiring the lock
if not self._tables_created:
async with self.db_engine.begin() as conn:
if (
self._db_schema_version
== _schema_check_utils.LATEST_SCHEMA_VERSION
):
# Uncomment to recreate DB every time
# await conn.run_sync(BaseV1.metadata.drop_all)
logger.debug("Using V1 schema tables...")
await conn.run_sync(BaseV1.metadata.create_all)
else:
# await conn.run_sync(BaseV0.metadata.drop_all)
logger.debug("Using V0 schema tables...")
await conn.run_sync(BaseV0.metadata.create_all)
self._tables_created = True
if self._tables_created:
return
if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION:
async with self._rollback_on_exception_session() as sql_session:
# Check if schema version is set, if not, set it to the latest
# version
stmt = select(StorageMetadata).where(
StorageMetadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
# Check the database schema version and set the _db_schema_version
if self._db_schema_version is None:
try:
async with self.db_engine.connect() as conn:
self._db_schema_version = await conn.run_sync(
_schema_check_utils.get_db_schema_version_from_connection
)
result = await sql_session.execute(stmt)
metadata = result.scalars().first()
if not metadata:
metadata = StorageMetadata(
key=_schema_check_utils.SCHEMA_VERSION_KEY,
value=_schema_check_utils.LATEST_SCHEMA_VERSION,
)
sql_session.add(metadata)
await sql_session.commit()
except Exception as e:
logger.error("Failed to inspect database tables: %s", e)
raise
async with self.db_engine.begin() as conn:
if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION:
# Uncomment to recreate DB every time
# await conn.run_sync(BaseV1.metadata.drop_all)
logger.debug("Using V1 schema tables...")
await conn.run_sync(BaseV1.metadata.create_all)
else:
# await conn.run_sync(BaseV0.metadata.drop_all)
logger.debug("Using V0 schema tables...")
await conn.run_sync(BaseV0.metadata.create_all)
if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION:
async with self._rollback_on_exception_session() as sql_session:
# Check if schema version is set, if not, set it to the latest
# version
stmt = select(StorageMetadata).where(
StorageMetadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
)
result = await sql_session.execute(stmt)
metadata = result.scalars().first()
if not metadata:
metadata = StorageMetadata(
key=_schema_check_utils.SCHEMA_VERSION_KEY,
value=_schema_check_utils.LATEST_SCHEMA_VERSION,
)
sql_session.add(metadata)
await sql_session.commit()
self._tables_created = True
@override
async def create_session(
@@ -1050,3 +1050,106 @@ async def test_service_recovers_after_multiple_failures():
assert session.id == 'recovered'
finally:
await service.close()
@pytest.mark.asyncio
async def test_concurrent_prepare_tables_no_race_condition():
"""Verifies that concurrent calls to _prepare_tables wait for table creation.
Reproduces the race condition from
https://github.com/google/adk-python/issues/4445: when concurrent requests
arrive at startup, _prepare_tables must not return before tables exist.
Previously, the early-return guard checked _db_schema_version (set during
schema detection) instead of _tables_created, so a second request could
slip through after schema detection but before table creation finished.
"""
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
# Tables haven't been created yet.
assert not service._tables_created
assert service._db_schema_version is None
# Launch several concurrent create_session calls, each with a unique
# app_name to avoid IntegrityError on the shared app_states row.
# Each will call _prepare_tables internally. If the race condition
# exists, some of these will fail because the "sessions" table doesn't
# exist yet.
num_concurrent = 5
results = await asyncio.gather(
*[
service.create_session(
app_name=f'app_{i}', user_id='user', session_id=f'sess_{i}'
)
for i in range(num_concurrent)
],
return_exceptions=True,
)
# Every call must succeed no exceptions allowed.
for i, result in enumerate(results):
assert not isinstance(result, BaseException), (
f'Concurrent create_session #{i} raised {result!r}; tables were'
' likely not ready due to the _prepare_tables race condition.'
)
# All sessions should be retrievable.
for i in range(num_concurrent):
session = await service.get_session(
app_name=f'app_{i}', user_id='user', session_id=f'sess_{i}'
)
assert session is not None, f'Session sess_{i} not found after creation.'
assert service._tables_created
finally:
await service.close()
@pytest.mark.asyncio
async def test_prepare_tables_serializes_schema_detection_and_creation():
"""Verifies schema detection and table creation happen atomically under one
lock, so concurrent callers cannot observe a partially-initialized state.
After _prepare_tables completes, both _db_schema_version and _tables_created
must be set.
"""
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
assert not service._tables_created
assert service._db_schema_version is None
await service._prepare_tables()
# Both must be set after a single _prepare_tables call.
assert service._tables_created
assert service._db_schema_version is not None
# Verify tables actually exist by performing a real operation.
session = await service.create_session(
app_name='app', user_id='user', session_id='s1'
)
assert session is not None
assert session.id == 's1'
finally:
await service.close()
@pytest.mark.asyncio
async def test_prepare_tables_idempotent_after_creation():
"""Calling _prepare_tables multiple times is safe and idempotent.
After tables are created, subsequent calls should return immediately via
the fast path without errors.
"""
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
await service._prepare_tables()
assert service._tables_created
# Call again — should be a no-op via the fast path.
await service._prepare_tables()
assert service._tables_created
# Service should still work.
session = await service.create_session(
app_name='app', user_id='user', session_id='s1'
)
assert session.id == 's1'
finally:
await service.close()