fix: Per-session locking and row-level locking in DatabaseSessionService.append_event

This change introduces an in-process `asyncio.Lock` per session to serialize `append_event` calls for the same session ID within a single process. For supported database dialects (MySQL, PostgreSQL, MariaDB), it also uses `SELECT ... FOR UPDATE` to acquire row-level locks on the session, app state, and user state records, preventing race conditions across different processes or database connections. A new test case verifies that concurrent updates to stale session objects correctly merge all state changes.

Close #1049

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 867752676
This commit is contained in:
George Weale
2026-02-09 13:37:47 -08:00
committed by Copybara-Service
parent 32ee07df01
commit f50847460f
2 changed files with 292 additions and 66 deletions
@@ -22,11 +22,12 @@ import logging
from typing import Any
from typing import AsyncIterator
from typing import Optional
from typing import TypeAlias
from typing import TypeVar
from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.asyncio import async_sessionmaker
@@ -59,6 +60,40 @@ from .state import State
logger = logging.getLogger("google_adk." + __name__)
_SQLITE_DIALECT = "sqlite"
_MARIADB_DIALECT = "mariadb"
_MYSQL_DIALECT = "mysql"
_POSTGRESQL_DIALECT = "postgresql"
# Tuple key order for in-process per-session lock maps:
# (app_name, user_id, session_id).
_SessionLockKey: TypeAlias = tuple[str, str, str]
_StorageStateT = TypeVar(
"_StorageStateT",
StorageAppStateV0,
StorageAppStateV1,
StorageUserStateV0,
StorageUserStateV1,
)
async def _select_required_state(
*,
sql_session: DatabaseSessionFactory,
state_model: type[_StorageStateT],
predicates: tuple[Any, ...],
use_row_level_locking: bool,
missing_message: str,
) -> _StorageStateT:
"""Returns a state row, raising if the row is missing."""
stmt = select(state_model).filter(*predicates)
if use_row_level_locking:
stmt = stmt.with_for_update()
result = await sql_session.execute(stmt)
state_row = result.scalars().one_or_none()
if state_row is None:
raise ValueError(missing_message)
return state_row
def _set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
@@ -107,16 +142,19 @@ class DatabaseSessionService(BaseSessionService):
try:
engine_kwargs = dict(kwargs)
url = make_url(db_url)
if url.get_backend_name() == "sqlite" and url.database == ":memory:":
if (
url.get_backend_name() == _SQLITE_DIALECT
and url.database == ":memory:"
):
engine_kwargs.setdefault("poolclass", StaticPool)
connect_args = dict(engine_kwargs.get("connect_args", {}))
connect_args.setdefault("check_same_thread", False)
engine_kwargs["connect_args"] = connect_args
elif url.get_backend_name() != "sqlite":
elif url.get_backend_name() != _SQLITE_DIALECT:
engine_kwargs.setdefault("pool_pre_ping", True)
db_engine = create_async_engine(db_url, **engine_kwargs)
if db_engine.dialect.name == "sqlite":
if db_engine.dialect.name == _SQLITE_DIALECT:
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)
@@ -152,6 +190,11 @@ class DatabaseSessionService(BaseSessionService):
# Lock to ensure thread-safe schema version check
self._db_schema_lock = asyncio.Lock()
# Per-session locks used to serialize append_event calls in this process.
self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {}
self._session_lock_ref_count: dict[_SessionLockKey, int] = {}
self._session_locks_guard = asyncio.Lock()
def _get_schema_classes(self) -> _SchemaClasses:
return _SchemaClasses(self._db_schema_version)
@@ -172,6 +215,45 @@ class DatabaseSessionService(BaseSessionService):
await sql_session.rollback()
raise
def _supports_row_level_locking(self) -> bool:
return self.db_engine.dialect.name in (
_MARIADB_DIALECT,
_MYSQL_DIALECT,
_POSTGRESQL_DIALECT,
)
@asynccontextmanager
async def _with_session_lock(
self, *, app_name: str, user_id: str, session_id: str
) -> AsyncIterator[None]:
"""Serializes event appends for the same session within this process."""
# Use one lock per logical ADK session to prevent concurrent append_event
# writes from racing in the same process.
lock_key = (app_name, user_id, session_id)
async with self._session_locks_guard:
lock = self._session_locks.get(lock_key)
if lock is None:
lock = asyncio.Lock()
self._session_locks[lock_key] = lock
# Reference counting keeps lock objects alive while they are in use by
# concurrent tasks and allows cleanup once all waiters complete.
self._session_lock_ref_count[lock_key] = (
self._session_lock_ref_count.get(lock_key, 0) + 1
)
try:
async with lock:
yield
finally:
async with self._session_locks_guard:
remaining = self._session_lock_ref_count.get(lock_key, 0) - 1
# Remove lock bookkeeping after the last waiter exits.
if remaining <= 0 and not lock.locked():
self._session_lock_ref_count.pop(lock_key, None)
self._session_locks.pop(lock_key, None)
else:
self._session_lock_ref_count[lock_key] = remaining
async def _prepare_tables(self):
"""Ensure database tables are ready for use.
@@ -291,7 +373,7 @@ class DatabaseSessionService(BaseSessionService):
# Store the session
now = datetime.now(timezone.utc)
is_sqlite = self.db_engine.dialect.name == "sqlite"
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
if is_sqlite:
now = now.replace(tzinfo=None)
@@ -372,7 +454,7 @@ class DatabaseSessionService(BaseSessionService):
# Convert storage session to session
events = [e.to_event() for e in reversed(storage_events)]
is_sqlite = self.db_engine.dialect.name == "sqlite"
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
session = storage_session.to_session(
state=merged_state, events=events, is_sqlite=is_sqlite
)
@@ -418,7 +500,7 @@ class DatabaseSessionService(BaseSessionService):
user_states_map[storage_user_state.user_id] = storage_user_state.state
sessions = []
is_sqlite = self.db_engine.dialect.name == "sqlite"
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
for storage_session in results:
session_state = storage_session.state
user_state = user_states_map.get(storage_session.user_id, {})
@@ -456,71 +538,109 @@ class DatabaseSessionService(BaseSessionService):
# 2. Update session attributes based on event config
# 3. Store event to table
schema = self._get_schema_classes()
async with self._rollback_on_exception_session() as sql_session:
storage_session = await sql_session.get(
schema.StorageSession, (session.app_name, session.user_id, session.id)
)
# Fetch states from storage
storage_app_state = await sql_session.get(
schema.StorageAppState, (session.app_name)
)
storage_user_state = await sql_session.get(
schema.StorageUserState, (session.app_name, session.user_id)
)
is_sqlite = self.db_engine.dialect.name == "sqlite"
if (
storage_session.get_update_timestamp(is_sqlite)
> session.last_update_time
):
# Reload the session from storage if it has been updated since it was
# loaded.
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
session_state = storage_session.state
session.state = _merge_state(app_state, user_state, session_state)
stmt = (
select(schema.StorageEvent)
.filter(schema.StorageEvent.app_name == session.app_name)
.filter(schema.StorageEvent.session_id == session.id)
.filter(schema.StorageEvent.user_id == session.user_id)
.order_by(schema.StorageEvent.timestamp.asc())
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
use_row_level_locking = self._supports_row_level_locking()
async with self._with_session_lock(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
):
async with self._rollback_on_exception_session() as sql_session:
storage_session_stmt = (
select(schema.StorageSession)
.filter(schema.StorageSession.app_name == session.app_name)
.filter(schema.StorageSession.user_id == session.user_id)
.filter(schema.StorageSession.id == session.id)
)
result = await sql_session.stream_scalars(stmt)
storage_events = [e async for e in result]
session.events = [e.to_event() for e in storage_events]
if use_row_level_locking:
storage_session_stmt = storage_session_stmt.with_for_update()
storage_session_result = await sql_session.execute(storage_session_stmt)
storage_session = storage_session_result.scalars().one_or_none()
if storage_session is None:
raise ValueError(f"Session {session.id} not found.")
# Extract state delta
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(
event.actions.state_delta
storage_app_state = await _select_required_state(
sql_session=sql_session,
state_model=schema.StorageAppState,
predicates=(schema.StorageAppState.app_name == session.app_name,),
use_row_level_locking=use_row_level_locking,
missing_message=(
"App state missing for app_name="
f"{session.app_name!r}. Session state tables should be "
"initialized by create_session."
),
)
storage_user_state = await _select_required_state(
sql_session=sql_session,
state_model=schema.StorageUserState,
predicates=(
schema.StorageUserState.app_name == session.app_name,
schema.StorageUserState.user_id == session.user_id,
),
use_row_level_locking=use_row_level_locking,
missing_message=(
"User state missing for app_name="
f"{session.app_name!r}, user_id={session.user_id!r}. "
"Session state tables should be initialized by "
"create_session."
),
)
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state_delta = state_deltas["session"]
# Merge state and update storage
if app_state_delta:
storage_app_state.state = storage_app_state.state | app_state_delta
if user_state_delta:
storage_user_state.state = storage_user_state.state | user_state_delta
if session_state_delta:
storage_session.state = storage_session.state | session_state_delta
if is_sqlite:
update_time = datetime.fromtimestamp(
event.timestamp, timezone.utc
).replace(tzinfo=None)
else:
update_time = datetime.fromtimestamp(event.timestamp)
storage_session.update_time = update_time
sql_session.add(schema.StorageEvent.from_event(session, event))
if (
storage_session.get_update_timestamp(is_sqlite)
> session.last_update_time
):
# Reload the session from storage if it has been updated since it was
# loaded.
app_state = storage_app_state.state
user_state = storage_user_state.state
session_state = storage_session.state
session.state = _merge_state(app_state, user_state, session_state)
await sql_session.commit()
stmt = (
select(schema.StorageEvent)
.filter(schema.StorageEvent.app_name == session.app_name)
.filter(schema.StorageEvent.session_id == session.id)
.filter(schema.StorageEvent.user_id == session.user_id)
.order_by(schema.StorageEvent.timestamp.asc())
)
result = await sql_session.stream_scalars(stmt)
storage_events = [e async for e in result]
session.events = [e.to_event() for e in storage_events]
# Update timestamp with commit time
session.last_update_time = storage_session.get_update_timestamp(is_sqlite)
# Extract state delta
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(
event.actions.state_delta
)
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state_delta = state_deltas["session"]
# Merge state and update storage
if app_state_delta:
storage_app_state.state = storage_app_state.state | app_state_delta
if user_state_delta:
storage_user_state.state = (
storage_user_state.state | user_state_delta
)
if session_state_delta:
storage_session.state = storage_session.state | session_state_delta
if is_sqlite:
update_time = datetime.fromtimestamp(
event.timestamp, timezone.utc
).replace(tzinfo=None)
else:
update_time = datetime.fromtimestamp(event.timestamp)
storage_session.update_time = update_time
sql_session.add(schema.StorageEvent.from_event(session, event))
await sql_session.commit()
# Update timestamp with commit time
session.last_update_time = storage_session.get_update_timestamp(
is_sqlite
)
# Also update the in-memory session
await super().append_event(session=session, event=event)
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import datetime
from datetime import timezone
import enum
@@ -28,6 +29,7 @@ from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.sqlite_session_service import SqliteSessionService
from google.genai import types
import pytest
from sqlalchemy import delete
class SessionServiceType(enum.Enum):
@@ -613,6 +615,110 @@ async def test_append_event_to_stale_session():
]
@pytest.mark.asyncio
async def test_append_event_raises_if_app_state_row_missing():
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
session = await service.create_session(
app_name='my_app', user_id='user', session_id='s1'
)
schema = service._get_schema_classes()
async with service.database_session_factory() as sql_session:
await sql_session.execute(
delete(schema.StorageAppState).where(
schema.StorageAppState.app_name == session.app_name
)
)
await sql_session.commit()
event = Event(
invocation_id='inv1',
author='user',
actions=EventActions(state_delta={'k': 'v'}),
)
with pytest.raises(ValueError, match='App state missing'):
await service.append_event(session, event)
finally:
await service.close()
@pytest.mark.asyncio
async def test_append_event_raises_if_user_state_row_missing():
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
session = await service.create_session(
app_name='my_app', user_id='user', session_id='s1'
)
schema = service._get_schema_classes()
async with service.database_session_factory() as sql_session:
await sql_session.execute(
delete(schema.StorageUserState).where(
schema.StorageUserState.app_name == session.app_name,
schema.StorageUserState.user_id == session.user_id,
)
)
await sql_session.commit()
event = Event(
invocation_id='inv1',
author='user',
actions=EventActions(state_delta={'k': 'v'}),
)
with pytest.raises(ValueError, match='User state missing'):
await service.append_event(session, event)
finally:
await service.close()
@pytest.mark.asyncio
async def test_append_event_concurrent_stale_sessions_preserve_all_state():
session_service = get_session_service(
service_type=SessionServiceType.DATABASE
)
async with session_service:
app_name = 'my_app'
user_id = 'user'
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
iteration_count = 8
for i in range(iteration_count):
latest_session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
stale_session_1 = latest_session.model_copy(deep=True)
stale_session_2 = latest_session.model_copy(deep=True)
base_timestamp = latest_session.last_update_time + 10.0
event_1 = Event(
invocation_id=f'inv{i}-1',
author='user',
timestamp=base_timestamp + 1.0,
actions=EventActions(state_delta={f'sk{i}-1': f'v{i}-1'}),
)
event_2 = Event(
invocation_id=f'inv{i}-2',
author='user',
timestamp=base_timestamp + 2.0,
actions=EventActions(state_delta={f'sk{i}-2': f'v{i}-2'}),
)
await asyncio.gather(
session_service.append_event(stale_session_1, event_1),
session_service.append_event(stale_session_2, event_2),
)
session_final = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
for i in range(iteration_count):
assert session_final.state.get(f'sk{i}-1') == f'v{i}-1'
assert session_final.state.get(f'sk{i}-2') == f'v{i}-2'
assert len(session_final.events) == iteration_count * 2
@pytest.mark.asyncio
async def test_get_session_with_config(session_service):
app_name = 'my_app'