You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: race condition in table creation for DatabaseSessionService
Using one lock and checking for tables creation instead of schema version. Closes issue #4445 Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 869808097
This commit is contained in:
committed by
Copybara-Service
parent
186371f01e
commit
fbe9eccd05
@@ -187,9 +187,6 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# The current database schema version in use, "None" if not yet checked
|
||||
self._db_schema_version: Optional[str] = None
|
||||
|
||||
# Lock to ensure thread-safe schema version check
|
||||
self._db_schema_lock = asyncio.Lock()
|
||||
|
||||
# Per-session locks used to serialize append_event calls in this process.
|
||||
self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {}
|
||||
self._session_lock_ref_count: dict[_SessionLockKey, int] = {}
|
||||
@@ -261,62 +258,55 @@ class DatabaseSessionService(BaseSessionService):
|
||||
DB schema version to use and creates the tables (including setting the
|
||||
schema version metadata) if needed.
|
||||
"""
|
||||
# Check the database schema version and set the _db_schema_version if
|
||||
# needed
|
||||
if self._db_schema_version is not None:
|
||||
return
|
||||
|
||||
async with self._db_schema_lock:
|
||||
# Double-check after acquiring the lock
|
||||
if self._db_schema_version is not None:
|
||||
return
|
||||
try:
|
||||
async with self.db_engine.connect() as conn:
|
||||
self._db_schema_version = await conn.run_sync(
|
||||
_schema_check_utils.get_db_schema_version_from_connection
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to inspect database tables: %s", e)
|
||||
raise
|
||||
|
||||
# Check if tables are created and create them if not
|
||||
# Early return if tables are already created
|
||||
if self._tables_created:
|
||||
return
|
||||
|
||||
async with self._table_creation_lock:
|
||||
# Double-check after acquiring the lock
|
||||
if not self._tables_created:
|
||||
async with self.db_engine.begin() as conn:
|
||||
if (
|
||||
self._db_schema_version
|
||||
== _schema_check_utils.LATEST_SCHEMA_VERSION
|
||||
):
|
||||
# Uncomment to recreate DB every time
|
||||
# await conn.run_sync(BaseV1.metadata.drop_all)
|
||||
logger.debug("Using V1 schema tables...")
|
||||
await conn.run_sync(BaseV1.metadata.create_all)
|
||||
else:
|
||||
# await conn.run_sync(BaseV0.metadata.drop_all)
|
||||
logger.debug("Using V0 schema tables...")
|
||||
await conn.run_sync(BaseV0.metadata.create_all)
|
||||
self._tables_created = True
|
||||
if self._tables_created:
|
||||
return
|
||||
|
||||
if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION:
|
||||
async with self._rollback_on_exception_session() as sql_session:
|
||||
# Check if schema version is set, if not, set it to the latest
|
||||
# version
|
||||
stmt = select(StorageMetadata).where(
|
||||
StorageMetadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
|
||||
# Check the database schema version and set the _db_schema_version
|
||||
if self._db_schema_version is None:
|
||||
try:
|
||||
async with self.db_engine.connect() as conn:
|
||||
self._db_schema_version = await conn.run_sync(
|
||||
_schema_check_utils.get_db_schema_version_from_connection
|
||||
)
|
||||
result = await sql_session.execute(stmt)
|
||||
metadata = result.scalars().first()
|
||||
if not metadata:
|
||||
metadata = StorageMetadata(
|
||||
key=_schema_check_utils.SCHEMA_VERSION_KEY,
|
||||
value=_schema_check_utils.LATEST_SCHEMA_VERSION,
|
||||
)
|
||||
sql_session.add(metadata)
|
||||
await sql_session.commit()
|
||||
except Exception as e:
|
||||
logger.error("Failed to inspect database tables: %s", e)
|
||||
raise
|
||||
|
||||
async with self.db_engine.begin() as conn:
|
||||
if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION:
|
||||
# Uncomment to recreate DB every time
|
||||
# await conn.run_sync(BaseV1.metadata.drop_all)
|
||||
logger.debug("Using V1 schema tables...")
|
||||
await conn.run_sync(BaseV1.metadata.create_all)
|
||||
else:
|
||||
# await conn.run_sync(BaseV0.metadata.drop_all)
|
||||
logger.debug("Using V0 schema tables...")
|
||||
await conn.run_sync(BaseV0.metadata.create_all)
|
||||
|
||||
if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION:
|
||||
async with self._rollback_on_exception_session() as sql_session:
|
||||
# Check if schema version is set, if not, set it to the latest
|
||||
# version
|
||||
stmt = select(StorageMetadata).where(
|
||||
StorageMetadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
|
||||
)
|
||||
result = await sql_session.execute(stmt)
|
||||
metadata = result.scalars().first()
|
||||
if not metadata:
|
||||
metadata = StorageMetadata(
|
||||
key=_schema_check_utils.SCHEMA_VERSION_KEY,
|
||||
value=_schema_check_utils.LATEST_SCHEMA_VERSION,
|
||||
)
|
||||
sql_session.add(metadata)
|
||||
await sql_session.commit()
|
||||
|
||||
self._tables_created = True
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
|
||||
@@ -1050,3 +1050,106 @@ async def test_service_recovers_after_multiple_failures():
|
||||
assert session.id == 'recovered'
|
||||
finally:
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_prepare_tables_no_race_condition():
|
||||
"""Verifies that concurrent calls to _prepare_tables wait for table creation.
|
||||
Reproduces the race condition from
|
||||
https://github.com/google/adk-python/issues/4445: when concurrent requests
|
||||
arrive at startup, _prepare_tables must not return before tables exist.
|
||||
Previously, the early-return guard checked _db_schema_version (set during
|
||||
schema detection) instead of _tables_created, so a second request could
|
||||
slip through after schema detection but before table creation finished.
|
||||
"""
|
||||
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
|
||||
try:
|
||||
# Tables haven't been created yet.
|
||||
assert not service._tables_created
|
||||
assert service._db_schema_version is None
|
||||
|
||||
# Launch several concurrent create_session calls, each with a unique
|
||||
# app_name to avoid IntegrityError on the shared app_states row.
|
||||
# Each will call _prepare_tables internally. If the race condition
|
||||
# exists, some of these will fail because the "sessions" table doesn't
|
||||
# exist yet.
|
||||
num_concurrent = 5
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
service.create_session(
|
||||
app_name=f'app_{i}', user_id='user', session_id=f'sess_{i}'
|
||||
)
|
||||
for i in range(num_concurrent)
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Every call must succeed – no exceptions allowed.
|
||||
for i, result in enumerate(results):
|
||||
assert not isinstance(result, BaseException), (
|
||||
f'Concurrent create_session #{i} raised {result!r}; tables were'
|
||||
' likely not ready due to the _prepare_tables race condition.'
|
||||
)
|
||||
|
||||
# All sessions should be retrievable.
|
||||
for i in range(num_concurrent):
|
||||
session = await service.get_session(
|
||||
app_name=f'app_{i}', user_id='user', session_id=f'sess_{i}'
|
||||
)
|
||||
assert session is not None, f'Session sess_{i} not found after creation.'
|
||||
|
||||
assert service._tables_created
|
||||
finally:
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_tables_serializes_schema_detection_and_creation():
|
||||
"""Verifies schema detection and table creation happen atomically under one
|
||||
lock, so concurrent callers cannot observe a partially-initialized state.
|
||||
After _prepare_tables completes, both _db_schema_version and _tables_created
|
||||
must be set.
|
||||
"""
|
||||
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
|
||||
try:
|
||||
assert not service._tables_created
|
||||
assert service._db_schema_version is None
|
||||
|
||||
await service._prepare_tables()
|
||||
|
||||
# Both must be set after a single _prepare_tables call.
|
||||
assert service._tables_created
|
||||
assert service._db_schema_version is not None
|
||||
|
||||
# Verify tables actually exist by performing a real operation.
|
||||
session = await service.create_session(
|
||||
app_name='app', user_id='user', session_id='s1'
|
||||
)
|
||||
assert session is not None
|
||||
assert session.id == 's1'
|
||||
finally:
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_tables_idempotent_after_creation():
|
||||
"""Calling _prepare_tables multiple times is safe and idempotent.
|
||||
After tables are created, subsequent calls should return immediately via
|
||||
the fast path without errors.
|
||||
"""
|
||||
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
|
||||
try:
|
||||
await service._prepare_tables()
|
||||
assert service._tables_created
|
||||
|
||||
# Call again — should be a no-op via the fast path.
|
||||
await service._prepare_tables()
|
||||
assert service._tables_created
|
||||
|
||||
# Service should still work.
|
||||
session = await service.create_session(
|
||||
app_name='app', user_id='user', session_id='s1'
|
||||
)
|
||||
assert session.id == 's1'
|
||||
finally:
|
||||
await service.close()
|
||||
|
||||
Reference in New Issue
Block a user