diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 97964833..e37f91d0 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -26,6 +26,7 @@ import uuid from sqlalchemy import Boolean from sqlalchemy import delete from sqlalchemy import Dialect +from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func from sqlalchemy import Text @@ -366,6 +367,12 @@ class StorageUserState(Base): ) +def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + class DatabaseSessionService(BaseSessionService): """A session service that uses a database for storage.""" @@ -374,9 +381,13 @@ class DatabaseSessionService(BaseSessionService): # 1. Create DB engine for db connection # 2. Create all tables based on schema # 3. Initialize all properties - try: db_engine = create_engine(db_url, **kwargs) + + if db_engine.dialect.name == "sqlite": + # Set sqlite pragma to enable foreign keys constraints + event.listen(db_engine, "connect", set_sqlite_pragma) + except Exception as e: if isinstance(e, ArgumentError): raise ValueError(