You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add database schema migration command and script
Final part of https://github.com/google/adk-python/discussions/3605. This change introduces: - A new `adk migrate session` CLI command to run database schema upgrades. - A migration script to upgrade from the old Pickle-based session schema (v0) to the new JSON-based schema (v1). - A migration runner that orchestrates the upgrade process, handling sequential migrations and using temporary SQLite databases for intermediate steps if needed. - Unit tests for the v0 to v1 migration. Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 852983323
This commit is contained in:
committed by
Copybara-Service
parent
0827d12ccd
commit
ce64787c3e
@@ -36,6 +36,7 @@ from . import cli_create
|
||||
from . import cli_deploy
|
||||
from .. import version
|
||||
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
from ..sessions.migration import migration_runner
|
||||
from .cli import run_cli
|
||||
from .fast_api import get_fast_api_app
|
||||
from .utils import envs
|
||||
@@ -1507,6 +1508,47 @@ def cli_deploy_cloud_run(
|
||||
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
||||
|
||||
|
||||
@main.group()
|
||||
def migrate():
|
||||
"""ADK migration commands."""
|
||||
pass
|
||||
|
||||
|
||||
@migrate.command("session", cls=HelpfulCommand)
|
||||
@click.option(
|
||||
"--source_db_url",
|
||||
required=True,
|
||||
help=(
|
||||
"SQLAlchemy URL of source database in database session service, e.g."
|
||||
" sqlite:///source.db."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--dest_db_url",
|
||||
required=True,
|
||||
help=(
|
||||
"SQLAlchemy URL of destination database in database session service,"
|
||||
" e.g. sqlite:///dest.db."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
type=LOG_LEVELS,
|
||||
default="INFO",
|
||||
help="Optional. Set the logging level",
|
||||
)
|
||||
def cli_migrate_session(
|
||||
*, source_db_url: str, dest_db_url: str, log_level: str
|
||||
):
|
||||
"""Migrates a session database to the latest schema version."""
|
||||
logs.setup_adk_logger(getattr(logging, log_level.upper()))
|
||||
try:
|
||||
migration_runner.upgrade(source_db_url, dest_db_url)
|
||||
click.secho("Migration check and upgrade process finished.", fg="green")
|
||||
except Exception as e:
|
||||
click.secho(f"Migration failed: {e}", fg="red", err=True)
|
||||
|
||||
|
||||
@deploy.command("agent_engine")
|
||||
@click.option(
|
||||
"--api_key",
|
||||
|
||||
@@ -178,12 +178,9 @@ class DatabaseSessionService(BaseSessionService):
|
||||
self._db_schema_version = await conn.run_sync(
|
||||
_schema_check_utils.get_db_schema_version_from_connection
|
||||
)
|
||||
except Exception:
|
||||
# If inspection fails, assume the latest schema
|
||||
logger.warning(
|
||||
"Failed to inspect database tables, assuming the latest schema."
|
||||
)
|
||||
self._db_schema_version = _schema_check_utils.LATEST_SCHEMA_VERSION
|
||||
except Exception as e:
|
||||
logger.error("Failed to inspect database tables: %s", e)
|
||||
raise
|
||||
|
||||
# Check if tables are created and create them if not
|
||||
if self._tables_created:
|
||||
|
||||
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import create_engine as create_sync_engine
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
|
||||
@@ -38,14 +39,16 @@ def _get_schema_version_impl(inspector, connection) -> str:
|
||||
if result:
|
||||
return result[0]
|
||||
else:
|
||||
return LATEST_SCHEMA_VERSION
|
||||
raise ValueError(
|
||||
"Schema version not found in adk_internal_metadata. The database"
|
||||
" might be malformed."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to query schema version from adk_internal_metadata,"
|
||||
" assuming the latest schema: %s.",
|
||||
logger.error(
|
||||
"Failed to query schema version from adk_internal_metadata: %s.",
|
||||
e,
|
||||
)
|
||||
return LATEST_SCHEMA_VERSION
|
||||
raise
|
||||
# Metadata table doesn't exist, check for v0 schema.
|
||||
# V0 schema has an 'events' table with an 'actions' column.
|
||||
if inspector.has_table("events"):
|
||||
@@ -57,13 +60,14 @@ def _get_schema_version_impl(inspector, connection) -> str:
|
||||
" serialize event actions. The v0 schema will not be supported"
|
||||
" going forward and will be deprecated in a few rollouts. Please"
|
||||
" migrate to the v1 schema which uses JSON serialization for event"
|
||||
" data. The migration command and script will be provided soon."
|
||||
" data. You can use `adk migrate session` command to migrate your"
|
||||
" database."
|
||||
)
|
||||
return SCHEMA_VERSION_0_PICKLE
|
||||
except Exception as e:
|
||||
logger.warning("Failed to inspect 'events' table columns: %s", e)
|
||||
return LATEST_SCHEMA_VERSION
|
||||
# New database, assume the latest schema.
|
||||
logger.error("Failed to inspect 'events' table columns: %s", e)
|
||||
raise
|
||||
# New database, use the latest schema.
|
||||
return LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
@@ -71,3 +75,42 @@ def get_db_schema_version_from_connection(connection) -> str:
|
||||
"""Gets DB schema version from a DB connection."""
|
||||
inspector = inspect(connection)
|
||||
return _get_schema_version_impl(inspector, connection)
|
||||
|
||||
|
||||
def _to_sync_url(db_url: str) -> str:
|
||||
"""Removes '+driver' from SQLAlchemy URL."""
|
||||
if "://" in db_url:
|
||||
scheme, _, rest = db_url.partition("://")
|
||||
if "+" in scheme:
|
||||
dialect = scheme.split("+", 1)[0]
|
||||
return f"{dialect}://{rest}"
|
||||
return db_url
|
||||
|
||||
|
||||
def get_db_schema_version(db_url: str) -> str:
|
||||
"""Reads schema version from DB.
|
||||
|
||||
Checks metadata table first and then falls back to table structure.
|
||||
|
||||
Args:
|
||||
db_url: The database URL.
|
||||
|
||||
Returns:
|
||||
The detected schema version as a string. Returns `LATEST_SCHEMA_VERSION`
|
||||
if it's a new database.
|
||||
"""
|
||||
engine = None
|
||||
try:
|
||||
engine = create_sync_engine(_to_sync_url(db_url))
|
||||
with engine.connect() as connection:
|
||||
inspector = inspect(connection)
|
||||
return _get_schema_version_impl(inspector, connection)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to get schema version from database %s.",
|
||||
db_url,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if engine:
|
||||
engine.dispose()
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Migration script from SQLAlchemy DB with Pickle Events to JSON schema."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.events.event_actions import EventActions
|
||||
from google.adk.sessions import _session_util
|
||||
from google.adk.sessions.migration import _schema_check_utils
|
||||
from google.adk.sessions.schemas import v1
|
||||
from google.genai import types
|
||||
import sqlalchemy
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
def _to_datetime_obj(val: Any) -> datetime | Any:
|
||||
"""Converts string to datetime if needed."""
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return datetime.strptime(val, "%Y-%m-%d %H:%M:%S.%f")
|
||||
except ValueError:
|
||||
try:
|
||||
return datetime.strptime(val, "%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
pass # return as is if not matching format
|
||||
return val
|
||||
|
||||
|
||||
def _row_to_event(row: dict) -> Event:
|
||||
"""Converts event row (dict) to event object, handling missing columns and deserializing."""
|
||||
|
||||
actions_val = row.get("actions")
|
||||
actions = None
|
||||
if actions_val is not None:
|
||||
try:
|
||||
if isinstance(actions_val, bytes):
|
||||
actions = pickle.loads(actions_val)
|
||||
else: # for spanner - it might return object directly
|
||||
actions = actions_val
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to unpickle actions for event {row.get('id')}: {e}"
|
||||
)
|
||||
actions = None
|
||||
|
||||
if actions and hasattr(actions, "model_dump"):
|
||||
actions = EventActions().model_validate(actions.model_dump())
|
||||
elif isinstance(actions, dict):
|
||||
actions = EventActions(**actions)
|
||||
else:
|
||||
actions = EventActions()
|
||||
|
||||
def _safe_json_load(val):
|
||||
data = None
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
data = json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to decode JSON for event {row.get('id')}")
|
||||
return None
|
||||
elif isinstance(val, dict):
|
||||
data = val # for postgres JSONB
|
||||
return data
|
||||
|
||||
content_dict = _safe_json_load(row.get("content"))
|
||||
grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata"))
|
||||
custom_metadata_dict = _safe_json_load(row.get("custom_metadata"))
|
||||
usage_metadata_dict = _safe_json_load(row.get("usage_metadata"))
|
||||
citation_metadata_dict = _safe_json_load(row.get("citation_metadata"))
|
||||
input_transcription_dict = _safe_json_load(row.get("input_transcription"))
|
||||
output_transcription_dict = _safe_json_load(row.get("output_transcription"))
|
||||
|
||||
long_running_tool_ids_json = row.get("long_running_tool_ids_json")
|
||||
long_running_tool_ids = set()
|
||||
if long_running_tool_ids_json:
|
||||
try:
|
||||
long_running_tool_ids = set(json.loads(long_running_tool_ids_json))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to decode long_running_tool_ids_json for event"
|
||||
f" {row.get('id')}"
|
||||
)
|
||||
long_running_tool_ids = set()
|
||||
|
||||
event_id = row.get("id")
|
||||
if not event_id:
|
||||
raise ValueError("Event must have an id.")
|
||||
timestamp = _to_datetime_obj(row.get("timestamp"))
|
||||
if not timestamp:
|
||||
raise ValueError(f"Event {event_id} must have a timestamp.")
|
||||
|
||||
return Event(
|
||||
id=event_id,
|
||||
invocation_id=row.get("invocation_id", ""),
|
||||
author=row.get("author", "agent"),
|
||||
branch=row.get("branch"),
|
||||
actions=actions,
|
||||
timestamp=timestamp.replace(tzinfo=timezone.utc).timestamp(),
|
||||
long_running_tool_ids=long_running_tool_ids,
|
||||
partial=row.get("partial"),
|
||||
turn_complete=row.get("turn_complete"),
|
||||
error_code=row.get("error_code"),
|
||||
error_message=row.get("error_message"),
|
||||
interrupted=row.get("interrupted"),
|
||||
custom_metadata=custom_metadata_dict,
|
||||
content=_session_util.decode_model(content_dict, types.Content),
|
||||
grounding_metadata=_session_util.decode_model(
|
||||
grounding_metadata_dict, types.GroundingMetadata
|
||||
),
|
||||
usage_metadata=_session_util.decode_model(
|
||||
usage_metadata_dict, types.GenerateContentResponseUsageMetadata
|
||||
),
|
||||
citation_metadata=_session_util.decode_model(
|
||||
citation_metadata_dict, types.CitationMetadata
|
||||
),
|
||||
input_transcription=_session_util.decode_model(
|
||||
input_transcription_dict, types.Transcription
|
||||
),
|
||||
output_transcription=_session_util.decode_model(
|
||||
output_transcription_dict, types.Transcription
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_state_dict(state_val: Any) -> dict:
|
||||
"""Safely load dict from JSON string or return dict if already dict."""
|
||||
if isinstance(state_val, dict):
|
||||
return state_val
|
||||
if isinstance(state_val, str):
|
||||
try:
|
||||
return json.loads(state_val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to parse state JSON string, defaulting to empty dict."
|
||||
)
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
# --- Migration Logic ---
|
||||
def migrate(source_db_url: str, dest_db_url: str):
|
||||
"""Migrates data from old pickle schema to new JSON schema."""
|
||||
logger.info(f"Connecting to source database: {source_db_url}")
|
||||
try:
|
||||
source_engine = create_engine(source_db_url)
|
||||
SourceSession = sessionmaker(bind=source_engine)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to source database: {e}")
|
||||
raise RuntimeError(f"Failed to connect to source database: {e}") from e
|
||||
|
||||
logger.info(f"Connecting to destination database: {dest_db_url}")
|
||||
try:
|
||||
dest_engine = create_engine(dest_db_url)
|
||||
v1.Base.metadata.create_all(dest_engine)
|
||||
DestSession = sessionmaker(bind=dest_engine)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to destination database: {e}")
|
||||
raise RuntimeError(f"Failed to connect to destination database: {e}") from e
|
||||
|
||||
with SourceSession() as source_session, DestSession() as dest_session:
|
||||
try:
|
||||
dest_session.merge(
|
||||
v1.StorageMetadata(
|
||||
key=_schema_check_utils.SCHEMA_VERSION_KEY,
|
||||
value=_schema_check_utils.SCHEMA_VERSION_1_JSON,
|
||||
)
|
||||
)
|
||||
logger.info("Created metadata table in destination database.")
|
||||
|
||||
inspector = sqlalchemy.inspect(source_engine)
|
||||
|
||||
logger.info("Migrating app_states...")
|
||||
if inspector.has_table("app_states"):
|
||||
num_rows = 0
|
||||
for row in source_session.execute(
|
||||
text("SELECT * FROM app_states")
|
||||
).mappings():
|
||||
num_rows += 1
|
||||
dest_session.merge(
|
||||
v1.StorageAppState(
|
||||
app_name=row["app_name"],
|
||||
state=_get_state_dict(row.get("state")),
|
||||
update_time=_to_datetime_obj(row["update_time"]),
|
||||
)
|
||||
)
|
||||
logger.info(f"Migrated {num_rows} app_states.")
|
||||
else:
|
||||
logger.info("No 'app_states' table found in source db.")
|
||||
|
||||
logger.info("Migrating user_states...")
|
||||
if inspector.has_table("user_states"):
|
||||
num_rows = 0
|
||||
for row in source_session.execute(
|
||||
text("SELECT * FROM user_states")
|
||||
).mappings():
|
||||
num_rows += 1
|
||||
dest_session.merge(
|
||||
v1.StorageUserState(
|
||||
app_name=row["app_name"],
|
||||
user_id=row["user_id"],
|
||||
state=_get_state_dict(row.get("state")),
|
||||
update_time=_to_datetime_obj(row["update_time"]),
|
||||
)
|
||||
)
|
||||
logger.info(f"Migrated {num_rows} user_states.")
|
||||
else:
|
||||
logger.info("No 'user_states' table found in source db.")
|
||||
|
||||
logger.info("Migrating sessions...")
|
||||
if inspector.has_table("sessions"):
|
||||
num_rows = 0
|
||||
for row in source_session.execute(
|
||||
text("SELECT * FROM sessions")
|
||||
).mappings():
|
||||
num_rows += 1
|
||||
dest_session.merge(
|
||||
v1.StorageSession(
|
||||
app_name=row["app_name"],
|
||||
user_id=row["user_id"],
|
||||
id=row["id"],
|
||||
state=_get_state_dict(row.get("state")),
|
||||
create_time=_to_datetime_obj(row["create_time"]),
|
||||
update_time=_to_datetime_obj(row["update_time"]),
|
||||
)
|
||||
)
|
||||
logger.info(f"Migrated {num_rows} sessions.")
|
||||
else:
|
||||
logger.info("No 'sessions' table found in source db.")
|
||||
|
||||
logger.info("Migrating events...")
|
||||
num_rows = 0
|
||||
if inspector.has_table("events"):
|
||||
for row in source_session.execute(
|
||||
text("SELECT * FROM events")
|
||||
).mappings():
|
||||
try:
|
||||
event_obj = _row_to_event(dict(row))
|
||||
new_event = v1.StorageEvent(
|
||||
id=event_obj.id,
|
||||
app_name=row["app_name"],
|
||||
user_id=row["user_id"],
|
||||
session_id=row["session_id"],
|
||||
invocation_id=event_obj.invocation_id,
|
||||
timestamp=datetime.fromtimestamp(
|
||||
event_obj.timestamp, timezone.utc
|
||||
).replace(tzinfo=None),
|
||||
event_data=event_obj.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
dest_session.merge(new_event)
|
||||
num_rows += 1
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to migrate event row {row.get('id', 'N/A')}: {e}"
|
||||
)
|
||||
logger.info(f"Migrated {num_rows} events.")
|
||||
else:
|
||||
logger.info("No 'events' table found in source database.")
|
||||
|
||||
dest_session.commit()
|
||||
logger.info("Migration completed successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during migration: {e}", exc_info=True)
|
||||
dest_session.rollback()
|
||||
raise RuntimeError(f"An error occurred during migration: {e}") from e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Migrate ADK sessions from SQLAlchemy Pickle format to JSON format."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source_db_url", required=True, help="SQLAlchemy URL of source database"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dest_db_url",
|
||||
required=True,
|
||||
help="SQLAlchemy URL of destination database",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
migrate(args.source_db_url, args.dest_db_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {e}")
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Migration runner to upgrade schemas to the latest version."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from google.adk.sessions.migration import _schema_check_utils
|
||||
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
# Migration map where key is start_version and value is
|
||||
# (end_version, migration_function).
|
||||
# Each key is a schema version, and its value is a tuple containing:
|
||||
# (the schema version AFTER this migration step, the migration function to run).
|
||||
# The migration function should accept (source_db_url, dest_db_url) as
|
||||
# arguments.
|
||||
MIGRATIONS = {
|
||||
_schema_check_utils.SCHEMA_VERSION_0_PICKLE: (
|
||||
_schema_check_utils.SCHEMA_VERSION_1_JSON,
|
||||
migrate_from_sqlalchemy_pickle.migrate,
|
||||
),
|
||||
}
|
||||
# The most recent schema version. The migration process stops once this version
|
||||
# is reached.
|
||||
LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
def upgrade(source_db_url: str, dest_db_url: str):
|
||||
"""Migrates a database from its current version to the latest version.
|
||||
|
||||
If the source database schema is older than the latest version, this
|
||||
function applies migration scripts sequentially until the schema reaches the
|
||||
LATEST_VERSION.
|
||||
|
||||
If multiple migration steps are required, intermediate results are stored in
|
||||
temporary SQLite database files. This means a multi-step migration
|
||||
between other database types (e.g. PostgreSQL to PostgreSQL) will use
|
||||
SQLite for intermediate steps.
|
||||
|
||||
In-place migration (source_db_url == dest_db_url) is not supported,
|
||||
as migrations always read from a source and write to a destination.
|
||||
|
||||
Args:
|
||||
source_db_url: The SQLAlchemy URL of the database to migrate from.
|
||||
dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be
|
||||
different from source_db_url.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If source_db_url and dest_db_url are the same, or if no
|
||||
migration path is found.
|
||||
"""
|
||||
if source_db_url == dest_db_url:
|
||||
raise RuntimeError(
|
||||
"In-place migration is not supported. "
|
||||
"Please provide a different URL for dest_db_url."
|
||||
)
|
||||
|
||||
current_version = _schema_check_utils.get_db_schema_version(source_db_url)
|
||||
if current_version == LATEST_VERSION:
|
||||
logger.info(
|
||||
f"Database {source_db_url} is already at latest version"
|
||||
f" {LATEST_VERSION}. No migration needed."
|
||||
)
|
||||
return
|
||||
|
||||
# Build the list of migration steps required to reach LATEST_VERSION.
|
||||
migrations_to_run = []
|
||||
ver = current_version
|
||||
while ver in MIGRATIONS and ver != LATEST_VERSION:
|
||||
migrations_to_run.append(MIGRATIONS[ver])
|
||||
ver = MIGRATIONS[ver][0]
|
||||
|
||||
if not migrations_to_run:
|
||||
raise RuntimeError(
|
||||
"Could not find migration path for schema version"
|
||||
f" {current_version} to {LATEST_VERSION}."
|
||||
)
|
||||
|
||||
temp_files = []
|
||||
in_url = source_db_url
|
||||
try:
|
||||
for i, (end_version, migrate_func) in enumerate(migrations_to_run):
|
||||
is_last_step = i == len(migrations_to_run) - 1
|
||||
|
||||
if is_last_step:
|
||||
out_url = dest_db_url
|
||||
else:
|
||||
# For intermediate steps, create a temporary SQLite DB to store the
|
||||
# result.
|
||||
fd, temp_path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
out_url = f"sqlite:///{temp_path}"
|
||||
temp_files.append(temp_path)
|
||||
logger.debug("Created temp db %s for step %d", out_url, i + 1)
|
||||
|
||||
logger.info(
|
||||
f"Migrating from {in_url} to {out_url} (schema v{end_version})..."
|
||||
)
|
||||
migrate_func(in_url, out_url)
|
||||
logger.info("Finished migration step to schema %s.", end_version)
|
||||
# The output of this step becomes the input for the next step.
|
||||
in_url = out_url
|
||||
finally:
|
||||
# Ensure temporary files are cleaned up even if migration fails.
|
||||
for path in temp_files:
|
||||
try:
|
||||
os.remove(path)
|
||||
logger.debug("Removed temp db %s", path)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to remove temp db file %s: %s", path, e)
|
||||
@@ -0,0 +1,106 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for migration scripts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from google.adk.events.event_actions import EventActions
|
||||
from google.adk.sessions.migration import _schema_check_utils
|
||||
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp
|
||||
from google.adk.sessions.schemas import v0
|
||||
from google.adk.sessions.schemas import v1
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
def test_migrate_from_sqlalchemy_pickle(tmp_path):
|
||||
"""Tests for migrate_from_sqlalchemy_pickle."""
|
||||
source_db_path = tmp_path / "source_pickle.db"
|
||||
dest_db_path = tmp_path / "dest_json.db"
|
||||
source_db_url = f"sqlite:///{source_db_path}"
|
||||
dest_db_url = f"sqlite:///{dest_db_path}"
|
||||
|
||||
# Set up source DB with old pickle schema
|
||||
source_engine = create_engine(source_db_url)
|
||||
v0.Base.metadata.create_all(source_engine)
|
||||
SourceSession = sessionmaker(bind=source_engine)
|
||||
source_session = SourceSession()
|
||||
|
||||
# Populate source data
|
||||
now = datetime.now(timezone.utc)
|
||||
app_state = v0.StorageAppState(
|
||||
app_name="app1", state={"akey": 1}, update_time=now
|
||||
)
|
||||
user_state = v0.StorageUserState(
|
||||
app_name="app1", user_id="user1", state={"ukey": 2}, update_time=now
|
||||
)
|
||||
session = v0.StorageSession(
|
||||
app_name="app1",
|
||||
user_id="user1",
|
||||
id="session1",
|
||||
state={"skey": 3},
|
||||
create_time=now,
|
||||
update_time=now,
|
||||
)
|
||||
event = v0.StorageEvent(
|
||||
id="event1",
|
||||
app_name="app1",
|
||||
user_id="user1",
|
||||
session_id="session1",
|
||||
invocation_id="invoke1",
|
||||
author="user",
|
||||
actions=EventActions(state_delta={"skey": 4}),
|
||||
timestamp=now,
|
||||
)
|
||||
source_session.add_all([app_state, user_state, session, event])
|
||||
source_session.commit()
|
||||
source_session.close()
|
||||
|
||||
mfsp.migrate(source_db_url, dest_db_url)
|
||||
|
||||
# Verify destination DB
|
||||
dest_engine = create_engine(dest_db_url)
|
||||
DestSession = sessionmaker(bind=dest_engine)
|
||||
dest_session = DestSession()
|
||||
|
||||
metadata = dest_session.query(v1.StorageMetadata).first()
|
||||
assert metadata is not None
|
||||
assert metadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
|
||||
assert metadata.value == _schema_check_utils.SCHEMA_VERSION_1_JSON
|
||||
|
||||
app_state_res = dest_session.query(v1.StorageAppState).first()
|
||||
assert app_state_res is not None
|
||||
assert app_state_res.app_name == "app1"
|
||||
assert app_state_res.state == {"akey": 1}
|
||||
|
||||
user_state_res = dest_session.query(v1.StorageUserState).first()
|
||||
assert user_state_res is not None
|
||||
assert user_state_res.user_id == "user1"
|
||||
assert user_state_res.state == {"ukey": 2}
|
||||
|
||||
session_res = dest_session.query(v1.StorageSession).first()
|
||||
assert session_res is not None
|
||||
assert session_res.id == "session1"
|
||||
assert session_res.state == {"skey": 3}
|
||||
|
||||
event_res = dest_session.query(v1.StorageEvent).first()
|
||||
assert event_res is not None
|
||||
assert event_res.id == "event1"
|
||||
assert "state_delta" in event_res.event_data["actions"]
|
||||
assert event_res.event_data["actions"]["state_delta"] == {"skey": 4}
|
||||
|
||||
dest_session.close()
|
||||
Reference in New Issue
Block a user