diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 92b37229..529bbfd4 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -113,6 +113,8 @@ class DatabaseSessionService(BaseSessionService): connect_args = dict(engine_kwargs.get("connect_args", {})) connect_args.setdefault("check_same_thread", False) engine_kwargs["connect_args"] = connect_args + elif url.get_backend_name() != "sqlite": + engine_kwargs.setdefault("pool_pre_ping", True) db_engine = create_async_engine(db_url, **engine_kwargs) if db_engine.dialect.name == "sqlite": diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index a305f53c..f6445934 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -16,10 +16,12 @@ from datetime import datetime from datetime import timezone import enum import sqlite3 +from unittest import mock from google.adk.errors.already_exists_error import AlreadyExistsError from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.sessions import database_session_service from google.adk.sessions.base_session_service import GetSessionConfig from google.adk.sessions.database_session_service import DatabaseSessionService from google.adk.sessions.in_memory_session_service import InMemorySessionService @@ -61,6 +63,51 @@ async def session_service(request, tmp_path): await service.close() +def test_database_session_service_enables_pool_pre_ping_by_default(): + captured_kwargs = {} + + def fake_create_async_engine(_db_url: str, **kwargs): + captured_kwargs.update(kwargs) + fake_engine = mock.Mock() + fake_engine.dialect.name = 'postgresql' + fake_engine.sync_engine = mock.Mock() + return fake_engine + + with mock.patch.object( + database_session_service, + 'create_async_engine', + side_effect=fake_create_async_engine, + ): + database_session_service.DatabaseSessionService( + 'postgresql+psycopg2://user:pass@localhost:5432/db' + ) + + assert captured_kwargs.get('pool_pre_ping') is True + + +def test_database_session_service_respects_pool_pre_ping_override(): + captured_kwargs = {} + + def fake_create_async_engine(_db_url: str, **kwargs): + captured_kwargs.update(kwargs) + fake_engine = mock.Mock() + fake_engine.dialect.name = 'postgresql' + fake_engine.sync_engine = mock.Mock() + return fake_engine + + with mock.patch.object( + database_session_service, + 'create_async_engine', + side_effect=fake_create_async_engine, + ): + database_session_service.DatabaseSessionService( + 'postgresql+psycopg2://user:pass@localhost:5432/db', + pool_pre_ping=False, + ) + + assert captured_kwargs.get('pool_pre_ping') is False + + @pytest.mark.asyncio async def test_sqlite_session_service_accepts_sqlite_urls( tmp_path, monkeypatch