fix: Handle async driver URLs in migration tool

The migration tool uses synchronous SQLAlchemy engines but users often provide async driver URLs (e.g., postgresql+asyncpg://) since that's what ADK requires at runtime.

This fix:
- Makes `to_sync_url()` public in `_schema_check_utils.py` for reuse
- Updates `migrate_from_sqlalchemy_pickle.py` to convert async URLs
- Updates `migrate_from_sqlalchemy_sqlite.py` to convert async URLs
- Adds comprehensive unit tests for `to_sync_url()` function
- Adds integration test for migration with async driver URLs

Fixes #4176

Co-authored-by: Liang Wu <wuliang@google.com>
PiperOrigin-RevId: 858359061
This commit is contained in:
Liang Wu
2026-01-19 19:38:52 -08:00
committed by Copybara-Service
parent 3dd7e3f1b9
commit 4b29d15b3e
4 changed files with 179 additions and 6 deletions
@@ -82,8 +82,28 @@ def get_db_schema_version_from_connection(connection) -> str:
return _get_schema_version_impl(inspector, connection)
def _to_sync_url(db_url: str) -> str:
"""Removes '+driver' from SQLAlchemy URL."""
def to_sync_url(db_url: str) -> str:
"""Removes '+driver' from SQLAlchemy URL.
This is useful when you need to use a synchronous SQLAlchemy engine with
a database URL that specifies an async driver (e.g., postgresql+asyncpg://
or sqlite+aiosqlite://).
Args:
db_url: The database URL, potentially with a driver specification.
Returns:
The database URL with the driver specification removed (e.g.,
'postgresql+asyncpg://host/db' becomes 'postgresql://host/db').
Examples:
>>> to_sync_url('postgresql+asyncpg://localhost/mydb')
'postgresql://localhost/mydb'
>>> to_sync_url('sqlite+aiosqlite:///path/to/db.sqlite')
'sqlite:///path/to/db.sqlite'
>>> to_sync_url('mysql://localhost/mydb') # No driver, returns unchanged
'mysql://localhost/mydb'
"""
if "://" in db_url:
scheme, _, rest = db_url.partition("://")
if "+" in scheme:
@@ -106,7 +126,7 @@ def get_db_schema_version(db_url: str) -> str:
"""
engine = None
try:
engine = create_sync_engine(_to_sync_url(db_url))
engine = create_sync_engine(to_sync_url(db_url))
with engine.connect() as connection:
inspector = inspect(connection)
return _get_schema_version_impl(inspector, connection)
@@ -165,9 +165,15 @@ def _get_state_dict(state_val: Any) -> dict:
# --- Migration Logic ---
def migrate(source_db_url: str, dest_db_url: str):
"""Migrates data from old pickle schema to new JSON schema."""
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
# This allows users to provide URLs like 'postgresql+asyncpg://...' and have
# them automatically converted to 'postgresql://...' for migration.
source_sync_url = _schema_check_utils.to_sync_url(source_db_url)
dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url)
logger.info(f"Connecting to source database: {source_db_url}")
try:
source_engine = create_engine(source_db_url)
source_engine = create_engine(source_sync_url)
SourceSession = sessionmaker(bind=source_engine)
except Exception as e:
logger.error(f"Failed to connect to source database: {e}")
@@ -175,7 +181,7 @@ def migrate(source_db_url: str, dest_db_url: str):
logger.info(f"Connecting to destination database: {dest_db_url}")
try:
dest_engine = create_engine(dest_db_url)
dest_engine = create_engine(dest_sync_url)
v1.Base.metadata.create_all(dest_engine)
DestSession = sessionmaker(bind=dest_engine)
except Exception as e:
@@ -23,6 +23,7 @@ import sqlite3
import sys
from google.adk.sessions import sqlite_session_service as sss
from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.schemas import v0 as v0_schema
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
@@ -32,9 +33,14 @@ logger = logging.getLogger("google_adk." + __name__)
def migrate(source_db_url: str, dest_db_path: str):
"""Migrates data from a SQLAlchemy-based SQLite DB to the new schema."""
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
# This allows users to provide URLs like 'sqlite+aiosqlite://...' and have
# them automatically converted to 'sqlite://...' for migration.
source_sync_url = _schema_check_utils.to_sync_url(source_db_url)
logger.info(f"Connecting to source database: {source_db_url}")
try:
engine = create_engine(source_db_url)
engine = create_engine(source_sync_url)
v0_schema.Base.metadata.create_all(
engine
) # Ensure tables exist for inspection
@@ -23,10 +23,88 @@ from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp
from google.adk.sessions.schemas import v0
from google.adk.sessions.schemas import v1
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
class TestToSyncUrl:
"""Tests for the to_sync_url function."""
@pytest.mark.parametrize(
"input_url,expected_url",
[
# PostgreSQL async drivers
(
"postgresql+asyncpg://localhost/mydb",
"postgresql://localhost/mydb",
),
(
"postgresql+asyncpg://user:pass@localhost:5432/mydb",
"postgresql://user:pass@localhost:5432/mydb",
),
# PostgreSQL sync drivers (should still strip)
(
"postgresql+psycopg2://localhost/mydb",
"postgresql://localhost/mydb",
),
# MySQL async drivers
(
"mysql+aiomysql://localhost/mydb",
"mysql://localhost/mydb",
),
(
"mysql+asyncmy://user:pass@localhost:3306/mydb",
"mysql://user:pass@localhost:3306/mydb",
),
# SQLite async driver
(
"sqlite+aiosqlite:///path/to/db.sqlite",
"sqlite:///path/to/db.sqlite",
),
(
"sqlite+aiosqlite:///:memory:",
"sqlite:///:memory:",
),
# URLs without driver specification (unchanged)
(
"postgresql://localhost/mydb",
"postgresql://localhost/mydb",
),
(
"mysql://localhost/mydb",
"mysql://localhost/mydb",
),
(
"sqlite:///path/to/db.sqlite",
"sqlite:///path/to/db.sqlite",
),
# Edge cases
(
"sqlite:///:memory:",
"sqlite:///:memory:",
),
# Complex URL with query parameters
(
"postgresql+asyncpg://user:pass@host/db?ssl=require",
"postgresql://user:pass@host/db?ssl=require",
),
],
)
def test_to_sync_url(self, input_url, expected_url):
"""Test that async driver specifications are correctly removed."""
assert _schema_check_utils.to_sync_url(input_url) == expected_url
def test_to_sync_url_no_scheme_separator(self):
"""Test that URLs without :// are returned unchanged."""
# This is an invalid URL but the function should handle it gracefully
assert _schema_check_utils.to_sync_url("not-a-url") == "not-a-url"
def test_to_sync_url_empty_string(self):
"""Test that empty string is returned unchanged."""
assert _schema_check_utils.to_sync_url("") == ""
def test_migrate_from_sqlalchemy_pickle(tmp_path):
"""Tests for migrate_from_sqlalchemy_pickle."""
source_db_path = tmp_path / "source_pickle.db"
@@ -104,3 +182,66 @@ def test_migrate_from_sqlalchemy_pickle(tmp_path):
assert event_res.event_data["actions"]["state_delta"] == {"skey": 4}
dest_session.close()
def test_migrate_from_sqlalchemy_pickle_with_async_driver_urls(tmp_path):
"""Tests that migration works with async driver URLs (fixes issue #4176).
Users often provide async driver URLs (e.g., postgresql+asyncpg://) since
that's what ADK requires at runtime. The migration tool should handle these
by automatically converting them to sync URLs.
"""
source_db_path = tmp_path / "source_pickle_async.db"
dest_db_path = tmp_path / "dest_json_async.db"
# Use async driver URLs like users would typically provide
source_db_url = f"sqlite+aiosqlite:///{source_db_path}"
dest_db_url = f"sqlite+aiosqlite:///{dest_db_path}"
# Set up source DB with old pickle schema using sync URL
sync_source_url = f"sqlite:///{source_db_path}"
source_engine = create_engine(sync_source_url)
v0.Base.metadata.create_all(source_engine)
SourceSession = sessionmaker(bind=source_engine)
source_session = SourceSession()
# Populate source data
now = datetime.now(timezone.utc)
app_state = v0.StorageAppState(
app_name="async_app", state={"key": "value"}, update_time=now
)
session = v0.StorageSession(
app_name="async_app",
user_id="async_user",
id="async_session",
state={},
create_time=now,
update_time=now,
)
source_session.add_all([app_state, session])
source_session.commit()
source_session.close()
# This should NOT raise an error about async drivers (the fix for #4176)
mfsp.migrate(source_db_url, dest_db_url)
# Verify destination DB
sync_dest_url = f"sqlite:///{dest_db_path}"
dest_engine = create_engine(sync_dest_url)
DestSession = sessionmaker(bind=dest_engine)
dest_session = DestSession()
metadata = dest_session.query(v1.StorageMetadata).first()
assert metadata is not None
assert metadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
assert metadata.value == _schema_check_utils.SCHEMA_VERSION_1_JSON
app_state_res = dest_session.query(v1.StorageAppState).first()
assert app_state_res is not None
assert app_state_res.app_name == "async_app"
assert app_state_res.state == {"key": "value"}
session_res = dest_session.query(v1.StorageSession).first()
assert session_res is not None
assert session_res.id == "async_session"
dest_session.close()