feat: full async implementation of DatabaseSessionService

Merge https://github.com/google/adk-python/pull/2889

# Implement Full async DatabaseSessionService

**Target Issue:** #1005

## Overview

This PR introduces an asynchronous implementation of the `DatabaseSessionService` with minimal breaking changes. The primary goal is to enable effective use of ADK in fully async environments and API endpoints while avoiding event loop blocking during database I/O operations.

## Changes

- Converted `DatabaseSessionService` to use async/await patterns throughout

## Testing Plan

The implementation has been tested following the project's contribution guidelines:

### Unit Tests
- All existing unit tests pass successfully
- Minor update to test requirements added to support `aiosqlite`

### Manual End-to-End Testing
- E2E tests performed using:
  - **LLM Provider:** LiteLLM
  - **Database:** PostgreSQL with `asyncpg` driver

 ```python
from google.adk.sessions.database_session_service import DatabaseSessionService

connection_string: str = (
    "postgresql+asyncpg://PG_USER:PG_PSWD@PG_HOST:5432/PG_DB"
)
session_service: DatabaseSessionService = DatabaseSessionService(
    db_url=connection_string
)

session = await session_service.create_session(
    app_name="test_app", session_id="test_session", user_id="test_user"
)
assert session is not None

sessions = await session_service.list_sessions(app_name="test_app", user_id="test_user")
assert len(sessions.sessions) > 0

session = await session_service.get_session(
    app_name="test_app", session_id="test_session", user_id="test_user"
)
assert session is not None

await session_service.delete_session(
    app_name="test_app", session_id="test_session", user_id="test_user"
)
assert (
    await session_service.get_session(
        app_name="test_app", session_id="test_session", user_id="test_user"
    )
    is None
)
```

The implementation have been also tested using the following configurations for llm provider and Runner:

```python
def get_azure_openai_model(deployment_id: str | None = None) -> LiteLlm:
    ...

    if not deployment_id:
        deployment_id = os.getenv("AZURE_OPENAI_DEPLOYMENT_ID")

    logger.info(f"Using Azure OpenAI deployment ID: {deployment_id}")

    return LiteLlm(
        model=f"azure/{os.getenv('AZURE_OPENAI_DEPLOYMENT_ID')}",
        stream=True,
    )

...

    @staticmethod
    def _get_runner(agent: Agent) -> Runner:
        storage=DatabaseSessionService(db_url=get_pg_connection_string())
        return Runner(
            agent=agent,
            app_name=APP_NAME,
            session_service=storage,
        )

...

async for event in self.runner.run_async(
    user_id=user_id,
    session_id=session_id,
    new_message=content,
    run_config=(
        RunConfig(
            streaming_mode=StreamingMode.SSE, response_modalities=["TEXT"]
        )
        if stream
        else RunConfig()
    ),
):
    last_event = event
    if stream:
        yield event

...

```

## Breaking Changes

- Database connection string format may need updates for async drivers

Co-authored-by: Shangjie Chen <deanchen@google.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/2889 from GitMarco27:feature/async_database_session_service e1b1b14934c1fb7975a6832cdd1549e94acab985
PiperOrigin-RevId: 830525148
This commit is contained in:
GitMarco27
2025-11-10 11:17:25 -08:00
committed by Copybara-Service
parent 2443a1b74f
commit 74959414d8
2 changed files with 92 additions and 64 deletions
@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations
import asyncio
import copy
from datetime import datetime
from datetime import timezone
@@ -30,20 +31,21 @@ from sqlalchemy import Dialect
from sqlalchemy import event
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import Text
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session as DatabaseSessionFactory
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import MetaData
from sqlalchemy.types import DateTime
from sqlalchemy.types import PickleType
@@ -417,11 +419,11 @@ class DatabaseSessionService(BaseSessionService):
# 2. Create all tables based on schema
# 3. Initialize all properties
try:
db_engine = create_engine(db_url, **kwargs)
db_engine = create_async_engine(db_url, **kwargs)
if db_engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine, "connect", set_sqlite_pragma)
event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma)
except Exception as e:
if isinstance(e, ArgumentError):
@@ -440,18 +442,32 @@ class DatabaseSessionService(BaseSessionService):
local_timezone = get_localzone()
logger.info("Local timezone: %s", local_timezone)
self.db_engine: Engine = db_engine
self.db_engine: AsyncEngine = db_engine
self.metadata: MetaData = MetaData()
self.inspector = inspect(self.db_engine)
# DB session factory method
self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
sessionmaker(bind=self.db_engine)
)
self.database_session_factory: async_sessionmaker[
DatabaseSessionFactory
] = async_sessionmaker(bind=self.db_engine, expire_on_commit=False)
# Uncomment to recreate DB every time
# Base.metadata.drop_all(self.db_engine)
Base.metadata.create_all(self.db_engine)
# Flag to indicate if tables are created
self._tables_created = False
# Lock to ensure thread-safe table creation
self._table_creation_lock = asyncio.Lock()
async def _ensure_tables_created(self):
"""Ensure database tables are created. This is called lazily."""
if self._tables_created:
return
async with self._table_creation_lock:
# Double-check after acquiring the lock
if not self._tables_created:
async with self.db_engine.begin() as conn:
# Uncomment to recreate DB every time
# await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
self._tables_created = True
@override
async def create_session(
@@ -467,22 +483,25 @@ class DatabaseSessionService(BaseSessionService):
# 3. Add the object to the table
# 4. Build the session object with generated id
# 5. Return the session
await self._ensure_tables_created()
async with self.database_session_factory() as sql_session:
with self.database_session_factory() as sql_session:
if session_id and sql_session.get(
if session_id and await sql_session.get(
StorageSession, (app_name, user_id, session_id)
):
raise AlreadyExistsError(
f"Session with id {session_id} already exists."
)
# Fetch app and user states from storage
storage_app_state = sql_session.get(StorageAppState, (app_name))
storage_app_state = await sql_session.get(StorageAppState, (app_name))
storage_user_state = await sql_session.get(
StorageUserState, (app_name, user_id)
)
# Create state tables if not exist
if not storage_app_state:
storage_app_state = StorageAppState(app_name=app_name, state={})
sql_session.add(storage_app_state)
storage_user_state = sql_session.get(
StorageUserState, (app_name, user_id)
)
if not storage_user_state:
storage_user_state = StorageUserState(
app_name=app_name, user_id=user_id, state={}
@@ -509,9 +528,9 @@ class DatabaseSessionService(BaseSessionService):
state=session_state,
)
sql_session.add(storage_session)
sql_session.commit()
await sql_session.commit()
sql_session.refresh(storage_session)
await sql_session.refresh(storage_session)
# Merge states for response
merged_state = _merge_state(
@@ -529,39 +548,39 @@ class DatabaseSessionService(BaseSessionService):
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
await self._ensure_tables_created()
# 1. Get the storage session entry from session table
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
with self.database_session_factory() as sql_session:
storage_session = sql_session.get(
async with self.database_session_factory() as sql_session:
storage_session = await sql_session.get(
StorageSession, (app_name, user_id, session_id)
)
if storage_session is None:
return None
query = sql_session.query(StorageEvent).filter(
StorageEvent.app_name == app_name,
StorageEvent.user_id == user_id,
StorageEvent.session_id == storage_session.id,
stmt = (
select(StorageEvent)
.filter(StorageEvent.app_name == app_name)
.filter(StorageEvent.session_id == storage_session.id)
.filter(StorageEvent.user_id == user_id)
)
if config and config.after_timestamp:
after_dt = datetime.fromtimestamp(config.after_timestamp)
query = query.filter(StorageEvent.timestamp >= after_dt)
stmt = stmt.filter(StorageEvent.timestamp >= after_dt)
storage_events = (
query.order_by(StorageEvent.timestamp.desc())
.limit(
config.num_recent_events
if config and config.num_recent_events
else None
)
.all()
)
stmt = stmt.order_by(StorageEvent.timestamp.desc())
if config and config.num_recent_events:
stmt = stmt.limit(config.num_recent_events)
result = await sql_session.execute(stmt)
storage_events = result.scalars().all()
# Fetch states from storage
storage_app_state = sql_session.get(StorageAppState, (app_name))
storage_user_state = sql_session.get(
storage_app_state = await sql_session.get(StorageAppState, (app_name))
storage_user_state = await sql_session.get(
StorageUserState, (app_name, user_id)
)
@@ -581,32 +600,33 @@ class DatabaseSessionService(BaseSessionService):
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
) -> ListSessionsResponse:
with self.database_session_factory() as sql_session:
query = sql_session.query(StorageSession).filter(
StorageSession.app_name == app_name
)
await self._ensure_tables_created()
async with self.database_session_factory() as sql_session:
stmt = select(StorageSession).filter(StorageSession.app_name == app_name)
if user_id is not None:
query = query.filter(StorageSession.user_id == user_id)
results = query.all()
stmt = stmt.filter(StorageSession.user_id == user_id)
result = await sql_session.execute(stmt)
results = result.scalars().all()
# Fetch app state from storage
storage_app_state = sql_session.get(StorageAppState, (app_name))
storage_app_state = await sql_session.get(StorageAppState, (app_name))
app_state = storage_app_state.state if storage_app_state else {}
# Fetch user state(s) from storage
user_states_map = {}
if user_id is not None:
storage_user_state = sql_session.get(
storage_user_state = await sql_session.get(
StorageUserState, (app_name, user_id)
)
if storage_user_state:
user_states_map[user_id] = storage_user_state.state
else:
all_user_states_for_app = (
sql_session.query(StorageUserState)
.filter(StorageUserState.app_name == app_name)
.all()
user_state_stmt = select(StorageUserState).filter(
StorageUserState.app_name == app_name
)
user_state_result = await sql_session.execute(user_state_stmt)
all_user_states_for_app = user_state_result.scalars().all()
for storage_user_state in all_user_states_for_app:
user_states_map[storage_user_state.user_id] = storage_user_state.state
@@ -622,17 +642,19 @@ class DatabaseSessionService(BaseSessionService):
async def delete_session(
self, app_name: str, user_id: str, session_id: str
) -> None:
with self.database_session_factory() as sql_session:
await self._ensure_tables_created()
async with self.database_session_factory() as sql_session:
stmt = delete(StorageSession).where(
StorageSession.app_name == app_name,
StorageSession.user_id == user_id,
StorageSession.id == session_id,
)
sql_session.execute(stmt)
sql_session.commit()
await sql_session.execute(stmt)
await sql_session.commit()
@override
async def append_event(self, session: Session, event: Event) -> Event:
await self._ensure_tables_created()
if event.partial:
return event
@@ -642,8 +664,8 @@ class DatabaseSessionService(BaseSessionService):
# 1. Check if timestamp is stale
# 2. Update session attributes based on event config
# 3. Store event to table
with self.database_session_factory() as sql_session:
storage_session = sql_session.get(
async with self.database_session_factory() as sql_session:
storage_session = await sql_session.get(
StorageSession, (session.app_name, session.user_id, session.id)
)
@@ -657,8 +679,10 @@ class DatabaseSessionService(BaseSessionService):
)
# Fetch states from storage
storage_app_state = sql_session.get(StorageAppState, (session.app_name))
storage_user_state = sql_session.get(
storage_app_state = await sql_session.get(
StorageAppState, (session.app_name)
)
storage_user_state = await sql_session.get(
StorageUserState, (session.app_name, session.user_id)
)
@@ -680,8 +704,8 @@ class DatabaseSessionService(BaseSessionService):
sql_session.add(StorageEvent.from_event(session, event))
sql_session.commit()
sql_session.refresh(storage_session)
await sql_session.commit()
await sql_session.refresh(storage_session)
# Update timestamp with commit time
session.last_update_time = storage_session.update_timestamp_tz
@@ -691,8 +715,12 @@ class DatabaseSessionService(BaseSessionService):
return event
def _merge_state(app_state, user_state, session_state):
# Merge states for response
def _merge_state(
app_state: dict[str, Any],
user_state: dict[str, Any],
session_state: dict[str, Any],
) -> dict[str, Any]:
"""Merge app, user, and session states into a single state dictionary."""
merged_state = copy.deepcopy(session_state)
for key in app_state.keys():
merged_state[State.APP_PREFIX + key] = app_state[key]
@@ -39,7 +39,7 @@ def get_session_service(
):
"""Creates a session service for testing."""
if service_type == SessionServiceType.DATABASE:
return DatabaseSessionService('sqlite:///:memory:')
return DatabaseSessionService('sqlite+aiosqlite:///:memory:')
if service_type == SessionServiceType.SQLITE:
return SqliteSessionService(str(tmp_path / 'sqlite.db'))
return InMemorySessionService()