diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 5d987aab..91c22fd2 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import asyncio import copy from datetime import datetime from datetime import timezone @@ -30,20 +31,21 @@ from sqlalchemy import Dialect from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func +from sqlalchemy import select from sqlalchemy import Text from sqlalchemy.dialects import mysql from sqlalchemy.dialects import postgresql -from sqlalchemy.engine import create_engine -from sqlalchemy.engine import Engine from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory +from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.inspection import inspect from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session as DatabaseSessionFactory -from sqlalchemy.orm import sessionmaker from sqlalchemy.schema import MetaData from sqlalchemy.types import DateTime from sqlalchemy.types import PickleType @@ -417,11 +419,11 @@ class DatabaseSessionService(BaseSessionService): # 2. Create all tables based on schema # 3. Initialize all properties try: - db_engine = create_engine(db_url, **kwargs) + db_engine = create_async_engine(db_url, **kwargs) if db_engine.dialect.name == "sqlite": # Set sqlite pragma to enable foreign keys constraints - event.listen(db_engine, "connect", set_sqlite_pragma) + event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma) except Exception as e: if isinstance(e, ArgumentError): @@ -440,18 +442,32 @@ class DatabaseSessionService(BaseSessionService): local_timezone = get_localzone() logger.info("Local timezone: %s", local_timezone) - self.db_engine: Engine = db_engine + self.db_engine: AsyncEngine = db_engine self.metadata: MetaData = MetaData() - self.inspector = inspect(self.db_engine) # DB session factory method - self.database_session_factory: sessionmaker[DatabaseSessionFactory] = ( - sessionmaker(bind=self.db_engine) - ) + self.database_session_factory: async_sessionmaker[ + DatabaseSessionFactory + ] = async_sessionmaker(bind=self.db_engine, expire_on_commit=False) - # Uncomment to recreate DB every time - # Base.metadata.drop_all(self.db_engine) - Base.metadata.create_all(self.db_engine) + # Flag to indicate if tables are created + self._tables_created = False + # Lock to ensure thread-safe table creation + self._table_creation_lock = asyncio.Lock() + + async def _ensure_tables_created(self): + """Ensure database tables are created. This is called lazily.""" + if self._tables_created: + return + + async with self._table_creation_lock: + # Double-check after acquiring the lock + if not self._tables_created: + async with self.db_engine.begin() as conn: + # Uncomment to recreate DB every time + # await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + self._tables_created = True @override async def create_session( @@ -467,22 +483,25 @@ class DatabaseSessionService(BaseSessionService): # 3. Add the object to the table # 4. Build the session object with generated id # 5. Return the session + await self._ensure_tables_created() + async with self.database_session_factory() as sql_session: - with self.database_session_factory() as sql_session: - if session_id and sql_session.get( + if session_id and await sql_session.get( StorageSession, (app_name, user_id, session_id) ): raise AlreadyExistsError( f"Session with id {session_id} already exists." ) # Fetch app and user states from storage - storage_app_state = sql_session.get(StorageAppState, (app_name)) + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + storage_user_state = await sql_session.get( + StorageUserState, (app_name, user_id) + ) + + # Create state tables if not exist if not storage_app_state: storage_app_state = StorageAppState(app_name=app_name, state={}) sql_session.add(storage_app_state) - storage_user_state = sql_session.get( - StorageUserState, (app_name, user_id) - ) if not storage_user_state: storage_user_state = StorageUserState( app_name=app_name, user_id=user_id, state={} @@ -509,9 +528,9 @@ class DatabaseSessionService(BaseSessionService): state=session_state, ) sql_session.add(storage_session) - sql_session.commit() + await sql_session.commit() - sql_session.refresh(storage_session) + await sql_session.refresh(storage_session) # Merge states for response merged_state = _merge_state( @@ -529,39 +548,39 @@ class DatabaseSessionService(BaseSessionService): session_id: str, config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: + await self._ensure_tables_created() # 1. Get the storage session entry from session table # 2. Get all the events based on session id and filtering config # 3. Convert and return the session - with self.database_session_factory() as sql_session: - storage_session = sql_session.get( + async with self.database_session_factory() as sql_session: + storage_session = await sql_session.get( StorageSession, (app_name, user_id, session_id) ) if storage_session is None: return None - query = sql_session.query(StorageEvent).filter( - StorageEvent.app_name == app_name, - StorageEvent.user_id == user_id, - StorageEvent.session_id == storage_session.id, + stmt = ( + select(StorageEvent) + .filter(StorageEvent.app_name == app_name) + .filter(StorageEvent.session_id == storage_session.id) + .filter(StorageEvent.user_id == user_id) ) if config and config.after_timestamp: after_dt = datetime.fromtimestamp(config.after_timestamp) - query = query.filter(StorageEvent.timestamp >= after_dt) + stmt = stmt.filter(StorageEvent.timestamp >= after_dt) - storage_events = ( - query.order_by(StorageEvent.timestamp.desc()) - .limit( - config.num_recent_events - if config and config.num_recent_events - else None - ) - .all() - ) + stmt = stmt.order_by(StorageEvent.timestamp.desc()) + + if config and config.num_recent_events: + stmt = stmt.limit(config.num_recent_events) + + result = await sql_session.execute(stmt) + storage_events = result.scalars().all() # Fetch states from storage - storage_app_state = sql_session.get(StorageAppState, (app_name)) - storage_user_state = sql_session.get( + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + storage_user_state = await sql_session.get( StorageUserState, (app_name, user_id) ) @@ -581,32 +600,33 @@ class DatabaseSessionService(BaseSessionService): async def list_sessions( self, *, app_name: str, user_id: Optional[str] = None ) -> ListSessionsResponse: - with self.database_session_factory() as sql_session: - query = sql_session.query(StorageSession).filter( - StorageSession.app_name == app_name - ) + await self._ensure_tables_created() + async with self.database_session_factory() as sql_session: + stmt = select(StorageSession).filter(StorageSession.app_name == app_name) if user_id is not None: - query = query.filter(StorageSession.user_id == user_id) - results = query.all() + stmt = stmt.filter(StorageSession.user_id == user_id) + + result = await sql_session.execute(stmt) + results = result.scalars().all() # Fetch app state from storage - storage_app_state = sql_session.get(StorageAppState, (app_name)) + storage_app_state = await sql_session.get(StorageAppState, (app_name)) app_state = storage_app_state.state if storage_app_state else {} # Fetch user state(s) from storage user_states_map = {} if user_id is not None: - storage_user_state = sql_session.get( + storage_user_state = await sql_session.get( StorageUserState, (app_name, user_id) ) if storage_user_state: user_states_map[user_id] = storage_user_state.state else: - all_user_states_for_app = ( - sql_session.query(StorageUserState) - .filter(StorageUserState.app_name == app_name) - .all() + user_state_stmt = select(StorageUserState).filter( + StorageUserState.app_name == app_name ) + user_state_result = await sql_session.execute(user_state_stmt) + all_user_states_for_app = user_state_result.scalars().all() for storage_user_state in all_user_states_for_app: user_states_map[storage_user_state.user_id] = storage_user_state.state @@ -622,17 +642,19 @@ class DatabaseSessionService(BaseSessionService): async def delete_session( self, app_name: str, user_id: str, session_id: str ) -> None: - with self.database_session_factory() as sql_session: + await self._ensure_tables_created() + async with self.database_session_factory() as sql_session: stmt = delete(StorageSession).where( StorageSession.app_name == app_name, StorageSession.user_id == user_id, StorageSession.id == session_id, ) - sql_session.execute(stmt) - sql_session.commit() + await sql_session.execute(stmt) + await sql_session.commit() @override async def append_event(self, session: Session, event: Event) -> Event: + await self._ensure_tables_created() if event.partial: return event @@ -642,8 +664,8 @@ class DatabaseSessionService(BaseSessionService): # 1. Check if timestamp is stale # 2. Update session attributes based on event config # 3. Store event to table - with self.database_session_factory() as sql_session: - storage_session = sql_session.get( + async with self.database_session_factory() as sql_session: + storage_session = await sql_session.get( StorageSession, (session.app_name, session.user_id, session.id) ) @@ -657,8 +679,10 @@ class DatabaseSessionService(BaseSessionService): ) # Fetch states from storage - storage_app_state = sql_session.get(StorageAppState, (session.app_name)) - storage_user_state = sql_session.get( + storage_app_state = await sql_session.get( + StorageAppState, (session.app_name) + ) + storage_user_state = await sql_session.get( StorageUserState, (session.app_name, session.user_id) ) @@ -680,8 +704,8 @@ class DatabaseSessionService(BaseSessionService): sql_session.add(StorageEvent.from_event(session, event)) - sql_session.commit() - sql_session.refresh(storage_session) + await sql_session.commit() + await sql_session.refresh(storage_session) # Update timestamp with commit time session.last_update_time = storage_session.update_timestamp_tz @@ -691,8 +715,12 @@ class DatabaseSessionService(BaseSessionService): return event -def _merge_state(app_state, user_state, session_state): - # Merge states for response +def _merge_state( + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], +) -> dict[str, Any]: + """Merge app, user, and session states into a single state dictionary.""" merged_state = copy.deepcopy(session_state) for key in app_state.keys(): merged_state[State.APP_PREFIX + key] = app_state[key] diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 3501eaa0..7fb91c9d 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -39,7 +39,7 @@ def get_session_service( ): """Creates a session service for testing.""" if service_type == SessionServiceType.DATABASE: - return DatabaseSessionService('sqlite:///:memory:') + return DatabaseSessionService('sqlite+aiosqlite:///:memory:') if service_type == SessionServiceType.SQLITE: return SqliteSessionService(str(tmp_path / 'sqlite.db')) return InMemorySessionService()