fix: Handle SQLite URLs in SqliteSessionService

The SqliteSessionService now accepts database paths in the form of SQLite URLs (e.g., "sqlite:///./sessions.db", "sqlite+aiosqlite:////absolute.db")

Close #4077

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 853922433
This commit is contained in:
George Weale
2026-01-08 16:13:37 -08:00
committed by Copybara-Service
parent 3c51ee7f48
commit b8917bc80e
2 changed files with 92 additions and 7 deletions
@@ -22,6 +22,8 @@ import sqlite3
import time
from typing import Any
from typing import Optional
from urllib.parse import unquote
from urllib.parse import urlparse
import uuid
import aiosqlite
@@ -91,6 +93,42 @@ CREATE_SCHEMA_SQL = "\n".join([
])
def _parse_db_path(db_path: str) -> tuple[str, str, bool]:
"""Normalizes a SQLite db path from a URL or filesystem path.
Returns:
A tuple of:
- filesystem path (for `os.path.exists` and user-facing messages)
- value to pass to sqlite/aiosqlite connect
- whether to pass `uri=True` to sqlite/aiosqlite connect
Notes:
When a SQLAlchemy-style SQLite URL is provided, this follows SQLAlchemy's
conventions:
- `sqlite:///relative.db` is a path relative to the current working dir.
- `sqlite:////absolute.db` is an absolute filesystem path.
"""
if not db_path.startswith(("sqlite:", "sqlite+aiosqlite:")):
return db_path, db_path, False
parsed = urlparse(db_path)
raw_path = unquote(parsed.path)
if not raw_path:
return db_path, db_path, False
normalized_path = raw_path
if normalized_path.startswith("//"):
normalized_path = normalized_path[1:]
elif normalized_path.startswith("/"):
normalized_path = normalized_path[1:]
if parsed.query:
# sqlite3 only treats the filename as a URI when it starts with `file:`.
return normalized_path, f"file:{normalized_path}?{parsed.query}", True
return normalized_path, normalized_path, False
class SqliteSessionService(BaseSessionService):
"""A session service that uses an SQLite database for storage via aiosqlite.
@@ -100,17 +138,19 @@ class SqliteSessionService(BaseSessionService):
def __init__(self, db_path: str):
"""Initializes the SQLite session service with a database path."""
self._db_path = db_path
self._db_path, self._db_connect_path, self._db_connect_uri = _parse_db_path(
db_path
)
if self._is_migration_needed():
raise RuntimeError(
f"Database {db_path} seems to use an old schema."
f"Database {self._db_path} seems to use an old schema."
" Please run the migration command to"
" migrate it to the new schema. Example: `python -m"
" google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite"
f" --source_db_path {db_path} --dest_db_path"
f" {db_path}.new` then backup {db_path} and rename"
f" {db_path}.new to {db_path}."
f" --source_db_path {self._db_path} --dest_db_path"
f" {self._db_path}.new` then backup {self._db_path} and rename"
f" {self._db_path}.new to {self._db_path}."
)
@override
@@ -415,7 +455,9 @@ class SqliteSessionService(BaseSessionService):
@asynccontextmanager
async def _get_db_connection(self):
"""Connects to the db and performs initial setup."""
async with aiosqlite.connect(self._db_path) as db:
async with aiosqlite.connect(
self._db_connect_path, uri=self._db_connect_uri
) as db:
db.row_factory = aiosqlite.Row
await db.execute(PRAGMA_FOREIGN_KEYS)
await db.executescript(CREATE_SCHEMA_SQL)
@@ -514,7 +556,9 @@ class SqliteSessionService(BaseSessionService):
if not os.path.exists(self._db_path):
return False
try:
with sqlite3.connect(self._db_path) as conn:
with sqlite3.connect(
self._db_connect_path, uri=self._db_connect_uri
) as conn:
cursor = conn.cursor()
# Check if events table exists
cursor.execute(
@@ -15,6 +15,7 @@
from datetime import datetime
from datetime import timezone
import enum
import sqlite3
from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event
@@ -60,6 +61,46 @@ async def session_service(request, tmp_path):
await service.close()
@pytest.mark.asyncio
async def test_sqlite_session_service_accepts_sqlite_urls(
tmp_path, monkeypatch
):
monkeypatch.chdir(tmp_path)
service = SqliteSessionService('sqlite+aiosqlite:///./sessions.db')
await service.create_session(app_name='app', user_id='user')
assert (tmp_path / 'sessions.db').exists()
service = SqliteSessionService('sqlite:///./sessions2.db')
await service.create_session(app_name='app', user_id='user')
assert (tmp_path / 'sessions2.db').exists()
@pytest.mark.asyncio
async def test_sqlite_session_service_preserves_uri_query_parameters(
tmp_path, monkeypatch
):
monkeypatch.chdir(tmp_path)
db_path = tmp_path / 'readonly.db'
with sqlite3.connect(db_path) as conn:
conn.execute('CREATE TABLE IF NOT EXISTS t (id INTEGER)')
conn.commit()
service = SqliteSessionService(f'sqlite+aiosqlite:///{db_path}?mode=ro')
# `mode=ro` opens the DB read-only; schema creation should fail.
with pytest.raises(sqlite3.OperationalError, match=r'readonly'):
await service.create_session(app_name='app', user_id='user')
@pytest.mark.asyncio
async def test_sqlite_session_service_accepts_absolute_sqlite_urls(tmp_path):
abs_db_path = tmp_path / 'absolute.db'
abs_url = 'sqlite+aiosqlite:////' + str(abs_db_path).lstrip('/')
service = SqliteSessionService(abs_url)
await service.create_session(app_name='app', user_id='user')
assert abs_db_path.exists()
@pytest.mark.asyncio
async def test_get_empty_session(session_service):
assert not await session_service.get_session(