You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: full async implementation of DatabaseSessionService
Merge https://github.com/google/adk-python/pull/2889 # Implement Full async DatabaseSessionService **Target Issue:** #1005 ## Overview This PR introduces an asynchronous implementation of the `DatabaseSessionService` with minimal breaking changes. The primary goal is to enable effective use of ADK in fully async environments and API endpoints while avoiding event loop blocking during database I/O operations. ## Changes - Converted `DatabaseSessionService` to use async/await patterns throughout ## Testing Plan The implementation has been tested following the project's contribution guidelines: ### Unit Tests - All existing unit tests pass successfully - Minor update to test requirements added to support `aiosqlite` ### Manual End-to-End Testing - E2E tests performed using: - **LLM Provider:** LiteLLM - **Database:** PostgreSQL with `asyncpg` driver ```python from google.adk.sessions.database_session_service import DatabaseSessionService connection_string: str = ( "postgresql+asyncpg://PG_USER:PG_PSWD@PG_HOST:5432/PG_DB" ) session_service: DatabaseSessionService = DatabaseSessionService( db_url=connection_string ) session = await session_service.create_session( app_name="test_app", session_id="test_session", user_id="test_user" ) assert session is not None sessions = await session_service.list_sessions(app_name="test_app", user_id="test_user") assert len(sessions.sessions) > 0 session = await session_service.get_session( app_name="test_app", session_id="test_session", user_id="test_user" ) assert session is not None await session_service.delete_session( app_name="test_app", session_id="test_session", user_id="test_user" ) assert ( await session_service.get_session( app_name="test_app", session_id="test_session", user_id="test_user" ) is None ) ``` The implementation have been also tested using the following configurations for llm provider and Runner: ```python def get_azure_openai_model(deployment_id: str | None = None) -> LiteLlm: ... if not deployment_id: deployment_id = os.getenv("AZURE_OPENAI_DEPLOYMENT_ID") logger.info(f"Using Azure OpenAI deployment ID: {deployment_id}") return LiteLlm( model=f"azure/{os.getenv('AZURE_OPENAI_DEPLOYMENT_ID')}", stream=True, ) ... @staticmethod def _get_runner(agent: Agent) -> Runner: storage=DatabaseSessionService(db_url=get_pg_connection_string()) return Runner( agent=agent, app_name=APP_NAME, session_service=storage, ) ... async for event in self.runner.run_async( user_id=user_id, session_id=session_id, new_message=content, run_config=( RunConfig( streaming_mode=StreamingMode.SSE, response_modalities=["TEXT"] ) if stream else RunConfig() ), ): last_event = event if stream: yield event ... ``` ## Breaking Changes - Database connection string format may need updates for async drivers Co-authored-by: Shangjie Chen <deanchen@google.com> COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/2889 from GitMarco27:feature/async_database_session_service e1b1b14934c1fb7975a6832cdd1549e94acab985 PiperOrigin-RevId: 830525148
This commit is contained in:
committed by
Copybara-Service
parent
2443a1b74f
commit
74959414d8
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@@ -30,20 +31,21 @@ from sqlalchemy import Dialect
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import mysql
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import ArgumentError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import Session as DatabaseSessionFactory
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.schema import MetaData
|
||||
from sqlalchemy.types import DateTime
|
||||
from sqlalchemy.types import PickleType
|
||||
@@ -417,11 +419,11 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# 2. Create all tables based on schema
|
||||
# 3. Initialize all properties
|
||||
try:
|
||||
db_engine = create_engine(db_url, **kwargs)
|
||||
db_engine = create_async_engine(db_url, **kwargs)
|
||||
|
||||
if db_engine.dialect.name == "sqlite":
|
||||
# Set sqlite pragma to enable foreign keys constraints
|
||||
event.listen(db_engine, "connect", set_sqlite_pragma)
|
||||
event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ArgumentError):
|
||||
@@ -440,18 +442,32 @@ class DatabaseSessionService(BaseSessionService):
|
||||
local_timezone = get_localzone()
|
||||
logger.info("Local timezone: %s", local_timezone)
|
||||
|
||||
self.db_engine: Engine = db_engine
|
||||
self.db_engine: AsyncEngine = db_engine
|
||||
self.metadata: MetaData = MetaData()
|
||||
self.inspector = inspect(self.db_engine)
|
||||
|
||||
# DB session factory method
|
||||
self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
|
||||
sessionmaker(bind=self.db_engine)
|
||||
)
|
||||
self.database_session_factory: async_sessionmaker[
|
||||
DatabaseSessionFactory
|
||||
] = async_sessionmaker(bind=self.db_engine, expire_on_commit=False)
|
||||
|
||||
# Uncomment to recreate DB every time
|
||||
# Base.metadata.drop_all(self.db_engine)
|
||||
Base.metadata.create_all(self.db_engine)
|
||||
# Flag to indicate if tables are created
|
||||
self._tables_created = False
|
||||
# Lock to ensure thread-safe table creation
|
||||
self._table_creation_lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_tables_created(self):
|
||||
"""Ensure database tables are created. This is called lazily."""
|
||||
if self._tables_created:
|
||||
return
|
||||
|
||||
async with self._table_creation_lock:
|
||||
# Double-check after acquiring the lock
|
||||
if not self._tables_created:
|
||||
async with self.db_engine.begin() as conn:
|
||||
# Uncomment to recreate DB every time
|
||||
# await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
self._tables_created = True
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
@@ -467,22 +483,25 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# 3. Add the object to the table
|
||||
# 4. Build the session object with generated id
|
||||
# 5. Return the session
|
||||
await self._ensure_tables_created()
|
||||
async with self.database_session_factory() as sql_session:
|
||||
|
||||
with self.database_session_factory() as sql_session:
|
||||
if session_id and sql_session.get(
|
||||
if session_id and await sql_session.get(
|
||||
StorageSession, (app_name, user_id, session_id)
|
||||
):
|
||||
raise AlreadyExistsError(
|
||||
f"Session with id {session_id} already exists."
|
||||
)
|
||||
# Fetch app and user states from storage
|
||||
storage_app_state = sql_session.get(StorageAppState, (app_name))
|
||||
storage_app_state = await sql_session.get(StorageAppState, (app_name))
|
||||
storage_user_state = await sql_session.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
|
||||
# Create state tables if not exist
|
||||
if not storage_app_state:
|
||||
storage_app_state = StorageAppState(app_name=app_name, state={})
|
||||
sql_session.add(storage_app_state)
|
||||
storage_user_state = sql_session.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
if not storage_user_state:
|
||||
storage_user_state = StorageUserState(
|
||||
app_name=app_name, user_id=user_id, state={}
|
||||
@@ -509,9 +528,9 @@ class DatabaseSessionService(BaseSessionService):
|
||||
state=session_state,
|
||||
)
|
||||
sql_session.add(storage_session)
|
||||
sql_session.commit()
|
||||
await sql_session.commit()
|
||||
|
||||
sql_session.refresh(storage_session)
|
||||
await sql_session.refresh(storage_session)
|
||||
|
||||
# Merge states for response
|
||||
merged_state = _merge_state(
|
||||
@@ -529,39 +548,39 @@ class DatabaseSessionService(BaseSessionService):
|
||||
session_id: str,
|
||||
config: Optional[GetSessionConfig] = None,
|
||||
) -> Optional[Session]:
|
||||
await self._ensure_tables_created()
|
||||
# 1. Get the storage session entry from session table
|
||||
# 2. Get all the events based on session id and filtering config
|
||||
# 3. Convert and return the session
|
||||
with self.database_session_factory() as sql_session:
|
||||
storage_session = sql_session.get(
|
||||
async with self.database_session_factory() as sql_session:
|
||||
storage_session = await sql_session.get(
|
||||
StorageSession, (app_name, user_id, session_id)
|
||||
)
|
||||
if storage_session is None:
|
||||
return None
|
||||
|
||||
query = sql_session.query(StorageEvent).filter(
|
||||
StorageEvent.app_name == app_name,
|
||||
StorageEvent.user_id == user_id,
|
||||
StorageEvent.session_id == storage_session.id,
|
||||
stmt = (
|
||||
select(StorageEvent)
|
||||
.filter(StorageEvent.app_name == app_name)
|
||||
.filter(StorageEvent.session_id == storage_session.id)
|
||||
.filter(StorageEvent.user_id == user_id)
|
||||
)
|
||||
|
||||
if config and config.after_timestamp:
|
||||
after_dt = datetime.fromtimestamp(config.after_timestamp)
|
||||
query = query.filter(StorageEvent.timestamp >= after_dt)
|
||||
stmt = stmt.filter(StorageEvent.timestamp >= after_dt)
|
||||
|
||||
storage_events = (
|
||||
query.order_by(StorageEvent.timestamp.desc())
|
||||
.limit(
|
||||
config.num_recent_events
|
||||
if config and config.num_recent_events
|
||||
else None
|
||||
)
|
||||
.all()
|
||||
)
|
||||
stmt = stmt.order_by(StorageEvent.timestamp.desc())
|
||||
|
||||
if config and config.num_recent_events:
|
||||
stmt = stmt.limit(config.num_recent_events)
|
||||
|
||||
result = await sql_session.execute(stmt)
|
||||
storage_events = result.scalars().all()
|
||||
|
||||
# Fetch states from storage
|
||||
storage_app_state = sql_session.get(StorageAppState, (app_name))
|
||||
storage_user_state = sql_session.get(
|
||||
storage_app_state = await sql_session.get(StorageAppState, (app_name))
|
||||
storage_user_state = await sql_session.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
|
||||
@@ -581,32 +600,33 @@ class DatabaseSessionService(BaseSessionService):
|
||||
async def list_sessions(
|
||||
self, *, app_name: str, user_id: Optional[str] = None
|
||||
) -> ListSessionsResponse:
|
||||
with self.database_session_factory() as sql_session:
|
||||
query = sql_session.query(StorageSession).filter(
|
||||
StorageSession.app_name == app_name
|
||||
)
|
||||
await self._ensure_tables_created()
|
||||
async with self.database_session_factory() as sql_session:
|
||||
stmt = select(StorageSession).filter(StorageSession.app_name == app_name)
|
||||
if user_id is not None:
|
||||
query = query.filter(StorageSession.user_id == user_id)
|
||||
results = query.all()
|
||||
stmt = stmt.filter(StorageSession.user_id == user_id)
|
||||
|
||||
result = await sql_session.execute(stmt)
|
||||
results = result.scalars().all()
|
||||
|
||||
# Fetch app state from storage
|
||||
storage_app_state = sql_session.get(StorageAppState, (app_name))
|
||||
storage_app_state = await sql_session.get(StorageAppState, (app_name))
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
|
||||
# Fetch user state(s) from storage
|
||||
user_states_map = {}
|
||||
if user_id is not None:
|
||||
storage_user_state = sql_session.get(
|
||||
storage_user_state = await sql_session.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
if storage_user_state:
|
||||
user_states_map[user_id] = storage_user_state.state
|
||||
else:
|
||||
all_user_states_for_app = (
|
||||
sql_session.query(StorageUserState)
|
||||
.filter(StorageUserState.app_name == app_name)
|
||||
.all()
|
||||
user_state_stmt = select(StorageUserState).filter(
|
||||
StorageUserState.app_name == app_name
|
||||
)
|
||||
user_state_result = await sql_session.execute(user_state_stmt)
|
||||
all_user_states_for_app = user_state_result.scalars().all()
|
||||
for storage_user_state in all_user_states_for_app:
|
||||
user_states_map[storage_user_state.user_id] = storage_user_state.state
|
||||
|
||||
@@ -622,17 +642,19 @@ class DatabaseSessionService(BaseSessionService):
|
||||
async def delete_session(
|
||||
self, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
with self.database_session_factory() as sql_session:
|
||||
await self._ensure_tables_created()
|
||||
async with self.database_session_factory() as sql_session:
|
||||
stmt = delete(StorageSession).where(
|
||||
StorageSession.app_name == app_name,
|
||||
StorageSession.user_id == user_id,
|
||||
StorageSession.id == session_id,
|
||||
)
|
||||
sql_session.execute(stmt)
|
||||
sql_session.commit()
|
||||
await sql_session.execute(stmt)
|
||||
await sql_session.commit()
|
||||
|
||||
@override
|
||||
async def append_event(self, session: Session, event: Event) -> Event:
|
||||
await self._ensure_tables_created()
|
||||
if event.partial:
|
||||
return event
|
||||
|
||||
@@ -642,8 +664,8 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# 1. Check if timestamp is stale
|
||||
# 2. Update session attributes based on event config
|
||||
# 3. Store event to table
|
||||
with self.database_session_factory() as sql_session:
|
||||
storage_session = sql_session.get(
|
||||
async with self.database_session_factory() as sql_session:
|
||||
storage_session = await sql_session.get(
|
||||
StorageSession, (session.app_name, session.user_id, session.id)
|
||||
)
|
||||
|
||||
@@ -657,8 +679,10 @@ class DatabaseSessionService(BaseSessionService):
|
||||
)
|
||||
|
||||
# Fetch states from storage
|
||||
storage_app_state = sql_session.get(StorageAppState, (session.app_name))
|
||||
storage_user_state = sql_session.get(
|
||||
storage_app_state = await sql_session.get(
|
||||
StorageAppState, (session.app_name)
|
||||
)
|
||||
storage_user_state = await sql_session.get(
|
||||
StorageUserState, (session.app_name, session.user_id)
|
||||
)
|
||||
|
||||
@@ -680,8 +704,8 @@ class DatabaseSessionService(BaseSessionService):
|
||||
|
||||
sql_session.add(StorageEvent.from_event(session, event))
|
||||
|
||||
sql_session.commit()
|
||||
sql_session.refresh(storage_session)
|
||||
await sql_session.commit()
|
||||
await sql_session.refresh(storage_session)
|
||||
|
||||
# Update timestamp with commit time
|
||||
session.last_update_time = storage_session.update_timestamp_tz
|
||||
@@ -691,8 +715,12 @@ class DatabaseSessionService(BaseSessionService):
|
||||
return event
|
||||
|
||||
|
||||
def _merge_state(app_state, user_state, session_state):
|
||||
# Merge states for response
|
||||
def _merge_state(
|
||||
app_state: dict[str, Any],
|
||||
user_state: dict[str, Any],
|
||||
session_state: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Merge app, user, and session states into a single state dictionary."""
|
||||
merged_state = copy.deepcopy(session_state)
|
||||
for key in app_state.keys():
|
||||
merged_state[State.APP_PREFIX + key] = app_state[key]
|
||||
|
||||
@@ -39,7 +39,7 @@ def get_session_service(
|
||||
):
|
||||
"""Creates a session service for testing."""
|
||||
if service_type == SessionServiceType.DATABASE:
|
||||
return DatabaseSessionService('sqlite:///:memory:')
|
||||
return DatabaseSessionService('sqlite+aiosqlite:///:memory:')
|
||||
if service_type == SessionServiceType.SQLITE:
|
||||
return SqliteSessionService(str(tmp_path / 'sqlite.db'))
|
||||
return InMemorySessionService()
|
||||
|
||||
Reference in New Issue
Block a user