From 4ddb2cb2a8d1d026a43418b2dd698e6ea199594e Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 5 Jan 2026 13:17:27 -0800 Subject: [PATCH] chore: Close database engines to avoid aiosqlite pytest hangs Co-authored-by: George Weale PiperOrigin-RevId: 852428755 --- pyproject.toml | 4 +- .../adk/sessions/database_session_service.py | 24 ++- .../migration/test_database_schema.py | 102 ++++----- .../sessions/test_session_service.py | 193 +++--------------- 4 files changed, 105 insertions(+), 218 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index af7e1840..19abaa3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,7 @@ classifiers = [ # List of https://pypi.org/classifiers/ dependencies = [ # go/keep-sorted start "PyYAML>=6.0.2, <7.0.0", # For APIHubToolset. - # TODO: Update aiosqlite version once https://github.com/omnilib/aiosqlite/issues/369 is fixed. - "aiosqlite==0.21.0", # For SQLite database + "aiosqlite>=0.21.0", # For SQLite database "anyio>=4.9.0, <5.0.0", # For MCP Session Manager "authlib>=1.5.1, <2.0.0", # For RestAPI Tool "click>=8.1.8, <9.0.0", # For CLI tools @@ -110,6 +109,7 @@ eval = [ "google-cloud-aiplatform[evaluation]>=1.100.0", "pandas>=2.2.3", "rouge-score>=0.1.2", + "scipy<1.16; python_version<'3.11'", "tabulate>=0.9.0", # go/keep-sorted end ] diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 3cc9bb6a..c9762ad0 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -25,12 +25,14 @@ from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import select from sqlalchemy import text +from sqlalchemy.engine import make_url from sqlalchemy.exc import ArgumentError from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.inspection import inspect +from sqlalchemy.pool import StaticPool from typing_extensions import override from tzlocal import get_localzone @@ -103,7 +105,15 @@ class DatabaseSessionService(BaseSessionService): # 2. Create all tables based on schema # 3. Initialize all properties try: - db_engine = create_async_engine(db_url, **kwargs) + engine_kwargs = dict(kwargs) + url = make_url(db_url) + if url.get_backend_name() == "sqlite" and url.database == ":memory:": + engine_kwargs.setdefault("poolclass", StaticPool) + connect_args = dict(engine_kwargs.get("connect_args", {})) + connect_args.setdefault("check_same_thread", False) + engine_kwargs["connect_args"] = connect_args + + db_engine = create_async_engine(db_url, **engine_kwargs) if db_engine.dialect.name == "sqlite": # Set sqlite pragma to enable foreign keys constraints event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma) @@ -477,3 +487,15 @@ class DatabaseSessionService(BaseSessionService): # Also update the in-memory session await super().append_event(session=session, event=event) return event + + async def close(self) -> None: + """Disposes the SQLAlchemy engine and closes pooled connections.""" + await self.db_engine.dispose() + + async def __aenter__(self) -> DatabaseSessionService: + """Enters the async context manager and returns this service.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Exits the async context manager and closes the service.""" + await self.close() diff --git a/tests/unittests/sessions/migration/test_database_schema.py b/tests/unittests/sessions/migration/test_database_schema.py index 4fc0d03d..239da2f1 100644 --- a/tests/unittests/sessions/migration/test_database_schema.py +++ b/tests/unittests/sessions/migration/test_database_schema.py @@ -29,17 +29,20 @@ async def create_v0_db(db_path): await engine.dispose() +# Use async context managers so DatabaseSessionService always closes. + + @pytest.mark.asyncio async def test_new_db_uses_latest_schema(tmp_path): db_path = tmp_path / 'new_db.db' db_url = f'sqlite+aiosqlite:///{db_path}' - session_service = DatabaseSessionService(db_url) - assert session_service._db_schema_version is None - await session_service.create_session(app_name='my_app', user_id='test_user') - assert ( - session_service._db_schema_version - == _schema_check_utils.LATEST_SCHEMA_VERSION - ) + async with DatabaseSessionService(db_url) as session_service: + assert session_service._db_schema_version is None + await session_service.create_session(app_name='my_app', user_id='test_user') + assert ( + session_service._db_schema_version + == _schema_check_utils.LATEST_SCHEMA_VERSION + ) # Verify metadata table engine = create_async_engine(db_url) @@ -71,21 +74,20 @@ async def test_existing_v0_db_uses_v0_schema(tmp_path): db_path = tmp_path / 'v0_db.db' await create_v0_db(db_path) db_url = f'sqlite+aiosqlite:///{db_path}' - session_service = DatabaseSessionService(db_url) + async with DatabaseSessionService(db_url) as session_service: + assert session_service._db_schema_version is None + await session_service.create_session( + app_name='my_app', user_id='test_user', session_id='s1' + ) + assert ( + session_service._db_schema_version + == _schema_check_utils.SCHEMA_VERSION_0_PICKLE + ) - assert session_service._db_schema_version is None - await session_service.create_session( - app_name='my_app', user_id='test_user', session_id='s1' - ) - assert ( - session_service._db_schema_version - == _schema_check_utils.SCHEMA_VERSION_0_PICKLE - ) - - session = await session_service.get_session( - app_name='my_app', user_id='test_user', session_id='s1' - ) - assert session.id == 's1' + session = await session_service.get_session( + app_name='my_app', user_id='test_user', session_id='s1' + ) + assert session.id == 's1' # Verify schema tables engine = create_async_engine(db_url) @@ -111,38 +113,38 @@ async def test_existing_latest_db_uses_latest_schema(tmp_path): db_url = f'sqlite+aiosqlite:///{db_path}' # Create session service which creates db with latest schema - session_service1 = DatabaseSessionService(db_url) - await session_service1.create_session( - app_name='my_app', user_id='test_user', session_id='s1' - ) - assert ( - session_service1._db_schema_version - == _schema_check_utils.LATEST_SCHEMA_VERSION - ) + async with DatabaseSessionService(db_url) as session_service1: + await session_service1.create_session( + app_name='my_app', user_id='test_user', session_id='s1' + ) + assert ( + session_service1._db_schema_version + == _schema_check_utils.LATEST_SCHEMA_VERSION + ) - # Create another session service on same db and check it detects latest schema - session_service2 = DatabaseSessionService(db_url) - await session_service2.create_session( - app_name='my_app', user_id='test_user2', session_id='s2' - ) - assert ( - session_service2._db_schema_version - == _schema_check_utils.LATEST_SCHEMA_VERSION - ) - s2 = await session_service2.get_session( - app_name='my_app', user_id='test_user2', session_id='s2' - ) - assert s2.id == 's2' + # Create another session service on same db and check it detects latest schema + async with DatabaseSessionService(db_url) as session_service2: + await session_service2.create_session( + app_name='my_app', user_id='test_user2', session_id='s2' + ) + assert ( + session_service2._db_schema_version + == _schema_check_utils.LATEST_SCHEMA_VERSION + ) + s2 = await session_service2.get_session( + app_name='my_app', user_id='test_user2', session_id='s2' + ) + assert s2.id == 's2' - s1 = await session_service2.get_session( - app_name='my_app', user_id='test_user', session_id='s1' - ) - assert s1.id == 's1' + s1 = await session_service2.get_session( + app_name='my_app', user_id='test_user', session_id='s1' + ) + assert s1.id == 's1' - list_sessions_response = await session_service2.list_sessions( - app_name='my_app' - ) - assert len(list_sessions_response.sessions) == 2 + list_sessions_response = await session_service2.list_sessions( + app_name='my_app' + ) + assert len(list_sessions_response.sessions) == 2 # Verify schema tables engine = create_async_engine(db_url) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 45aa3fee..556d78ae 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -45,33 +45,30 @@ def get_session_service( return InMemorySessionService() -@pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ +@pytest.fixture( + params=[ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, SessionServiceType.SQLITE, - ], + ] ) -async def test_get_empty_session(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def session_service(request, tmp_path): + """Provides a session service and closes database backends on teardown.""" + service = get_session_service(request.param, tmp_path) + yield service + if isinstance(service, DatabaseSessionService): + await service.close() + + +@pytest.mark.asyncio +async def test_get_empty_session(session_service): assert not await session_service.get_session( app_name='my_app', user_id='test_user', session_id='123' ) @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_create_get_session(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_create_get_session(session_service): app_name = 'my_app' user_id = 'test_user' state = {'key': 'value'} @@ -111,16 +108,7 @@ async def test_create_get_session(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_create_and_list_sessions(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_create_and_list_sessions(session_service): app_name = 'my_app' user_id = 'test_user' @@ -144,16 +132,7 @@ async def test_create_and_list_sessions(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_list_sessions_all_users(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_list_sessions_all_users(session_service): app_name = 'my_app' user_id_1 = 'user1' user_id_2 = 'user2' @@ -209,16 +188,7 @@ async def test_list_sessions_all_users(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_app_state_is_shared_by_all_users_of_app(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_app_state_is_shared_by_all_users_of_app(session_service): app_name = 'my_app' # User 1 creates a session, establishing app:k1 session1 = await session_service.create_session( @@ -247,18 +217,7 @@ async def test_app_state_is_shared_by_all_users_of_app(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_user_state_is_shared_only_by_user_sessions( - service_type, tmp_path -): - session_service = get_session_service(service_type, tmp_path) +async def test_user_state_is_shared_only_by_user_sessions(session_service): app_name = 'my_app' # User 1 creates a session, establishing user:k1 for user 1 session1 = await session_service.create_session( @@ -286,16 +245,7 @@ async def test_user_state_is_shared_only_by_user_sessions( @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_session_state_is_not_shared(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_session_state_is_not_shared(session_service): app_name = 'my_app' # User 1 creates a session session1, establishing sk1 only for session1 session1 = await session_service.create_session( @@ -324,18 +274,7 @@ async def test_session_state_is_not_shared(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_temp_state_is_not_persisted_in_state_or_events( - service_type, tmp_path -): - session_service = get_session_service(service_type, tmp_path) +async def test_temp_state_is_not_persisted_in_state_or_events(session_service): app_name = 'my_app' user_id = 'u1' session = await session_service.create_session( @@ -361,16 +300,7 @@ async def test_temp_state_is_not_persisted_in_state_or_events( @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_get_session_respects_user_id(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_get_session_respects_user_id(session_service): app_name = 'my_app' # u1 creates session 's1' and adds an event session1 = await session_service.create_session( @@ -392,18 +322,7 @@ async def test_get_session_respects_user_id(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_create_session_with_existing_id_raises_error( - service_type, tmp_path -): - session_service = get_session_service(service_type, tmp_path) +async def test_create_session_with_existing_id_raises_error(session_service): app_name = 'my_app' user_id = 'test_user' session_id = 'existing_session' @@ -425,16 +344,7 @@ async def test_create_session_with_existing_id_raises_error( @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_append_event_bytes(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_append_event_bytes(session_service): app_name = 'my_app' user_id = 'user' @@ -471,16 +381,7 @@ async def test_append_event_bytes(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_append_event_complete(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_append_event_complete(session_service): app_name = 'my_app' user_id = 'user' @@ -532,18 +433,7 @@ async def test_append_event_complete(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_session_last_update_time_updates_on_event( - service_type, tmp_path -): - session_service = get_session_service(service_type, tmp_path) +async def test_session_last_update_time_updates_on_event(session_service): app_name = 'my_app' user_id = 'user' @@ -573,16 +463,7 @@ async def test_session_last_update_time_updates_on_event( @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_get_session_with_config(service_type): - session_service = get_session_service(service_type) +async def test_get_session_with_config(session_service): app_name = 'my_app' user_id = 'user' @@ -605,16 +486,7 @@ async def test_get_session_with_config(service_type): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_get_session_with_config(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_get_session_with_config(session_service): app_name = 'my_app' user_id = 'user' @@ -674,16 +546,7 @@ async def test_get_session_with_config(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_partial_events_are_not_persisted(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_partial_events_are_not_persisted(session_service): app_name = 'my_app' user_id = 'user' session = await session_service.create_session(