You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Handle SQLite URLs in SqliteSessionService
The SqliteSessionService now accepts database paths in the form of SQLite URLs (e.g., "sqlite:///./sessions.db", "sqlite+aiosqlite:////absolute.db") Close #4077 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 853922433
This commit is contained in:
committed by
Copybara-Service
parent
3c51ee7f48
commit
b8917bc80e
@@ -22,6 +22,8 @@ import sqlite3
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
import uuid
|
||||
|
||||
import aiosqlite
|
||||
@@ -91,6 +93,42 @@ CREATE_SCHEMA_SQL = "\n".join([
|
||||
])
|
||||
|
||||
|
||||
def _parse_db_path(db_path: str) -> tuple[str, str, bool]:
|
||||
"""Normalizes a SQLite db path from a URL or filesystem path.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- filesystem path (for `os.path.exists` and user-facing messages)
|
||||
- value to pass to sqlite/aiosqlite connect
|
||||
- whether to pass `uri=True` to sqlite/aiosqlite connect
|
||||
|
||||
Notes:
|
||||
When a SQLAlchemy-style SQLite URL is provided, this follows SQLAlchemy's
|
||||
conventions:
|
||||
- `sqlite:///relative.db` is a path relative to the current working dir.
|
||||
- `sqlite:////absolute.db` is an absolute filesystem path.
|
||||
"""
|
||||
if not db_path.startswith(("sqlite:", "sqlite+aiosqlite:")):
|
||||
return db_path, db_path, False
|
||||
|
||||
parsed = urlparse(db_path)
|
||||
raw_path = unquote(parsed.path)
|
||||
if not raw_path:
|
||||
return db_path, db_path, False
|
||||
|
||||
normalized_path = raw_path
|
||||
if normalized_path.startswith("//"):
|
||||
normalized_path = normalized_path[1:]
|
||||
elif normalized_path.startswith("/"):
|
||||
normalized_path = normalized_path[1:]
|
||||
|
||||
if parsed.query:
|
||||
# sqlite3 only treats the filename as a URI when it starts with `file:`.
|
||||
return normalized_path, f"file:{normalized_path}?{parsed.query}", True
|
||||
|
||||
return normalized_path, normalized_path, False
|
||||
|
||||
|
||||
class SqliteSessionService(BaseSessionService):
|
||||
"""A session service that uses an SQLite database for storage via aiosqlite.
|
||||
|
||||
@@ -100,17 +138,19 @@ class SqliteSessionService(BaseSessionService):
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
"""Initializes the SQLite session service with a database path."""
|
||||
self._db_path = db_path
|
||||
self._db_path, self._db_connect_path, self._db_connect_uri = _parse_db_path(
|
||||
db_path
|
||||
)
|
||||
|
||||
if self._is_migration_needed():
|
||||
raise RuntimeError(
|
||||
f"Database {db_path} seems to use an old schema."
|
||||
f"Database {self._db_path} seems to use an old schema."
|
||||
" Please run the migration command to"
|
||||
" migrate it to the new schema. Example: `python -m"
|
||||
" google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite"
|
||||
f" --source_db_path {db_path} --dest_db_path"
|
||||
f" {db_path}.new` then backup {db_path} and rename"
|
||||
f" {db_path}.new to {db_path}."
|
||||
f" --source_db_path {self._db_path} --dest_db_path"
|
||||
f" {self._db_path}.new` then backup {self._db_path} and rename"
|
||||
f" {self._db_path}.new to {self._db_path}."
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -415,7 +455,9 @@ class SqliteSessionService(BaseSessionService):
|
||||
@asynccontextmanager
|
||||
async def _get_db_connection(self):
|
||||
"""Connects to the db and performs initial setup."""
|
||||
async with aiosqlite.connect(self._db_path) as db:
|
||||
async with aiosqlite.connect(
|
||||
self._db_connect_path, uri=self._db_connect_uri
|
||||
) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute(PRAGMA_FOREIGN_KEYS)
|
||||
await db.executescript(CREATE_SCHEMA_SQL)
|
||||
@@ -514,7 +556,9 @@ class SqliteSessionService(BaseSessionService):
|
||||
if not os.path.exists(self._db_path):
|
||||
return False
|
||||
try:
|
||||
with sqlite3.connect(self._db_path) as conn:
|
||||
with sqlite3.connect(
|
||||
self._db_connect_path, uri=self._db_connect_uri
|
||||
) as conn:
|
||||
cursor = conn.cursor()
|
||||
# Check if events table exists
|
||||
cursor.execute(
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
import enum
|
||||
import sqlite3
|
||||
|
||||
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||
from google.adk.events.event import Event
|
||||
@@ -60,6 +61,46 @@ async def session_service(request, tmp_path):
|
||||
await service.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlite_session_service_accepts_sqlite_urls(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
service = SqliteSessionService('sqlite+aiosqlite:///./sessions.db')
|
||||
await service.create_session(app_name='app', user_id='user')
|
||||
assert (tmp_path / 'sessions.db').exists()
|
||||
|
||||
service = SqliteSessionService('sqlite:///./sessions2.db')
|
||||
await service.create_session(app_name='app', user_id='user')
|
||||
assert (tmp_path / 'sessions2.db').exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlite_session_service_preserves_uri_query_parameters(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
db_path = tmp_path / 'readonly.db'
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
conn.execute('CREATE TABLE IF NOT EXISTS t (id INTEGER)')
|
||||
conn.commit()
|
||||
|
||||
service = SqliteSessionService(f'sqlite+aiosqlite:///{db_path}?mode=ro')
|
||||
# `mode=ro` opens the DB read-only; schema creation should fail.
|
||||
with pytest.raises(sqlite3.OperationalError, match=r'readonly'):
|
||||
await service.create_session(app_name='app', user_id='user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlite_session_service_accepts_absolute_sqlite_urls(tmp_path):
|
||||
abs_db_path = tmp_path / 'absolute.db'
|
||||
abs_url = 'sqlite+aiosqlite:////' + str(abs_db_path).lstrip('/')
|
||||
service = SqliteSessionService(abs_url)
|
||||
await service.create_session(app_name='app', user_id='user')
|
||||
assert abs_db_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_empty_session(session_service):
|
||||
assert not await session_service.get_session(
|
||||
|
||||
Reference in New Issue
Block a user