diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index e0d44b38..1d9516ec 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -22,6 +22,8 @@ import sqlite3 import time from typing import Any from typing import Optional +from urllib.parse import unquote +from urllib.parse import urlparse import uuid import aiosqlite @@ -91,6 +93,42 @@ CREATE_SCHEMA_SQL = "\n".join([ ]) +def _parse_db_path(db_path: str) -> tuple[str, str, bool]: + """Normalizes a SQLite db path from a URL or filesystem path. + + Returns: + A tuple of: + - filesystem path (for `os.path.exists` and user-facing messages) + - value to pass to sqlite/aiosqlite connect + - whether to pass `uri=True` to sqlite/aiosqlite connect + + Notes: + When a SQLAlchemy-style SQLite URL is provided, this follows SQLAlchemy's + conventions: + - `sqlite:///relative.db` is a path relative to the current working dir. + - `sqlite:////absolute.db` is an absolute filesystem path. + """ + if not db_path.startswith(("sqlite:", "sqlite+aiosqlite:")): + return db_path, db_path, False + + parsed = urlparse(db_path) + raw_path = unquote(parsed.path) + if not raw_path: + return db_path, db_path, False + + normalized_path = raw_path + if normalized_path.startswith("//"): + normalized_path = normalized_path[1:] + elif normalized_path.startswith("/"): + normalized_path = normalized_path[1:] + + if parsed.query: + # sqlite3 only treats the filename as a URI when it starts with `file:`. + return normalized_path, f"file:{normalized_path}?{parsed.query}", True + + return normalized_path, normalized_path, False + + class SqliteSessionService(BaseSessionService): """A session service that uses an SQLite database for storage via aiosqlite. @@ -100,17 +138,19 @@ class SqliteSessionService(BaseSessionService): def __init__(self, db_path: str): """Initializes the SQLite session service with a database path.""" - self._db_path = db_path + self._db_path, self._db_connect_path, self._db_connect_uri = _parse_db_path( + db_path + ) if self._is_migration_needed(): raise RuntimeError( - f"Database {db_path} seems to use an old schema." + f"Database {self._db_path} seems to use an old schema." " Please run the migration command to" " migrate it to the new schema. Example: `python -m" " google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite" - f" --source_db_path {db_path} --dest_db_path" - f" {db_path}.new` then backup {db_path} and rename" - f" {db_path}.new to {db_path}." + f" --source_db_path {self._db_path} --dest_db_path" + f" {self._db_path}.new` then backup {self._db_path} and rename" + f" {self._db_path}.new to {self._db_path}." ) @override @@ -415,7 +455,9 @@ class SqliteSessionService(BaseSessionService): @asynccontextmanager async def _get_db_connection(self): """Connects to the db and performs initial setup.""" - async with aiosqlite.connect(self._db_path) as db: + async with aiosqlite.connect( + self._db_connect_path, uri=self._db_connect_uri + ) as db: db.row_factory = aiosqlite.Row await db.execute(PRAGMA_FOREIGN_KEYS) await db.executescript(CREATE_SCHEMA_SQL) @@ -514,7 +556,9 @@ class SqliteSessionService(BaseSessionService): if not os.path.exists(self._db_path): return False try: - with sqlite3.connect(self._db_path) as conn: + with sqlite3.connect( + self._db_connect_path, uri=self._db_connect_uri + ) as conn: cursor = conn.cursor() # Check if events table exists cursor.execute( diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 556d78ae..96d2f387 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -15,6 +15,7 @@ from datetime import datetime from datetime import timezone import enum +import sqlite3 from google.adk.errors.already_exists_error import AlreadyExistsError from google.adk.events.event import Event @@ -60,6 +61,46 @@ async def session_service(request, tmp_path): await service.close() +@pytest.mark.asyncio +async def test_sqlite_session_service_accepts_sqlite_urls( + tmp_path, monkeypatch +): + monkeypatch.chdir(tmp_path) + + service = SqliteSessionService('sqlite+aiosqlite:///./sessions.db') + await service.create_session(app_name='app', user_id='user') + assert (tmp_path / 'sessions.db').exists() + + service = SqliteSessionService('sqlite:///./sessions2.db') + await service.create_session(app_name='app', user_id='user') + assert (tmp_path / 'sessions2.db').exists() + + +@pytest.mark.asyncio +async def test_sqlite_session_service_preserves_uri_query_parameters( + tmp_path, monkeypatch +): + monkeypatch.chdir(tmp_path) + db_path = tmp_path / 'readonly.db' + with sqlite3.connect(db_path) as conn: + conn.execute('CREATE TABLE IF NOT EXISTS t (id INTEGER)') + conn.commit() + + service = SqliteSessionService(f'sqlite+aiosqlite:///{db_path}?mode=ro') + # `mode=ro` opens the DB read-only; schema creation should fail. + with pytest.raises(sqlite3.OperationalError, match=r'readonly'): + await service.create_session(app_name='app', user_id='user') + + +@pytest.mark.asyncio +async def test_sqlite_session_service_accepts_absolute_sqlite_urls(tmp_path): + abs_db_path = tmp_path / 'absolute.db' + abs_url = 'sqlite+aiosqlite:////' + str(abs_db_path).lstrip('/') + service = SqliteSessionService(abs_url) + await service.create_session(app_name='app', user_id='user') + assert abs_db_path.exists() + + @pytest.mark.asyncio async def test_get_empty_session(session_service): assert not await session_service.get_session(