From 74959414d8ded733d584875a49fb4638a12d3ce5 Mon Sep 17 00:00:00 2001 From: GitMarco27 Date: Mon, 10 Nov 2025 11:17:25 -0800 Subject: [PATCH] feat: full async implementation of DatabaseSessionService Merge https://github.com/google/adk-python/pull/2889 # Implement Full async DatabaseSessionService **Target Issue:** #1005 ## Overview This PR introduces an asynchronous implementation of the `DatabaseSessionService` with minimal breaking changes. The primary goal is to enable effective use of ADK in fully async environments and API endpoints while avoiding event loop blocking during database I/O operations. ## Changes - Converted `DatabaseSessionService` to use async/await patterns throughout ## Testing Plan The implementation has been tested following the project's contribution guidelines: ### Unit Tests - All existing unit tests pass successfully - Minor update to test requirements added to support `aiosqlite` ### Manual End-to-End Testing - E2E tests performed using: - **LLM Provider:** LiteLLM - **Database:** PostgreSQL with `asyncpg` driver ```python from google.adk.sessions.database_session_service import DatabaseSessionService connection_string: str = ( "postgresql+asyncpg://PG_USER:PG_PSWD@PG_HOST:5432/PG_DB" ) session_service: DatabaseSessionService = DatabaseSessionService( db_url=connection_string ) session = await session_service.create_session( app_name="test_app", session_id="test_session", user_id="test_user" ) assert session is not None sessions = await session_service.list_sessions(app_name="test_app", user_id="test_user") assert len(sessions.sessions) > 0 session = await session_service.get_session( app_name="test_app", session_id="test_session", user_id="test_user" ) assert session is not None await session_service.delete_session( app_name="test_app", session_id="test_session", user_id="test_user" ) assert ( await session_service.get_session( app_name="test_app", session_id="test_session", user_id="test_user" ) is None ) ``` The implementation have been also tested using the following configurations for llm provider and Runner: ```python def get_azure_openai_model(deployment_id: str | None = None) -> LiteLlm: ... if not deployment_id: deployment_id = os.getenv("AZURE_OPENAI_DEPLOYMENT_ID") logger.info(f"Using Azure OpenAI deployment ID: {deployment_id}") return LiteLlm( model=f"azure/{os.getenv('AZURE_OPENAI_DEPLOYMENT_ID')}", stream=True, ) ... @staticmethod def _get_runner(agent: Agent) -> Runner: storage=DatabaseSessionService(db_url=get_pg_connection_string()) return Runner( agent=agent, app_name=APP_NAME, session_service=storage, ) ... async for event in self.runner.run_async( user_id=user_id, session_id=session_id, new_message=content, run_config=( RunConfig( streaming_mode=StreamingMode.SSE, response_modalities=["TEXT"] ) if stream else RunConfig() ), ): last_event = event if stream: yield event ... ``` ## Breaking Changes - Database connection string format may need updates for async drivers Co-authored-by: Shangjie Chen COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/2889 from GitMarco27:feature/async_database_session_service e1b1b14934c1fb7975a6832cdd1549e94acab985 PiperOrigin-RevId: 830525148 --- .../adk/sessions/database_session_service.py | 154 +++++++++++------- .../sessions/test_session_service.py | 2 +- 2 files changed, 92 insertions(+), 64 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 5d987aab..91c22fd2 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import asyncio import copy from datetime import datetime from datetime import timezone @@ -30,20 +31,21 @@ from sqlalchemy import Dialect from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func +from sqlalchemy import select from sqlalchemy import Text from sqlalchemy.dialects import mysql from sqlalchemy.dialects import postgresql -from sqlalchemy.engine import create_engine -from sqlalchemy.engine import Engine from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory +from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.inspection import inspect from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session as DatabaseSessionFactory -from sqlalchemy.orm import sessionmaker from sqlalchemy.schema import MetaData from sqlalchemy.types import DateTime from sqlalchemy.types import PickleType @@ -417,11 +419,11 @@ class DatabaseSessionService(BaseSessionService): # 2. Create all tables based on schema # 3. Initialize all properties try: - db_engine = create_engine(db_url, **kwargs) + db_engine = create_async_engine(db_url, **kwargs) if db_engine.dialect.name == "sqlite": # Set sqlite pragma to enable foreign keys constraints - event.listen(db_engine, "connect", set_sqlite_pragma) + event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma) except Exception as e: if isinstance(e, ArgumentError): @@ -440,18 +442,32 @@ class DatabaseSessionService(BaseSessionService): local_timezone = get_localzone() logger.info("Local timezone: %s", local_timezone) - self.db_engine: Engine = db_engine + self.db_engine: AsyncEngine = db_engine self.metadata: MetaData = MetaData() - self.inspector = inspect(self.db_engine) # DB session factory method - self.database_session_factory: sessionmaker[DatabaseSessionFactory] = ( - sessionmaker(bind=self.db_engine) - ) + self.database_session_factory: async_sessionmaker[ + DatabaseSessionFactory + ] = async_sessionmaker(bind=self.db_engine, expire_on_commit=False) - # Uncomment to recreate DB every time - # Base.metadata.drop_all(self.db_engine) - Base.metadata.create_all(self.db_engine) + # Flag to indicate if tables are created + self._tables_created = False + # Lock to ensure thread-safe table creation + self._table_creation_lock = asyncio.Lock() + + async def _ensure_tables_created(self): + """Ensure database tables are created. This is called lazily.""" + if self._tables_created: + return + + async with self._table_creation_lock: + # Double-check after acquiring the lock + if not self._tables_created: + async with self.db_engine.begin() as conn: + # Uncomment to recreate DB every time + # await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + self._tables_created = True @override async def create_session( @@ -467,22 +483,25 @@ class DatabaseSessionService(BaseSessionService): # 3. Add the object to the table # 4. Build the session object with generated id # 5. Return the session + await self._ensure_tables_created() + async with self.database_session_factory() as sql_session: - with self.database_session_factory() as sql_session: - if session_id and sql_session.get( + if session_id and await sql_session.get( StorageSession, (app_name, user_id, session_id) ): raise AlreadyExistsError( f"Session with id {session_id} already exists." ) # Fetch app and user states from storage - storage_app_state = sql_session.get(StorageAppState, (app_name)) + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + storage_user_state = await sql_session.get( + StorageUserState, (app_name, user_id) + ) + + # Create state tables if not exist if not storage_app_state: storage_app_state = StorageAppState(app_name=app_name, state={}) sql_session.add(storage_app_state) - storage_user_state = sql_session.get( - StorageUserState, (app_name, user_id) - ) if not storage_user_state: storage_user_state = StorageUserState( app_name=app_name, user_id=user_id, state={} @@ -509,9 +528,9 @@ class DatabaseSessionService(BaseSessionService): state=session_state, ) sql_session.add(storage_session) - sql_session.commit() + await sql_session.commit() - sql_session.refresh(storage_session) + await sql_session.refresh(storage_session) # Merge states for response merged_state = _merge_state( @@ -529,39 +548,39 @@ class DatabaseSessionService(BaseSessionService): session_id: str, config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: + await self._ensure_tables_created() # 1. Get the storage session entry from session table # 2. Get all the events based on session id and filtering config # 3. Convert and return the session - with self.database_session_factory() as sql_session: - storage_session = sql_session.get( + async with self.database_session_factory() as sql_session: + storage_session = await sql_session.get( StorageSession, (app_name, user_id, session_id) ) if storage_session is None: return None - query = sql_session.query(StorageEvent).filter( - StorageEvent.app_name == app_name, - StorageEvent.user_id == user_id, - StorageEvent.session_id == storage_session.id, + stmt = ( + select(StorageEvent) + .filter(StorageEvent.app_name == app_name) + .filter(StorageEvent.session_id == storage_session.id) + .filter(StorageEvent.user_id == user_id) ) if config and config.after_timestamp: after_dt = datetime.fromtimestamp(config.after_timestamp) - query = query.filter(StorageEvent.timestamp >= after_dt) + stmt = stmt.filter(StorageEvent.timestamp >= after_dt) - storage_events = ( - query.order_by(StorageEvent.timestamp.desc()) - .limit( - config.num_recent_events - if config and config.num_recent_events - else None - ) - .all() - ) + stmt = stmt.order_by(StorageEvent.timestamp.desc()) + + if config and config.num_recent_events: + stmt = stmt.limit(config.num_recent_events) + + result = await sql_session.execute(stmt) + storage_events = result.scalars().all() # Fetch states from storage - storage_app_state = sql_session.get(StorageAppState, (app_name)) - storage_user_state = sql_session.get( + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + storage_user_state = await sql_session.get( StorageUserState, (app_name, user_id) ) @@ -581,32 +600,33 @@ class DatabaseSessionService(BaseSessionService): async def list_sessions( self, *, app_name: str, user_id: Optional[str] = None ) -> ListSessionsResponse: - with self.database_session_factory() as sql_session: - query = sql_session.query(StorageSession).filter( - StorageSession.app_name == app_name - ) + await self._ensure_tables_created() + async with self.database_session_factory() as sql_session: + stmt = select(StorageSession).filter(StorageSession.app_name == app_name) if user_id is not None: - query = query.filter(StorageSession.user_id == user_id) - results = query.all() + stmt = stmt.filter(StorageSession.user_id == user_id) + + result = await sql_session.execute(stmt) + results = result.scalars().all() # Fetch app state from storage - storage_app_state = sql_session.get(StorageAppState, (app_name)) + storage_app_state = await sql_session.get(StorageAppState, (app_name)) app_state = storage_app_state.state if storage_app_state else {} # Fetch user state(s) from storage user_states_map = {} if user_id is not None: - storage_user_state = sql_session.get( + storage_user_state = await sql_session.get( StorageUserState, (app_name, user_id) ) if storage_user_state: user_states_map[user_id] = storage_user_state.state else: - all_user_states_for_app = ( - sql_session.query(StorageUserState) - .filter(StorageUserState.app_name == app_name) - .all() + user_state_stmt = select(StorageUserState).filter( + StorageUserState.app_name == app_name ) + user_state_result = await sql_session.execute(user_state_stmt) + all_user_states_for_app = user_state_result.scalars().all() for storage_user_state in all_user_states_for_app: user_states_map[storage_user_state.user_id] = storage_user_state.state @@ -622,17 +642,19 @@ class DatabaseSessionService(BaseSessionService): async def delete_session( self, app_name: str, user_id: str, session_id: str ) -> None: - with self.database_session_factory() as sql_session: + await self._ensure_tables_created() + async with self.database_session_factory() as sql_session: stmt = delete(StorageSession).where( StorageSession.app_name == app_name, StorageSession.user_id == user_id, StorageSession.id == session_id, ) - sql_session.execute(stmt) - sql_session.commit() + await sql_session.execute(stmt) + await sql_session.commit() @override async def append_event(self, session: Session, event: Event) -> Event: + await self._ensure_tables_created() if event.partial: return event @@ -642,8 +664,8 @@ class DatabaseSessionService(BaseSessionService): # 1. Check if timestamp is stale # 2. Update session attributes based on event config # 3. Store event to table - with self.database_session_factory() as sql_session: - storage_session = sql_session.get( + async with self.database_session_factory() as sql_session: + storage_session = await sql_session.get( StorageSession, (session.app_name, session.user_id, session.id) ) @@ -657,8 +679,10 @@ class DatabaseSessionService(BaseSessionService): ) # Fetch states from storage - storage_app_state = sql_session.get(StorageAppState, (session.app_name)) - storage_user_state = sql_session.get( + storage_app_state = await sql_session.get( + StorageAppState, (session.app_name) + ) + storage_user_state = await sql_session.get( StorageUserState, (session.app_name, session.user_id) ) @@ -680,8 +704,8 @@ class DatabaseSessionService(BaseSessionService): sql_session.add(StorageEvent.from_event(session, event)) - sql_session.commit() - sql_session.refresh(storage_session) + await sql_session.commit() + await sql_session.refresh(storage_session) # Update timestamp with commit time session.last_update_time = storage_session.update_timestamp_tz @@ -691,8 +715,12 @@ class DatabaseSessionService(BaseSessionService): return event -def _merge_state(app_state, user_state, session_state): - # Merge states for response +def _merge_state( + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], +) -> dict[str, Any]: + """Merge app, user, and session states into a single state dictionary.""" merged_state = copy.deepcopy(session_state) for key in app_state.keys(): merged_state[State.APP_PREFIX + key] = app_state[key] diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 3501eaa0..7fb91c9d 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -39,7 +39,7 @@ def get_session_service( ): """Creates a session service for testing.""" if service_type == SessionServiceType.DATABASE: - return DatabaseSessionService('sqlite:///:memory:') + return DatabaseSessionService('sqlite+aiosqlite:///:memory:') if service_type == SessionServiceType.SQLITE: return SqliteSessionService(str(tmp_path / 'sqlite.db')) return InMemorySessionService()