You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Handle async driver URLs in migration tool
The migration tool uses synchronous SQLAlchemy engines but users often provide async driver URLs (e.g., postgresql+asyncpg://) since that's what ADK requires at runtime. This fix: - Makes `to_sync_url()` public in `_schema_check_utils.py` for reuse - Updates `migrate_from_sqlalchemy_pickle.py` to convert async URLs - Updates `migrate_from_sqlalchemy_sqlite.py` to convert async URLs - Adds comprehensive unit tests for `to_sync_url()` function - Adds integration test for migration with async driver URLs Fixes #4176 Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 858359061
This commit is contained in:
committed by
Copybara-Service
parent
3dd7e3f1b9
commit
4b29d15b3e
@@ -82,8 +82,28 @@ def get_db_schema_version_from_connection(connection) -> str:
|
||||
return _get_schema_version_impl(inspector, connection)
|
||||
|
||||
|
||||
def _to_sync_url(db_url: str) -> str:
|
||||
"""Removes '+driver' from SQLAlchemy URL."""
|
||||
def to_sync_url(db_url: str) -> str:
|
||||
"""Removes '+driver' from SQLAlchemy URL.
|
||||
|
||||
This is useful when you need to use a synchronous SQLAlchemy engine with
|
||||
a database URL that specifies an async driver (e.g., postgresql+asyncpg://
|
||||
or sqlite+aiosqlite://).
|
||||
|
||||
Args:
|
||||
db_url: The database URL, potentially with a driver specification.
|
||||
|
||||
Returns:
|
||||
The database URL with the driver specification removed (e.g.,
|
||||
'postgresql+asyncpg://host/db' becomes 'postgresql://host/db').
|
||||
|
||||
Examples:
|
||||
>>> to_sync_url('postgresql+asyncpg://localhost/mydb')
|
||||
'postgresql://localhost/mydb'
|
||||
>>> to_sync_url('sqlite+aiosqlite:///path/to/db.sqlite')
|
||||
'sqlite:///path/to/db.sqlite'
|
||||
>>> to_sync_url('mysql://localhost/mydb') # No driver, returns unchanged
|
||||
'mysql://localhost/mydb'
|
||||
"""
|
||||
if "://" in db_url:
|
||||
scheme, _, rest = db_url.partition("://")
|
||||
if "+" in scheme:
|
||||
@@ -106,7 +126,7 @@ def get_db_schema_version(db_url: str) -> str:
|
||||
"""
|
||||
engine = None
|
||||
try:
|
||||
engine = create_sync_engine(_to_sync_url(db_url))
|
||||
engine = create_sync_engine(to_sync_url(db_url))
|
||||
with engine.connect() as connection:
|
||||
inspector = inspect(connection)
|
||||
return _get_schema_version_impl(inspector, connection)
|
||||
|
||||
@@ -165,9 +165,15 @@ def _get_state_dict(state_val: Any) -> dict:
|
||||
# --- Migration Logic ---
|
||||
def migrate(source_db_url: str, dest_db_url: str):
|
||||
"""Migrates data from old pickle schema to new JSON schema."""
|
||||
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
|
||||
# This allows users to provide URLs like 'postgresql+asyncpg://...' and have
|
||||
# them automatically converted to 'postgresql://...' for migration.
|
||||
source_sync_url = _schema_check_utils.to_sync_url(source_db_url)
|
||||
dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url)
|
||||
|
||||
logger.info(f"Connecting to source database: {source_db_url}")
|
||||
try:
|
||||
source_engine = create_engine(source_db_url)
|
||||
source_engine = create_engine(source_sync_url)
|
||||
SourceSession = sessionmaker(bind=source_engine)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to source database: {e}")
|
||||
@@ -175,7 +181,7 @@ def migrate(source_db_url: str, dest_db_url: str):
|
||||
|
||||
logger.info(f"Connecting to destination database: {dest_db_url}")
|
||||
try:
|
||||
dest_engine = create_engine(dest_db_url)
|
||||
dest_engine = create_engine(dest_sync_url)
|
||||
v1.Base.metadata.create_all(dest_engine)
|
||||
DestSession = sessionmaker(bind=dest_engine)
|
||||
except Exception as e:
|
||||
|
||||
@@ -23,6 +23,7 @@ import sqlite3
|
||||
import sys
|
||||
|
||||
from google.adk.sessions import sqlite_session_service as sss
|
||||
from google.adk.sessions.migration import _schema_check_utils
|
||||
from google.adk.sessions.schemas import v0 as v0_schema
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@@ -32,9 +33,14 @@ logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
def migrate(source_db_url: str, dest_db_path: str):
|
||||
"""Migrates data from a SQLAlchemy-based SQLite DB to the new schema."""
|
||||
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
|
||||
# This allows users to provide URLs like 'sqlite+aiosqlite://...' and have
|
||||
# them automatically converted to 'sqlite://...' for migration.
|
||||
source_sync_url = _schema_check_utils.to_sync_url(source_db_url)
|
||||
|
||||
logger.info(f"Connecting to source database: {source_db_url}")
|
||||
try:
|
||||
engine = create_engine(source_db_url)
|
||||
engine = create_engine(source_sync_url)
|
||||
v0_schema.Base.metadata.create_all(
|
||||
engine
|
||||
) # Ensure tables exist for inspection
|
||||
|
||||
@@ -23,10 +23,88 @@ from google.adk.sessions.migration import _schema_check_utils
|
||||
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp
|
||||
from google.adk.sessions.schemas import v0
|
||||
from google.adk.sessions.schemas import v1
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
class TestToSyncUrl:
|
||||
"""Tests for the to_sync_url function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_url,expected_url",
|
||||
[
|
||||
# PostgreSQL async drivers
|
||||
(
|
||||
"postgresql+asyncpg://localhost/mydb",
|
||||
"postgresql://localhost/mydb",
|
||||
),
|
||||
(
|
||||
"postgresql+asyncpg://user:pass@localhost:5432/mydb",
|
||||
"postgresql://user:pass@localhost:5432/mydb",
|
||||
),
|
||||
# PostgreSQL sync drivers (should still strip)
|
||||
(
|
||||
"postgresql+psycopg2://localhost/mydb",
|
||||
"postgresql://localhost/mydb",
|
||||
),
|
||||
# MySQL async drivers
|
||||
(
|
||||
"mysql+aiomysql://localhost/mydb",
|
||||
"mysql://localhost/mydb",
|
||||
),
|
||||
(
|
||||
"mysql+asyncmy://user:pass@localhost:3306/mydb",
|
||||
"mysql://user:pass@localhost:3306/mydb",
|
||||
),
|
||||
# SQLite async driver
|
||||
(
|
||||
"sqlite+aiosqlite:///path/to/db.sqlite",
|
||||
"sqlite:///path/to/db.sqlite",
|
||||
),
|
||||
(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
"sqlite:///:memory:",
|
||||
),
|
||||
# URLs without driver specification (unchanged)
|
||||
(
|
||||
"postgresql://localhost/mydb",
|
||||
"postgresql://localhost/mydb",
|
||||
),
|
||||
(
|
||||
"mysql://localhost/mydb",
|
||||
"mysql://localhost/mydb",
|
||||
),
|
||||
(
|
||||
"sqlite:///path/to/db.sqlite",
|
||||
"sqlite:///path/to/db.sqlite",
|
||||
),
|
||||
# Edge cases
|
||||
(
|
||||
"sqlite:///:memory:",
|
||||
"sqlite:///:memory:",
|
||||
),
|
||||
# Complex URL with query parameters
|
||||
(
|
||||
"postgresql+asyncpg://user:pass@host/db?ssl=require",
|
||||
"postgresql://user:pass@host/db?ssl=require",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_to_sync_url(self, input_url, expected_url):
|
||||
"""Test that async driver specifications are correctly removed."""
|
||||
assert _schema_check_utils.to_sync_url(input_url) == expected_url
|
||||
|
||||
def test_to_sync_url_no_scheme_separator(self):
|
||||
"""Test that URLs without :// are returned unchanged."""
|
||||
# This is an invalid URL but the function should handle it gracefully
|
||||
assert _schema_check_utils.to_sync_url("not-a-url") == "not-a-url"
|
||||
|
||||
def test_to_sync_url_empty_string(self):
|
||||
"""Test that empty string is returned unchanged."""
|
||||
assert _schema_check_utils.to_sync_url("") == ""
|
||||
|
||||
|
||||
def test_migrate_from_sqlalchemy_pickle(tmp_path):
|
||||
"""Tests for migrate_from_sqlalchemy_pickle."""
|
||||
source_db_path = tmp_path / "source_pickle.db"
|
||||
@@ -104,3 +182,66 @@ def test_migrate_from_sqlalchemy_pickle(tmp_path):
|
||||
assert event_res.event_data["actions"]["state_delta"] == {"skey": 4}
|
||||
|
||||
dest_session.close()
|
||||
|
||||
|
||||
def test_migrate_from_sqlalchemy_pickle_with_async_driver_urls(tmp_path):
|
||||
"""Tests that migration works with async driver URLs (fixes issue #4176).
|
||||
|
||||
Users often provide async driver URLs (e.g., postgresql+asyncpg://) since
|
||||
that's what ADK requires at runtime. The migration tool should handle these
|
||||
by automatically converting them to sync URLs.
|
||||
"""
|
||||
source_db_path = tmp_path / "source_pickle_async.db"
|
||||
dest_db_path = tmp_path / "dest_json_async.db"
|
||||
# Use async driver URLs like users would typically provide
|
||||
source_db_url = f"sqlite+aiosqlite:///{source_db_path}"
|
||||
dest_db_url = f"sqlite+aiosqlite:///{dest_db_path}"
|
||||
|
||||
# Set up source DB with old pickle schema using sync URL
|
||||
sync_source_url = f"sqlite:///{source_db_path}"
|
||||
source_engine = create_engine(sync_source_url)
|
||||
v0.Base.metadata.create_all(source_engine)
|
||||
SourceSession = sessionmaker(bind=source_engine)
|
||||
source_session = SourceSession()
|
||||
|
||||
# Populate source data
|
||||
now = datetime.now(timezone.utc)
|
||||
app_state = v0.StorageAppState(
|
||||
app_name="async_app", state={"key": "value"}, update_time=now
|
||||
)
|
||||
session = v0.StorageSession(
|
||||
app_name="async_app",
|
||||
user_id="async_user",
|
||||
id="async_session",
|
||||
state={},
|
||||
create_time=now,
|
||||
update_time=now,
|
||||
)
|
||||
source_session.add_all([app_state, session])
|
||||
source_session.commit()
|
||||
source_session.close()
|
||||
|
||||
# This should NOT raise an error about async drivers (the fix for #4176)
|
||||
mfsp.migrate(source_db_url, dest_db_url)
|
||||
|
||||
# Verify destination DB
|
||||
sync_dest_url = f"sqlite:///{dest_db_path}"
|
||||
dest_engine = create_engine(sync_dest_url)
|
||||
DestSession = sessionmaker(bind=dest_engine)
|
||||
dest_session = DestSession()
|
||||
|
||||
metadata = dest_session.query(v1.StorageMetadata).first()
|
||||
assert metadata is not None
|
||||
assert metadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
|
||||
assert metadata.value == _schema_check_utils.SCHEMA_VERSION_1_JSON
|
||||
|
||||
app_state_res = dest_session.query(v1.StorageAppState).first()
|
||||
assert app_state_res is not None
|
||||
assert app_state_res.app_name == "async_app"
|
||||
assert app_state_res.state == {"key": "value"}
|
||||
|
||||
session_res = dest_session.query(v1.StorageSession).first()
|
||||
assert session_res is not None
|
||||
assert session_res.id == "async_session"
|
||||
|
||||
dest_session.close()
|
||||
|
||||
Reference in New Issue
Block a user