chore: Close database engines to avoid aiosqlite pytest hangs

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 852428755
This commit is contained in:
George Weale
2026-01-05 13:17:27 -08:00
committed by Copybara-Service
parent 8789ad8f16
commit 4ddb2cb2a8
4 changed files with 105 additions and 218 deletions
+2 -2
View File
@@ -26,8 +26,7 @@ classifiers = [ # List of https://pypi.org/classifiers/
dependencies = [
# go/keep-sorted start
"PyYAML>=6.0.2, <7.0.0", # For APIHubToolset.
# TODO: Update aiosqlite version once https://github.com/omnilib/aiosqlite/issues/369 is fixed.
"aiosqlite==0.21.0", # For SQLite database
"aiosqlite>=0.21.0", # For SQLite database
"anyio>=4.9.0, <5.0.0", # For MCP Session Manager
"authlib>=1.5.1, <2.0.0", # For RestAPI Tool
"click>=8.1.8, <9.0.0", # For CLI tools
@@ -110,6 +109,7 @@ eval = [
"google-cloud-aiplatform[evaluation]>=1.100.0",
"pandas>=2.2.3",
"rouge-score>=0.1.2",
"scipy<1.16; python_version<'3.11'",
"tabulate>=0.9.0",
# go/keep-sorted end
]
@@ -25,12 +25,14 @@ from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.inspection import inspect
from sqlalchemy.pool import StaticPool
from typing_extensions import override
from tzlocal import get_localzone
@@ -103,7 +105,15 @@ class DatabaseSessionService(BaseSessionService):
# 2. Create all tables based on schema
# 3. Initialize all properties
try:
db_engine = create_async_engine(db_url, **kwargs)
engine_kwargs = dict(kwargs)
url = make_url(db_url)
if url.get_backend_name() == "sqlite" and url.database == ":memory:":
engine_kwargs.setdefault("poolclass", StaticPool)
connect_args = dict(engine_kwargs.get("connect_args", {}))
connect_args.setdefault("check_same_thread", False)
engine_kwargs["connect_args"] = connect_args
db_engine = create_async_engine(db_url, **engine_kwargs)
if db_engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)
@@ -477,3 +487,15 @@ class DatabaseSessionService(BaseSessionService):
# Also update the in-memory session
await super().append_event(session=session, event=event)
return event
async def close(self) -> None:
"""Disposes the SQLAlchemy engine and closes pooled connections."""
await self.db_engine.dispose()
async def __aenter__(self) -> DatabaseSessionService:
"""Enters the async context manager and returns this service."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exits the async context manager and closes the service."""
await self.close()
@@ -29,17 +29,20 @@ async def create_v0_db(db_path):
await engine.dispose()
# Use async context managers so DatabaseSessionService always closes.
@pytest.mark.asyncio
async def test_new_db_uses_latest_schema(tmp_path):
db_path = tmp_path / 'new_db.db'
db_url = f'sqlite+aiosqlite:///{db_path}'
session_service = DatabaseSessionService(db_url)
assert session_service._db_schema_version is None
await session_service.create_session(app_name='my_app', user_id='test_user')
assert (
session_service._db_schema_version
== _schema_check_utils.LATEST_SCHEMA_VERSION
)
async with DatabaseSessionService(db_url) as session_service:
assert session_service._db_schema_version is None
await session_service.create_session(app_name='my_app', user_id='test_user')
assert (
session_service._db_schema_version
== _schema_check_utils.LATEST_SCHEMA_VERSION
)
# Verify metadata table
engine = create_async_engine(db_url)
@@ -71,21 +74,20 @@ async def test_existing_v0_db_uses_v0_schema(tmp_path):
db_path = tmp_path / 'v0_db.db'
await create_v0_db(db_path)
db_url = f'sqlite+aiosqlite:///{db_path}'
session_service = DatabaseSessionService(db_url)
async with DatabaseSessionService(db_url) as session_service:
assert session_service._db_schema_version is None
await session_service.create_session(
app_name='my_app', user_id='test_user', session_id='s1'
)
assert (
session_service._db_schema_version
== _schema_check_utils.SCHEMA_VERSION_0_PICKLE
)
assert session_service._db_schema_version is None
await session_service.create_session(
app_name='my_app', user_id='test_user', session_id='s1'
)
assert (
session_service._db_schema_version
== _schema_check_utils.SCHEMA_VERSION_0_PICKLE
)
session = await session_service.get_session(
app_name='my_app', user_id='test_user', session_id='s1'
)
assert session.id == 's1'
session = await session_service.get_session(
app_name='my_app', user_id='test_user', session_id='s1'
)
assert session.id == 's1'
# Verify schema tables
engine = create_async_engine(db_url)
@@ -111,38 +113,38 @@ async def test_existing_latest_db_uses_latest_schema(tmp_path):
db_url = f'sqlite+aiosqlite:///{db_path}'
# Create session service which creates db with latest schema
session_service1 = DatabaseSessionService(db_url)
await session_service1.create_session(
app_name='my_app', user_id='test_user', session_id='s1'
)
assert (
session_service1._db_schema_version
== _schema_check_utils.LATEST_SCHEMA_VERSION
)
async with DatabaseSessionService(db_url) as session_service1:
await session_service1.create_session(
app_name='my_app', user_id='test_user', session_id='s1'
)
assert (
session_service1._db_schema_version
== _schema_check_utils.LATEST_SCHEMA_VERSION
)
# Create another session service on same db and check it detects latest schema
session_service2 = DatabaseSessionService(db_url)
await session_service2.create_session(
app_name='my_app', user_id='test_user2', session_id='s2'
)
assert (
session_service2._db_schema_version
== _schema_check_utils.LATEST_SCHEMA_VERSION
)
s2 = await session_service2.get_session(
app_name='my_app', user_id='test_user2', session_id='s2'
)
assert s2.id == 's2'
# Create another session service on same db and check it detects latest schema
async with DatabaseSessionService(db_url) as session_service2:
await session_service2.create_session(
app_name='my_app', user_id='test_user2', session_id='s2'
)
assert (
session_service2._db_schema_version
== _schema_check_utils.LATEST_SCHEMA_VERSION
)
s2 = await session_service2.get_session(
app_name='my_app', user_id='test_user2', session_id='s2'
)
assert s2.id == 's2'
s1 = await session_service2.get_session(
app_name='my_app', user_id='test_user', session_id='s1'
)
assert s1.id == 's1'
s1 = await session_service2.get_session(
app_name='my_app', user_id='test_user', session_id='s1'
)
assert s1.id == 's1'
list_sessions_response = await session_service2.list_sessions(
app_name='my_app'
)
assert len(list_sessions_response.sessions) == 2
list_sessions_response = await session_service2.list_sessions(
app_name='my_app'
)
assert len(list_sessions_response.sessions) == 2
# Verify schema tables
engine = create_async_engine(db_url)
+28 -165
View File
@@ -45,33 +45,30 @@ def get_session_service(
return InMemorySessionService()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
@pytest.fixture(
params=[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
]
)
async def test_get_empty_session(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def session_service(request, tmp_path):
"""Provides a session service and closes database backends on teardown."""
service = get_session_service(request.param, tmp_path)
yield service
if isinstance(service, DatabaseSessionService):
await service.close()
@pytest.mark.asyncio
async def test_get_empty_session(session_service):
assert not await session_service.get_session(
app_name='my_app', user_id='test_user', session_id='123'
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_create_get_session(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_create_get_session(session_service):
app_name = 'my_app'
user_id = 'test_user'
state = {'key': 'value'}
@@ -111,16 +108,7 @@ async def test_create_get_session(service_type, tmp_path):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_create_and_list_sessions(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_create_and_list_sessions(session_service):
app_name = 'my_app'
user_id = 'test_user'
@@ -144,16 +132,7 @@ async def test_create_and_list_sessions(service_type, tmp_path):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_list_sessions_all_users(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_list_sessions_all_users(session_service):
app_name = 'my_app'
user_id_1 = 'user1'
user_id_2 = 'user2'
@@ -209,16 +188,7 @@ async def test_list_sessions_all_users(service_type, tmp_path):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_app_state_is_shared_by_all_users_of_app(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_app_state_is_shared_by_all_users_of_app(session_service):
app_name = 'my_app'
# User 1 creates a session, establishing app:k1
session1 = await session_service.create_session(
@@ -247,18 +217,7 @@ async def test_app_state_is_shared_by_all_users_of_app(service_type, tmp_path):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_user_state_is_shared_only_by_user_sessions(
service_type, tmp_path
):
session_service = get_session_service(service_type, tmp_path)
async def test_user_state_is_shared_only_by_user_sessions(session_service):
app_name = 'my_app'
# User 1 creates a session, establishing user:k1 for user 1
session1 = await session_service.create_session(
@@ -286,16 +245,7 @@ async def test_user_state_is_shared_only_by_user_sessions(
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_session_state_is_not_shared(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_session_state_is_not_shared(session_service):
app_name = 'my_app'
# User 1 creates a session session1, establishing sk1 only for session1
session1 = await session_service.create_session(
@@ -324,18 +274,7 @@ async def test_session_state_is_not_shared(service_type, tmp_path):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_temp_state_is_not_persisted_in_state_or_events(
service_type, tmp_path
):
session_service = get_session_service(service_type, tmp_path)
async def test_temp_state_is_not_persisted_in_state_or_events(session_service):
app_name = 'my_app'
user_id = 'u1'
session = await session_service.create_session(
@@ -361,16 +300,7 @@ async def test_temp_state_is_not_persisted_in_state_or_events(
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_get_session_respects_user_id(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_get_session_respects_user_id(session_service):
app_name = 'my_app'
# u1 creates session 's1' and adds an event
session1 = await session_service.create_session(
@@ -392,18 +322,7 @@ async def test_get_session_respects_user_id(service_type, tmp_path):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_create_session_with_existing_id_raises_error(
service_type, tmp_path
):
session_service = get_session_service(service_type, tmp_path)
async def test_create_session_with_existing_id_raises_error(session_service):
app_name = 'my_app'
user_id = 'test_user'
session_id = 'existing_session'
@@ -425,16 +344,7 @@ async def test_create_session_with_existing_id_raises_error(
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_append_event_bytes(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_append_event_bytes(session_service):
app_name = 'my_app'
user_id = 'user'
@@ -471,16 +381,7 @@ async def test_append_event_bytes(service_type, tmp_path):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_append_event_complete(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_append_event_complete(session_service):
app_name = 'my_app'
user_id = 'user'
@@ -532,18 +433,7 @@ async def test_append_event_complete(service_type, tmp_path):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_session_last_update_time_updates_on_event(
service_type, tmp_path
):
session_service = get_session_service(service_type, tmp_path)
async def test_session_last_update_time_updates_on_event(session_service):
app_name = 'my_app'
user_id = 'user'
@@ -573,16 +463,7 @@ async def test_session_last_update_time_updates_on_event(
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_get_session_with_config(service_type):
session_service = get_session_service(service_type)
async def test_get_session_with_config(session_service):
app_name = 'my_app'
user_id = 'user'
@@ -605,16 +486,7 @@ async def test_get_session_with_config(service_type):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_get_session_with_config(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_get_session_with_config(session_service):
app_name = 'my_app'
user_id = 'user'
@@ -674,16 +546,7 @@ async def test_get_session_with_config(service_type, tmp_path):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_partial_events_are_not_persisted(service_type, tmp_path):
session_service = get_session_service(service_type, tmp_path)
async def test_partial_events_are_not_persisted(session_service):
app_name = 'my_app'
user_id = 'user'
session = await session_service.create_session(