You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Per-session locking and row-level locking in DatabaseSessionService.append_event
This change introduces an in-process `asyncio.Lock` per session to serialize `append_event` calls for the same session ID within a single process. For supported database dialects (MySQL, PostgreSQL, MariaDB), it also uses `SELECT ... FOR UPDATE` to acquire row-level locks on the session, app state, and user state records, preventing race conditions across different processes or database connections. A new test case verifies that concurrent updates to stale session objects correctly merge all state changes. Close #1049 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 867752676
This commit is contained in:
committed by
Copybara-Service
parent
32ee07df01
commit
f50847460f
@@ -22,11 +22,12 @@ import logging
|
||||
from typing import Any
|
||||
from typing import AsyncIterator
|
||||
from typing import Optional
|
||||
from typing import TypeAlias
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.exc import ArgumentError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
@@ -59,6 +60,40 @@ from .state import State
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
_SQLITE_DIALECT = "sqlite"
|
||||
_MARIADB_DIALECT = "mariadb"
|
||||
_MYSQL_DIALECT = "mysql"
|
||||
_POSTGRESQL_DIALECT = "postgresql"
|
||||
# Tuple key order for in-process per-session lock maps:
|
||||
# (app_name, user_id, session_id).
|
||||
_SessionLockKey: TypeAlias = tuple[str, str, str]
|
||||
_StorageStateT = TypeVar(
|
||||
"_StorageStateT",
|
||||
StorageAppStateV0,
|
||||
StorageAppStateV1,
|
||||
StorageUserStateV0,
|
||||
StorageUserStateV1,
|
||||
)
|
||||
|
||||
|
||||
async def _select_required_state(
|
||||
*,
|
||||
sql_session: DatabaseSessionFactory,
|
||||
state_model: type[_StorageStateT],
|
||||
predicates: tuple[Any, ...],
|
||||
use_row_level_locking: bool,
|
||||
missing_message: str,
|
||||
) -> _StorageStateT:
|
||||
"""Returns a state row, raising if the row is missing."""
|
||||
stmt = select(state_model).filter(*predicates)
|
||||
if use_row_level_locking:
|
||||
stmt = stmt.with_for_update()
|
||||
result = await sql_session.execute(stmt)
|
||||
state_row = result.scalars().one_or_none()
|
||||
if state_row is None:
|
||||
raise ValueError(missing_message)
|
||||
return state_row
|
||||
|
||||
|
||||
def _set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
@@ -107,16 +142,19 @@ class DatabaseSessionService(BaseSessionService):
|
||||
try:
|
||||
engine_kwargs = dict(kwargs)
|
||||
url = make_url(db_url)
|
||||
if url.get_backend_name() == "sqlite" and url.database == ":memory:":
|
||||
if (
|
||||
url.get_backend_name() == _SQLITE_DIALECT
|
||||
and url.database == ":memory:"
|
||||
):
|
||||
engine_kwargs.setdefault("poolclass", StaticPool)
|
||||
connect_args = dict(engine_kwargs.get("connect_args", {}))
|
||||
connect_args.setdefault("check_same_thread", False)
|
||||
engine_kwargs["connect_args"] = connect_args
|
||||
elif url.get_backend_name() != "sqlite":
|
||||
elif url.get_backend_name() != _SQLITE_DIALECT:
|
||||
engine_kwargs.setdefault("pool_pre_ping", True)
|
||||
|
||||
db_engine = create_async_engine(db_url, **engine_kwargs)
|
||||
if db_engine.dialect.name == "sqlite":
|
||||
if db_engine.dialect.name == _SQLITE_DIALECT:
|
||||
# Set sqlite pragma to enable foreign keys constraints
|
||||
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)
|
||||
|
||||
@@ -152,6 +190,11 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# Lock to ensure thread-safe schema version check
|
||||
self._db_schema_lock = asyncio.Lock()
|
||||
|
||||
# Per-session locks used to serialize append_event calls in this process.
|
||||
self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {}
|
||||
self._session_lock_ref_count: dict[_SessionLockKey, int] = {}
|
||||
self._session_locks_guard = asyncio.Lock()
|
||||
|
||||
def _get_schema_classes(self) -> _SchemaClasses:
|
||||
return _SchemaClasses(self._db_schema_version)
|
||||
|
||||
@@ -172,6 +215,45 @@ class DatabaseSessionService(BaseSessionService):
|
||||
await sql_session.rollback()
|
||||
raise
|
||||
|
||||
def _supports_row_level_locking(self) -> bool:
|
||||
return self.db_engine.dialect.name in (
|
||||
_MARIADB_DIALECT,
|
||||
_MYSQL_DIALECT,
|
||||
_POSTGRESQL_DIALECT,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _with_session_lock(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> AsyncIterator[None]:
|
||||
"""Serializes event appends for the same session within this process."""
|
||||
# Use one lock per logical ADK session to prevent concurrent append_event
|
||||
# writes from racing in the same process.
|
||||
lock_key = (app_name, user_id, session_id)
|
||||
async with self._session_locks_guard:
|
||||
lock = self._session_locks.get(lock_key)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._session_locks[lock_key] = lock
|
||||
# Reference counting keeps lock objects alive while they are in use by
|
||||
# concurrent tasks and allows cleanup once all waiters complete.
|
||||
self._session_lock_ref_count[lock_key] = (
|
||||
self._session_lock_ref_count.get(lock_key, 0) + 1
|
||||
)
|
||||
|
||||
try:
|
||||
async with lock:
|
||||
yield
|
||||
finally:
|
||||
async with self._session_locks_guard:
|
||||
remaining = self._session_lock_ref_count.get(lock_key, 0) - 1
|
||||
# Remove lock bookkeeping after the last waiter exits.
|
||||
if remaining <= 0 and not lock.locked():
|
||||
self._session_lock_ref_count.pop(lock_key, None)
|
||||
self._session_locks.pop(lock_key, None)
|
||||
else:
|
||||
self._session_lock_ref_count[lock_key] = remaining
|
||||
|
||||
async def _prepare_tables(self):
|
||||
"""Ensure database tables are ready for use.
|
||||
|
||||
@@ -291,7 +373,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
|
||||
# Store the session
|
||||
now = datetime.now(timezone.utc)
|
||||
is_sqlite = self.db_engine.dialect.name == "sqlite"
|
||||
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
|
||||
if is_sqlite:
|
||||
now = now.replace(tzinfo=None)
|
||||
|
||||
@@ -372,7 +454,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
|
||||
# Convert storage session to session
|
||||
events = [e.to_event() for e in reversed(storage_events)]
|
||||
is_sqlite = self.db_engine.dialect.name == "sqlite"
|
||||
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
|
||||
session = storage_session.to_session(
|
||||
state=merged_state, events=events, is_sqlite=is_sqlite
|
||||
)
|
||||
@@ -418,7 +500,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
user_states_map[storage_user_state.user_id] = storage_user_state.state
|
||||
|
||||
sessions = []
|
||||
is_sqlite = self.db_engine.dialect.name == "sqlite"
|
||||
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
|
||||
for storage_session in results:
|
||||
session_state = storage_session.state
|
||||
user_state = user_states_map.get(storage_session.user_id, {})
|
||||
@@ -456,71 +538,109 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# 2. Update session attributes based on event config
|
||||
# 3. Store event to table
|
||||
schema = self._get_schema_classes()
|
||||
async with self._rollback_on_exception_session() as sql_session:
|
||||
storage_session = await sql_session.get(
|
||||
schema.StorageSession, (session.app_name, session.user_id, session.id)
|
||||
)
|
||||
|
||||
# Fetch states from storage
|
||||
storage_app_state = await sql_session.get(
|
||||
schema.StorageAppState, (session.app_name)
|
||||
)
|
||||
storage_user_state = await sql_session.get(
|
||||
schema.StorageUserState, (session.app_name, session.user_id)
|
||||
)
|
||||
|
||||
is_sqlite = self.db_engine.dialect.name == "sqlite"
|
||||
if (
|
||||
storage_session.get_update_timestamp(is_sqlite)
|
||||
> session.last_update_time
|
||||
):
|
||||
# Reload the session from storage if it has been updated since it was
|
||||
# loaded.
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
user_state = storage_user_state.state if storage_user_state else {}
|
||||
session_state = storage_session.state
|
||||
session.state = _merge_state(app_state, user_state, session_state)
|
||||
|
||||
stmt = (
|
||||
select(schema.StorageEvent)
|
||||
.filter(schema.StorageEvent.app_name == session.app_name)
|
||||
.filter(schema.StorageEvent.session_id == session.id)
|
||||
.filter(schema.StorageEvent.user_id == session.user_id)
|
||||
.order_by(schema.StorageEvent.timestamp.asc())
|
||||
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
|
||||
use_row_level_locking = self._supports_row_level_locking()
|
||||
async with self._with_session_lock(
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
):
|
||||
async with self._rollback_on_exception_session() as sql_session:
|
||||
storage_session_stmt = (
|
||||
select(schema.StorageSession)
|
||||
.filter(schema.StorageSession.app_name == session.app_name)
|
||||
.filter(schema.StorageSession.user_id == session.user_id)
|
||||
.filter(schema.StorageSession.id == session.id)
|
||||
)
|
||||
result = await sql_session.stream_scalars(stmt)
|
||||
storage_events = [e async for e in result]
|
||||
session.events = [e.to_event() for e in storage_events]
|
||||
if use_row_level_locking:
|
||||
storage_session_stmt = storage_session_stmt.with_for_update()
|
||||
storage_session_result = await sql_session.execute(storage_session_stmt)
|
||||
storage_session = storage_session_result.scalars().one_or_none()
|
||||
if storage_session is None:
|
||||
raise ValueError(f"Session {session.id} not found.")
|
||||
|
||||
# Extract state delta
|
||||
if event.actions and event.actions.state_delta:
|
||||
state_deltas = _session_util.extract_state_delta(
|
||||
event.actions.state_delta
|
||||
storage_app_state = await _select_required_state(
|
||||
sql_session=sql_session,
|
||||
state_model=schema.StorageAppState,
|
||||
predicates=(schema.StorageAppState.app_name == session.app_name,),
|
||||
use_row_level_locking=use_row_level_locking,
|
||||
missing_message=(
|
||||
"App state missing for app_name="
|
||||
f"{session.app_name!r}. Session state tables should be "
|
||||
"initialized by create_session."
|
||||
),
|
||||
)
|
||||
storage_user_state = await _select_required_state(
|
||||
sql_session=sql_session,
|
||||
state_model=schema.StorageUserState,
|
||||
predicates=(
|
||||
schema.StorageUserState.app_name == session.app_name,
|
||||
schema.StorageUserState.user_id == session.user_id,
|
||||
),
|
||||
use_row_level_locking=use_row_level_locking,
|
||||
missing_message=(
|
||||
"User state missing for app_name="
|
||||
f"{session.app_name!r}, user_id={session.user_id!r}. "
|
||||
"Session state tables should be initialized by "
|
||||
"create_session."
|
||||
),
|
||||
)
|
||||
app_state_delta = state_deltas["app"]
|
||||
user_state_delta = state_deltas["user"]
|
||||
session_state_delta = state_deltas["session"]
|
||||
# Merge state and update storage
|
||||
if app_state_delta:
|
||||
storage_app_state.state = storage_app_state.state | app_state_delta
|
||||
if user_state_delta:
|
||||
storage_user_state.state = storage_user_state.state | user_state_delta
|
||||
if session_state_delta:
|
||||
storage_session.state = storage_session.state | session_state_delta
|
||||
|
||||
if is_sqlite:
|
||||
update_time = datetime.fromtimestamp(
|
||||
event.timestamp, timezone.utc
|
||||
).replace(tzinfo=None)
|
||||
else:
|
||||
update_time = datetime.fromtimestamp(event.timestamp)
|
||||
storage_session.update_time = update_time
|
||||
sql_session.add(schema.StorageEvent.from_event(session, event))
|
||||
if (
|
||||
storage_session.get_update_timestamp(is_sqlite)
|
||||
> session.last_update_time
|
||||
):
|
||||
# Reload the session from storage if it has been updated since it was
|
||||
# loaded.
|
||||
app_state = storage_app_state.state
|
||||
user_state = storage_user_state.state
|
||||
session_state = storage_session.state
|
||||
session.state = _merge_state(app_state, user_state, session_state)
|
||||
|
||||
await sql_session.commit()
|
||||
stmt = (
|
||||
select(schema.StorageEvent)
|
||||
.filter(schema.StorageEvent.app_name == session.app_name)
|
||||
.filter(schema.StorageEvent.session_id == session.id)
|
||||
.filter(schema.StorageEvent.user_id == session.user_id)
|
||||
.order_by(schema.StorageEvent.timestamp.asc())
|
||||
)
|
||||
result = await sql_session.stream_scalars(stmt)
|
||||
storage_events = [e async for e in result]
|
||||
session.events = [e.to_event() for e in storage_events]
|
||||
|
||||
# Update timestamp with commit time
|
||||
session.last_update_time = storage_session.get_update_timestamp(is_sqlite)
|
||||
# Extract state delta
|
||||
if event.actions and event.actions.state_delta:
|
||||
state_deltas = _session_util.extract_state_delta(
|
||||
event.actions.state_delta
|
||||
)
|
||||
app_state_delta = state_deltas["app"]
|
||||
user_state_delta = state_deltas["user"]
|
||||
session_state_delta = state_deltas["session"]
|
||||
# Merge state and update storage
|
||||
if app_state_delta:
|
||||
storage_app_state.state = storage_app_state.state | app_state_delta
|
||||
if user_state_delta:
|
||||
storage_user_state.state = (
|
||||
storage_user_state.state | user_state_delta
|
||||
)
|
||||
if session_state_delta:
|
||||
storage_session.state = storage_session.state | session_state_delta
|
||||
|
||||
if is_sqlite:
|
||||
update_time = datetime.fromtimestamp(
|
||||
event.timestamp, timezone.utc
|
||||
).replace(tzinfo=None)
|
||||
else:
|
||||
update_time = datetime.fromtimestamp(event.timestamp)
|
||||
storage_session.update_time = update_time
|
||||
sql_session.add(schema.StorageEvent.from_event(session, event))
|
||||
|
||||
await sql_session.commit()
|
||||
|
||||
# Update timestamp with commit time
|
||||
session.last_update_time = storage_session.get_update_timestamp(
|
||||
is_sqlite
|
||||
)
|
||||
|
||||
# Also update the in-memory session
|
||||
await super().append_event(session=session, event=event)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
import enum
|
||||
@@ -28,6 +29,7 @@ from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.sqlite_session_service import SqliteSessionService
|
||||
from google.genai import types
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
|
||||
class SessionServiceType(enum.Enum):
|
||||
@@ -613,6 +615,110 @@ async def test_append_event_to_stale_session():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event_raises_if_app_state_row_missing():
|
||||
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
|
||||
try:
|
||||
session = await service.create_session(
|
||||
app_name='my_app', user_id='user', session_id='s1'
|
||||
)
|
||||
schema = service._get_schema_classes()
|
||||
async with service.database_session_factory() as sql_session:
|
||||
await sql_session.execute(
|
||||
delete(schema.StorageAppState).where(
|
||||
schema.StorageAppState.app_name == session.app_name
|
||||
)
|
||||
)
|
||||
await sql_session.commit()
|
||||
|
||||
event = Event(
|
||||
invocation_id='inv1',
|
||||
author='user',
|
||||
actions=EventActions(state_delta={'k': 'v'}),
|
||||
)
|
||||
with pytest.raises(ValueError, match='App state missing'):
|
||||
await service.append_event(session, event)
|
||||
finally:
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event_raises_if_user_state_row_missing():
|
||||
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
|
||||
try:
|
||||
session = await service.create_session(
|
||||
app_name='my_app', user_id='user', session_id='s1'
|
||||
)
|
||||
schema = service._get_schema_classes()
|
||||
async with service.database_session_factory() as sql_session:
|
||||
await sql_session.execute(
|
||||
delete(schema.StorageUserState).where(
|
||||
schema.StorageUserState.app_name == session.app_name,
|
||||
schema.StorageUserState.user_id == session.user_id,
|
||||
)
|
||||
)
|
||||
await sql_session.commit()
|
||||
|
||||
event = Event(
|
||||
invocation_id='inv1',
|
||||
author='user',
|
||||
actions=EventActions(state_delta={'k': 'v'}),
|
||||
)
|
||||
with pytest.raises(ValueError, match='User state missing'):
|
||||
await service.append_event(session, event)
|
||||
finally:
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event_concurrent_stale_sessions_preserve_all_state():
|
||||
session_service = get_session_service(
|
||||
service_type=SessionServiceType.DATABASE
|
||||
)
|
||||
|
||||
async with session_service:
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
|
||||
iteration_count = 8
|
||||
for i in range(iteration_count):
|
||||
latest_session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
stale_session_1 = latest_session.model_copy(deep=True)
|
||||
stale_session_2 = latest_session.model_copy(deep=True)
|
||||
base_timestamp = latest_session.last_update_time + 10.0
|
||||
event_1 = Event(
|
||||
invocation_id=f'inv{i}-1',
|
||||
author='user',
|
||||
timestamp=base_timestamp + 1.0,
|
||||
actions=EventActions(state_delta={f'sk{i}-1': f'v{i}-1'}),
|
||||
)
|
||||
event_2 = Event(
|
||||
invocation_id=f'inv{i}-2',
|
||||
author='user',
|
||||
timestamp=base_timestamp + 2.0,
|
||||
actions=EventActions(state_delta={f'sk{i}-2': f'v{i}-2'}),
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
session_service.append_event(stale_session_1, event_1),
|
||||
session_service.append_event(stale_session_2, event_2),
|
||||
)
|
||||
|
||||
session_final = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
|
||||
for i in range(iteration_count):
|
||||
assert session_final.state.get(f'sk{i}-1') == f'v{i}-1'
|
||||
assert session_final.state.get(f'sk{i}-2') == f'v{i}-2'
|
||||
assert len(session_final.events) == iteration_count * 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_with_config(session_service):
|
||||
app_name = 'my_app'
|
||||
|
||||
Reference in New Issue
Block a user