You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Close database engines to avoid aiosqlite pytest hangs
Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 852428755
This commit is contained in:
committed by
Copybara-Service
parent
8789ad8f16
commit
4ddb2cb2a8
+2
-2
@@ -26,8 +26,7 @@ classifiers = [ # List of https://pypi.org/classifiers/
|
||||
dependencies = [
|
||||
# go/keep-sorted start
|
||||
"PyYAML>=6.0.2, <7.0.0", # For APIHubToolset.
|
||||
# TODO: Update aiosqlite version once https://github.com/omnilib/aiosqlite/issues/369 is fixed.
|
||||
"aiosqlite==0.21.0", # For SQLite database
|
||||
"aiosqlite>=0.21.0", # For SQLite database
|
||||
"anyio>=4.9.0, <5.0.0", # For MCP Session Manager
|
||||
"authlib>=1.5.1, <2.0.0", # For RestAPI Tool
|
||||
"click>=8.1.8, <9.0.0", # For CLI tools
|
||||
@@ -110,6 +109,7 @@ eval = [
|
||||
"google-cloud-aiplatform[evaluation]>=1.100.0",
|
||||
"pandas>=2.2.3",
|
||||
"rouge-score>=0.1.2",
|
||||
"scipy<1.16; python_version<'3.11'",
|
||||
"tabulate>=0.9.0",
|
||||
# go/keep-sorted end
|
||||
]
|
||||
|
||||
@@ -25,12 +25,14 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.exc import ArgumentError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from typing_extensions import override
|
||||
from tzlocal import get_localzone
|
||||
|
||||
@@ -103,7 +105,15 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# 2. Create all tables based on schema
|
||||
# 3. Initialize all properties
|
||||
try:
|
||||
db_engine = create_async_engine(db_url, **kwargs)
|
||||
engine_kwargs = dict(kwargs)
|
||||
url = make_url(db_url)
|
||||
if url.get_backend_name() == "sqlite" and url.database == ":memory:":
|
||||
engine_kwargs.setdefault("poolclass", StaticPool)
|
||||
connect_args = dict(engine_kwargs.get("connect_args", {}))
|
||||
connect_args.setdefault("check_same_thread", False)
|
||||
engine_kwargs["connect_args"] = connect_args
|
||||
|
||||
db_engine = create_async_engine(db_url, **engine_kwargs)
|
||||
if db_engine.dialect.name == "sqlite":
|
||||
# Set sqlite pragma to enable foreign keys constraints
|
||||
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)
|
||||
@@ -477,3 +487,15 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# Also update the in-memory session
|
||||
await super().append_event(session=session, event=event)
|
||||
return event
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Disposes the SQLAlchemy engine and closes pooled connections."""
|
||||
await self.db_engine.dispose()
|
||||
|
||||
async def __aenter__(self) -> DatabaseSessionService:
|
||||
"""Enters the async context manager and returns this service."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Exits the async context manager and closes the service."""
|
||||
await self.close()
|
||||
|
||||
@@ -29,17 +29,20 @@ async def create_v0_db(db_path):
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# Use async context managers so DatabaseSessionService always closes.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_db_uses_latest_schema(tmp_path):
|
||||
db_path = tmp_path / 'new_db.db'
|
||||
db_url = f'sqlite+aiosqlite:///{db_path}'
|
||||
session_service = DatabaseSessionService(db_url)
|
||||
assert session_service._db_schema_version is None
|
||||
await session_service.create_session(app_name='my_app', user_id='test_user')
|
||||
assert (
|
||||
session_service._db_schema_version
|
||||
== _schema_check_utils.LATEST_SCHEMA_VERSION
|
||||
)
|
||||
async with DatabaseSessionService(db_url) as session_service:
|
||||
assert session_service._db_schema_version is None
|
||||
await session_service.create_session(app_name='my_app', user_id='test_user')
|
||||
assert (
|
||||
session_service._db_schema_version
|
||||
== _schema_check_utils.LATEST_SCHEMA_VERSION
|
||||
)
|
||||
|
||||
# Verify metadata table
|
||||
engine = create_async_engine(db_url)
|
||||
@@ -71,21 +74,20 @@ async def test_existing_v0_db_uses_v0_schema(tmp_path):
|
||||
db_path = tmp_path / 'v0_db.db'
|
||||
await create_v0_db(db_path)
|
||||
db_url = f'sqlite+aiosqlite:///{db_path}'
|
||||
session_service = DatabaseSessionService(db_url)
|
||||
async with DatabaseSessionService(db_url) as session_service:
|
||||
assert session_service._db_schema_version is None
|
||||
await session_service.create_session(
|
||||
app_name='my_app', user_id='test_user', session_id='s1'
|
||||
)
|
||||
assert (
|
||||
session_service._db_schema_version
|
||||
== _schema_check_utils.SCHEMA_VERSION_0_PICKLE
|
||||
)
|
||||
|
||||
assert session_service._db_schema_version is None
|
||||
await session_service.create_session(
|
||||
app_name='my_app', user_id='test_user', session_id='s1'
|
||||
)
|
||||
assert (
|
||||
session_service._db_schema_version
|
||||
== _schema_check_utils.SCHEMA_VERSION_0_PICKLE
|
||||
)
|
||||
|
||||
session = await session_service.get_session(
|
||||
app_name='my_app', user_id='test_user', session_id='s1'
|
||||
)
|
||||
assert session.id == 's1'
|
||||
session = await session_service.get_session(
|
||||
app_name='my_app', user_id='test_user', session_id='s1'
|
||||
)
|
||||
assert session.id == 's1'
|
||||
|
||||
# Verify schema tables
|
||||
engine = create_async_engine(db_url)
|
||||
@@ -111,38 +113,38 @@ async def test_existing_latest_db_uses_latest_schema(tmp_path):
|
||||
db_url = f'sqlite+aiosqlite:///{db_path}'
|
||||
|
||||
# Create session service which creates db with latest schema
|
||||
session_service1 = DatabaseSessionService(db_url)
|
||||
await session_service1.create_session(
|
||||
app_name='my_app', user_id='test_user', session_id='s1'
|
||||
)
|
||||
assert (
|
||||
session_service1._db_schema_version
|
||||
== _schema_check_utils.LATEST_SCHEMA_VERSION
|
||||
)
|
||||
async with DatabaseSessionService(db_url) as session_service1:
|
||||
await session_service1.create_session(
|
||||
app_name='my_app', user_id='test_user', session_id='s1'
|
||||
)
|
||||
assert (
|
||||
session_service1._db_schema_version
|
||||
== _schema_check_utils.LATEST_SCHEMA_VERSION
|
||||
)
|
||||
|
||||
# Create another session service on same db and check it detects latest schema
|
||||
session_service2 = DatabaseSessionService(db_url)
|
||||
await session_service2.create_session(
|
||||
app_name='my_app', user_id='test_user2', session_id='s2'
|
||||
)
|
||||
assert (
|
||||
session_service2._db_schema_version
|
||||
== _schema_check_utils.LATEST_SCHEMA_VERSION
|
||||
)
|
||||
s2 = await session_service2.get_session(
|
||||
app_name='my_app', user_id='test_user2', session_id='s2'
|
||||
)
|
||||
assert s2.id == 's2'
|
||||
# Create another session service on same db and check it detects latest schema
|
||||
async with DatabaseSessionService(db_url) as session_service2:
|
||||
await session_service2.create_session(
|
||||
app_name='my_app', user_id='test_user2', session_id='s2'
|
||||
)
|
||||
assert (
|
||||
session_service2._db_schema_version
|
||||
== _schema_check_utils.LATEST_SCHEMA_VERSION
|
||||
)
|
||||
s2 = await session_service2.get_session(
|
||||
app_name='my_app', user_id='test_user2', session_id='s2'
|
||||
)
|
||||
assert s2.id == 's2'
|
||||
|
||||
s1 = await session_service2.get_session(
|
||||
app_name='my_app', user_id='test_user', session_id='s1'
|
||||
)
|
||||
assert s1.id == 's1'
|
||||
s1 = await session_service2.get_session(
|
||||
app_name='my_app', user_id='test_user', session_id='s1'
|
||||
)
|
||||
assert s1.id == 's1'
|
||||
|
||||
list_sessions_response = await session_service2.list_sessions(
|
||||
app_name='my_app'
|
||||
)
|
||||
assert len(list_sessions_response.sessions) == 2
|
||||
list_sessions_response = await session_service2.list_sessions(
|
||||
app_name='my_app'
|
||||
)
|
||||
assert len(list_sessions_response.sessions) == 2
|
||||
|
||||
# Verify schema tables
|
||||
engine = create_async_engine(db_url)
|
||||
|
||||
@@ -45,33 +45,30 @@ def get_session_service(
|
||||
return InMemorySessionService()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
]
|
||||
)
|
||||
async def test_get_empty_session(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def session_service(request, tmp_path):
|
||||
"""Provides a session service and closes database backends on teardown."""
|
||||
service = get_session_service(request.param, tmp_path)
|
||||
yield service
|
||||
if isinstance(service, DatabaseSessionService):
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_empty_session(session_service):
|
||||
assert not await session_service.get_session(
|
||||
app_name='my_app', user_id='test_user', session_id='123'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_create_get_session(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_create_get_session(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
state = {'key': 'value'}
|
||||
@@ -111,16 +108,7 @@ async def test_create_get_session(service_type, tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_create_and_list_sessions(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_create_and_list_sessions(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
|
||||
@@ -144,16 +132,7 @@ async def test_create_and_list_sessions(service_type, tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_list_sessions_all_users(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_list_sessions_all_users(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id_1 = 'user1'
|
||||
user_id_2 = 'user2'
|
||||
@@ -209,16 +188,7 @@ async def test_list_sessions_all_users(service_type, tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_app_state_is_shared_by_all_users_of_app(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_app_state_is_shared_by_all_users_of_app(session_service):
|
||||
app_name = 'my_app'
|
||||
# User 1 creates a session, establishing app:k1
|
||||
session1 = await session_service.create_session(
|
||||
@@ -247,18 +217,7 @@ async def test_app_state_is_shared_by_all_users_of_app(service_type, tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_user_state_is_shared_only_by_user_sessions(
|
||||
service_type, tmp_path
|
||||
):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_user_state_is_shared_only_by_user_sessions(session_service):
|
||||
app_name = 'my_app'
|
||||
# User 1 creates a session, establishing user:k1 for user 1
|
||||
session1 = await session_service.create_session(
|
||||
@@ -286,16 +245,7 @@ async def test_user_state_is_shared_only_by_user_sessions(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_session_state_is_not_shared(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_session_state_is_not_shared(session_service):
|
||||
app_name = 'my_app'
|
||||
# User 1 creates a session session1, establishing sk1 only for session1
|
||||
session1 = await session_service.create_session(
|
||||
@@ -324,18 +274,7 @@ async def test_session_state_is_not_shared(service_type, tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_temp_state_is_not_persisted_in_state_or_events(
|
||||
service_type, tmp_path
|
||||
):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_temp_state_is_not_persisted_in_state_or_events(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'u1'
|
||||
session = await session_service.create_session(
|
||||
@@ -361,16 +300,7 @@ async def test_temp_state_is_not_persisted_in_state_or_events(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_get_session_respects_user_id(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_get_session_respects_user_id(session_service):
|
||||
app_name = 'my_app'
|
||||
# u1 creates session 's1' and adds an event
|
||||
session1 = await session_service.create_session(
|
||||
@@ -392,18 +322,7 @@ async def test_get_session_respects_user_id(service_type, tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_create_session_with_existing_id_raises_error(
|
||||
service_type, tmp_path
|
||||
):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_create_session_with_existing_id_raises_error(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
session_id = 'existing_session'
|
||||
@@ -425,16 +344,7 @@ async def test_create_session_with_existing_id_raises_error(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_append_event_bytes(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_append_event_bytes(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
@@ -471,16 +381,7 @@ async def test_append_event_bytes(service_type, tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_append_event_complete(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_append_event_complete(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
@@ -532,18 +433,7 @@ async def test_append_event_complete(service_type, tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_session_last_update_time_updates_on_event(
|
||||
service_type, tmp_path
|
||||
):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_session_last_update_time_updates_on_event(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
@@ -573,16 +463,7 @@ async def test_session_last_update_time_updates_on_event(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_get_session_with_config(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
async def test_get_session_with_config(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
@@ -605,16 +486,7 @@ async def test_get_session_with_config(service_type):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_get_session_with_config(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_get_session_with_config(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
@@ -674,16 +546,7 @@ async def test_get_session_with_config(service_type, tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type',
|
||||
[
|
||||
SessionServiceType.IN_MEMORY,
|
||||
SessionServiceType.DATABASE,
|
||||
SessionServiceType.SQLITE,
|
||||
],
|
||||
)
|
||||
async def test_partial_events_are_not_persisted(service_type, tmp_path):
|
||||
session_service = get_session_service(service_type, tmp_path)
|
||||
async def test_partial_events_are_not_persisted(session_service):
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
session = await session_service.create_session(
|
||||
|
||||
Reference in New Issue
Block a user