feat: Add database schema migration command and script

Final part of https://github.com/google/adk-python/discussions/3605.

This change introduces:
-   A new `adk migrate session` CLI command to run database schema upgrades.
-   A migration script to upgrade from the old Pickle-based session schema (v0) to the new JSON-based schema (v1).
-   A migration runner that orchestrates the upgrade process, handling sequential migrations and using temporary SQLite databases for intermediate steps if needed.
-   Unit tests for the v0 to v1 migration.

Co-authored-by: Liang Wu <wuliang@google.com>
PiperOrigin-RevId: 852983323
This commit is contained in:
Liang Wu
2026-01-06 16:35:30 -08:00
committed by Copybara-Service
parent 0827d12ccd
commit ce64787c3e
6 changed files with 640 additions and 15 deletions
+42
View File
@@ -36,6 +36,7 @@ from . import cli_create
from . import cli_deploy
from .. import version
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
from ..sessions.migration import migration_runner
from .cli import run_cli
from .fast_api import get_fast_api_app
from .utils import envs
@@ -1507,6 +1508,47 @@ def cli_deploy_cloud_run(
click.secho(f"Deploy failed: {e}", fg="red", err=True)
@main.group()
def migrate():
"""ADK migration commands."""
pass
@migrate.command("session", cls=HelpfulCommand)
@click.option(
"--source_db_url",
required=True,
help=(
"SQLAlchemy URL of source database in database session service, e.g."
" sqlite:///source.db."
),
)
@click.option(
"--dest_db_url",
required=True,
help=(
"SQLAlchemy URL of destination database in database session service,"
" e.g. sqlite:///dest.db."
),
)
@click.option(
"--log_level",
type=LOG_LEVELS,
default="INFO",
help="Optional. Set the logging level",
)
def cli_migrate_session(
*, source_db_url: str, dest_db_url: str, log_level: str
):
"""Migrates a session database to the latest schema version."""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
try:
migration_runner.upgrade(source_db_url, dest_db_url)
click.secho("Migration check and upgrade process finished.", fg="green")
except Exception as e:
click.secho(f"Migration failed: {e}", fg="red", err=True)
@deploy.command("agent_engine")
@click.option(
"--api_key",
@@ -178,12 +178,9 @@ class DatabaseSessionService(BaseSessionService):
self._db_schema_version = await conn.run_sync(
_schema_check_utils.get_db_schema_version_from_connection
)
except Exception:
# If inspection fails, assume the latest schema
logger.warning(
"Failed to inspect database tables, assuming the latest schema."
)
self._db_schema_version = _schema_check_utils.LATEST_SCHEMA_VERSION
except Exception as e:
logger.error("Failed to inspect database tables: %s", e)
raise
# Check if tables are created and create them if not
if self._tables_created:
@@ -16,6 +16,7 @@ from __future__ import annotations
import logging
from sqlalchemy import create_engine as create_sync_engine
from sqlalchemy import inspect
from sqlalchemy import text
@@ -38,14 +39,16 @@ def _get_schema_version_impl(inspector, connection) -> str:
if result:
return result[0]
else:
return LATEST_SCHEMA_VERSION
raise ValueError(
"Schema version not found in adk_internal_metadata. The database"
" might be malformed."
)
except Exception as e:
logger.warning(
"Failed to query schema version from adk_internal_metadata,"
" assuming the latest schema: %s.",
logger.error(
"Failed to query schema version from adk_internal_metadata: %s.",
e,
)
return LATEST_SCHEMA_VERSION
raise
# Metadata table doesn't exist, check for v0 schema.
# V0 schema has an 'events' table with an 'actions' column.
if inspector.has_table("events"):
@@ -57,13 +60,14 @@ def _get_schema_version_impl(inspector, connection) -> str:
" serialize event actions. The v0 schema will not be supported"
" going forward and will be deprecated in a few rollouts. Please"
" migrate to the v1 schema which uses JSON serialization for event"
" data. The migration command and script will be provided soon."
" data. You can use `adk migrate session` command to migrate your"
" database."
)
return SCHEMA_VERSION_0_PICKLE
except Exception as e:
logger.warning("Failed to inspect 'events' table columns: %s", e)
return LATEST_SCHEMA_VERSION
# New database, assume the latest schema.
logger.error("Failed to inspect 'events' table columns: %s", e)
raise
# New database, use the latest schema.
return LATEST_SCHEMA_VERSION
@@ -71,3 +75,42 @@ def get_db_schema_version_from_connection(connection) -> str:
"""Gets DB schema version from a DB connection."""
inspector = inspect(connection)
return _get_schema_version_impl(inspector, connection)
def _to_sync_url(db_url: str) -> str:
"""Removes '+driver' from SQLAlchemy URL."""
if "://" in db_url:
scheme, _, rest = db_url.partition("://")
if "+" in scheme:
dialect = scheme.split("+", 1)[0]
return f"{dialect}://{rest}"
return db_url
def get_db_schema_version(db_url: str) -> str:
"""Reads schema version from DB.
Checks metadata table first and then falls back to table structure.
Args:
db_url: The database URL.
Returns:
The detected schema version as a string. Returns `LATEST_SCHEMA_VERSION`
if it's a new database.
"""
engine = None
try:
engine = create_sync_engine(_to_sync_url(db_url))
with engine.connect() as connection:
inspector = inspect(connection)
return _get_schema_version_impl(inspector, connection)
except Exception:
logger.warning(
"Failed to get schema version from database %s.",
db_url,
)
raise
finally:
if engine:
engine.dispose()
@@ -0,0 +1,311 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Migration script from SQLAlchemy DB with Pickle Events to JSON schema."""
from __future__ import annotations
import argparse
from datetime import datetime
from datetime import timezone
import json
import logging
import pickle
import sys
from typing import Any
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.sessions import _session_util
from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.schemas import v1
from google.genai import types
import sqlalchemy
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy.orm import sessionmaker
logger = logging.getLogger("google_adk." + __name__)
def _to_datetime_obj(val: Any) -> datetime | Any:
"""Converts string to datetime if needed."""
if isinstance(val, str):
try:
return datetime.strptime(val, "%Y-%m-%d %H:%M:%S.%f")
except ValueError:
try:
return datetime.strptime(val, "%Y-%m-%d %H:%M:%S")
except ValueError:
pass # return as is if not matching format
return val
def _row_to_event(row: dict) -> Event:
"""Converts event row (dict) to event object, handling missing columns and deserializing."""
actions_val = row.get("actions")
actions = None
if actions_val is not None:
try:
if isinstance(actions_val, bytes):
actions = pickle.loads(actions_val)
else: # for spanner - it might return object directly
actions = actions_val
except Exception as e:
logger.warning(
f"Failed to unpickle actions for event {row.get('id')}: {e}"
)
actions = None
if actions and hasattr(actions, "model_dump"):
actions = EventActions().model_validate(actions.model_dump())
elif isinstance(actions, dict):
actions = EventActions(**actions)
else:
actions = EventActions()
def _safe_json_load(val):
data = None
if isinstance(val, str):
try:
data = json.loads(val)
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON for event {row.get('id')}")
return None
elif isinstance(val, dict):
data = val # for postgres JSONB
return data
content_dict = _safe_json_load(row.get("content"))
grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata"))
custom_metadata_dict = _safe_json_load(row.get("custom_metadata"))
usage_metadata_dict = _safe_json_load(row.get("usage_metadata"))
citation_metadata_dict = _safe_json_load(row.get("citation_metadata"))
input_transcription_dict = _safe_json_load(row.get("input_transcription"))
output_transcription_dict = _safe_json_load(row.get("output_transcription"))
long_running_tool_ids_json = row.get("long_running_tool_ids_json")
long_running_tool_ids = set()
if long_running_tool_ids_json:
try:
long_running_tool_ids = set(json.loads(long_running_tool_ids_json))
except json.JSONDecodeError:
logger.warning(
"Failed to decode long_running_tool_ids_json for event"
f" {row.get('id')}"
)
long_running_tool_ids = set()
event_id = row.get("id")
if not event_id:
raise ValueError("Event must have an id.")
timestamp = _to_datetime_obj(row.get("timestamp"))
if not timestamp:
raise ValueError(f"Event {event_id} must have a timestamp.")
return Event(
id=event_id,
invocation_id=row.get("invocation_id", ""),
author=row.get("author", "agent"),
branch=row.get("branch"),
actions=actions,
timestamp=timestamp.replace(tzinfo=timezone.utc).timestamp(),
long_running_tool_ids=long_running_tool_ids,
partial=row.get("partial"),
turn_complete=row.get("turn_complete"),
error_code=row.get("error_code"),
error_message=row.get("error_message"),
interrupted=row.get("interrupted"),
custom_metadata=custom_metadata_dict,
content=_session_util.decode_model(content_dict, types.Content),
grounding_metadata=_session_util.decode_model(
grounding_metadata_dict, types.GroundingMetadata
),
usage_metadata=_session_util.decode_model(
usage_metadata_dict, types.GenerateContentResponseUsageMetadata
),
citation_metadata=_session_util.decode_model(
citation_metadata_dict, types.CitationMetadata
),
input_transcription=_session_util.decode_model(
input_transcription_dict, types.Transcription
),
output_transcription=_session_util.decode_model(
output_transcription_dict, types.Transcription
),
)
def _get_state_dict(state_val: Any) -> dict:
"""Safely load dict from JSON string or return dict if already dict."""
if isinstance(state_val, dict):
return state_val
if isinstance(state_val, str):
try:
return json.loads(state_val)
except json.JSONDecodeError:
logger.warning(
"Failed to parse state JSON string, defaulting to empty dict."
)
return {}
return {}
# --- Migration Logic ---
def migrate(source_db_url: str, dest_db_url: str):
"""Migrates data from old pickle schema to new JSON schema."""
logger.info(f"Connecting to source database: {source_db_url}")
try:
source_engine = create_engine(source_db_url)
SourceSession = sessionmaker(bind=source_engine)
except Exception as e:
logger.error(f"Failed to connect to source database: {e}")
raise RuntimeError(f"Failed to connect to source database: {e}") from e
logger.info(f"Connecting to destination database: {dest_db_url}")
try:
dest_engine = create_engine(dest_db_url)
v1.Base.metadata.create_all(dest_engine)
DestSession = sessionmaker(bind=dest_engine)
except Exception as e:
logger.error(f"Failed to connect to destination database: {e}")
raise RuntimeError(f"Failed to connect to destination database: {e}") from e
with SourceSession() as source_session, DestSession() as dest_session:
try:
dest_session.merge(
v1.StorageMetadata(
key=_schema_check_utils.SCHEMA_VERSION_KEY,
value=_schema_check_utils.SCHEMA_VERSION_1_JSON,
)
)
logger.info("Created metadata table in destination database.")
inspector = sqlalchemy.inspect(source_engine)
logger.info("Migrating app_states...")
if inspector.has_table("app_states"):
num_rows = 0
for row in source_session.execute(
text("SELECT * FROM app_states")
).mappings():
num_rows += 1
dest_session.merge(
v1.StorageAppState(
app_name=row["app_name"],
state=_get_state_dict(row.get("state")),
update_time=_to_datetime_obj(row["update_time"]),
)
)
logger.info(f"Migrated {num_rows} app_states.")
else:
logger.info("No 'app_states' table found in source db.")
logger.info("Migrating user_states...")
if inspector.has_table("user_states"):
num_rows = 0
for row in source_session.execute(
text("SELECT * FROM user_states")
).mappings():
num_rows += 1
dest_session.merge(
v1.StorageUserState(
app_name=row["app_name"],
user_id=row["user_id"],
state=_get_state_dict(row.get("state")),
update_time=_to_datetime_obj(row["update_time"]),
)
)
logger.info(f"Migrated {num_rows} user_states.")
else:
logger.info("No 'user_states' table found in source db.")
logger.info("Migrating sessions...")
if inspector.has_table("sessions"):
num_rows = 0
for row in source_session.execute(
text("SELECT * FROM sessions")
).mappings():
num_rows += 1
dest_session.merge(
v1.StorageSession(
app_name=row["app_name"],
user_id=row["user_id"],
id=row["id"],
state=_get_state_dict(row.get("state")),
create_time=_to_datetime_obj(row["create_time"]),
update_time=_to_datetime_obj(row["update_time"]),
)
)
logger.info(f"Migrated {num_rows} sessions.")
else:
logger.info("No 'sessions' table found in source db.")
logger.info("Migrating events...")
num_rows = 0
if inspector.has_table("events"):
for row in source_session.execute(
text("SELECT * FROM events")
).mappings():
try:
event_obj = _row_to_event(dict(row))
new_event = v1.StorageEvent(
id=event_obj.id,
app_name=row["app_name"],
user_id=row["user_id"],
session_id=row["session_id"],
invocation_id=event_obj.invocation_id,
timestamp=datetime.fromtimestamp(
event_obj.timestamp, timezone.utc
).replace(tzinfo=None),
event_data=event_obj.model_dump(mode="json", exclude_none=True),
)
dest_session.merge(new_event)
num_rows += 1
except Exception as e:
logger.warning(
f"Failed to migrate event row {row.get('id', 'N/A')}: {e}"
)
logger.info(f"Migrated {num_rows} events.")
else:
logger.info("No 'events' table found in source database.")
dest_session.commit()
logger.info("Migration completed successfully.")
except Exception as e:
logger.error(f"An error occurred during migration: {e}", exc_info=True)
dest_session.rollback()
raise RuntimeError(f"An error occurred during migration: {e}") from e
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"Migrate ADK sessions from SQLAlchemy Pickle format to JSON format."
)
)
parser.add_argument(
"--source_db_url", required=True, help="SQLAlchemy URL of source database"
)
parser.add_argument(
"--dest_db_url",
required=True,
help="SQLAlchemy URL of destination database",
)
args = parser.parse_args()
try:
migrate(args.source_db_url, args.dest_db_url)
except Exception as e:
logger.error(f"Migration failed: {e}")
sys.exit(1)
@@ -0,0 +1,126 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Migration runner to upgrade schemas to the latest version."""
from __future__ import annotations
import logging
import os
import tempfile
from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle
logger = logging.getLogger("google_adk." + __name__)
# Migration map where key is start_version and value is
# (end_version, migration_function).
# Each key is a schema version, and its value is a tuple containing:
# (the schema version AFTER this migration step, the migration function to run).
# The migration function should accept (source_db_url, dest_db_url) as
# arguments.
MIGRATIONS = {
_schema_check_utils.SCHEMA_VERSION_0_PICKLE: (
_schema_check_utils.SCHEMA_VERSION_1_JSON,
migrate_from_sqlalchemy_pickle.migrate,
),
}
# The most recent schema version. The migration process stops once this version
# is reached.
LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION
def upgrade(source_db_url: str, dest_db_url: str):
"""Migrates a database from its current version to the latest version.
If the source database schema is older than the latest version, this
function applies migration scripts sequentially until the schema reaches the
LATEST_VERSION.
If multiple migration steps are required, intermediate results are stored in
temporary SQLite database files. This means a multi-step migration
between other database types (e.g. PostgreSQL to PostgreSQL) will use
SQLite for intermediate steps.
In-place migration (source_db_url == dest_db_url) is not supported,
as migrations always read from a source and write to a destination.
Args:
source_db_url: The SQLAlchemy URL of the database to migrate from.
dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be
different from source_db_url.
Raises:
RuntimeError: If source_db_url and dest_db_url are the same, or if no
migration path is found.
"""
if source_db_url == dest_db_url:
raise RuntimeError(
"In-place migration is not supported. "
"Please provide a different URL for dest_db_url."
)
current_version = _schema_check_utils.get_db_schema_version(source_db_url)
if current_version == LATEST_VERSION:
logger.info(
f"Database {source_db_url} is already at latest version"
f" {LATEST_VERSION}. No migration needed."
)
return
# Build the list of migration steps required to reach LATEST_VERSION.
migrations_to_run = []
ver = current_version
while ver in MIGRATIONS and ver != LATEST_VERSION:
migrations_to_run.append(MIGRATIONS[ver])
ver = MIGRATIONS[ver][0]
if not migrations_to_run:
raise RuntimeError(
"Could not find migration path for schema version"
f" {current_version} to {LATEST_VERSION}."
)
temp_files = []
in_url = source_db_url
try:
for i, (end_version, migrate_func) in enumerate(migrations_to_run):
is_last_step = i == len(migrations_to_run) - 1
if is_last_step:
out_url = dest_db_url
else:
# For intermediate steps, create a temporary SQLite DB to store the
# result.
fd, temp_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
out_url = f"sqlite:///{temp_path}"
temp_files.append(temp_path)
logger.debug("Created temp db %s for step %d", out_url, i + 1)
logger.info(
f"Migrating from {in_url} to {out_url} (schema v{end_version})..."
)
migrate_func(in_url, out_url)
logger.info("Finished migration step to schema %s.", end_version)
# The output of this step becomes the input for the next step.
in_url = out_url
finally:
# Ensure temporary files are cleaned up even if migration fails.
for path in temp_files:
try:
os.remove(path)
logger.debug("Removed temp db %s", path)
except OSError as e:
logger.warning("Failed to remove temp db file %s: %s", path, e)
@@ -0,0 +1,106 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for migration scripts."""
from __future__ import annotations
from datetime import datetime
from datetime import timezone
from google.adk.events.event_actions import EventActions
from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp
from google.adk.sessions.schemas import v0
from google.adk.sessions.schemas import v1
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
def test_migrate_from_sqlalchemy_pickle(tmp_path):
"""Tests for migrate_from_sqlalchemy_pickle."""
source_db_path = tmp_path / "source_pickle.db"
dest_db_path = tmp_path / "dest_json.db"
source_db_url = f"sqlite:///{source_db_path}"
dest_db_url = f"sqlite:///{dest_db_path}"
# Set up source DB with old pickle schema
source_engine = create_engine(source_db_url)
v0.Base.metadata.create_all(source_engine)
SourceSession = sessionmaker(bind=source_engine)
source_session = SourceSession()
# Populate source data
now = datetime.now(timezone.utc)
app_state = v0.StorageAppState(
app_name="app1", state={"akey": 1}, update_time=now
)
user_state = v0.StorageUserState(
app_name="app1", user_id="user1", state={"ukey": 2}, update_time=now
)
session = v0.StorageSession(
app_name="app1",
user_id="user1",
id="session1",
state={"skey": 3},
create_time=now,
update_time=now,
)
event = v0.StorageEvent(
id="event1",
app_name="app1",
user_id="user1",
session_id="session1",
invocation_id="invoke1",
author="user",
actions=EventActions(state_delta={"skey": 4}),
timestamp=now,
)
source_session.add_all([app_state, user_state, session, event])
source_session.commit()
source_session.close()
mfsp.migrate(source_db_url, dest_db_url)
# Verify destination DB
dest_engine = create_engine(dest_db_url)
DestSession = sessionmaker(bind=dest_engine)
dest_session = DestSession()
metadata = dest_session.query(v1.StorageMetadata).first()
assert metadata is not None
assert metadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
assert metadata.value == _schema_check_utils.SCHEMA_VERSION_1_JSON
app_state_res = dest_session.query(v1.StorageAppState).first()
assert app_state_res is not None
assert app_state_res.app_name == "app1"
assert app_state_res.state == {"akey": 1}
user_state_res = dest_session.query(v1.StorageUserState).first()
assert user_state_res is not None
assert user_state_res.user_id == "user1"
assert user_state_res.state == {"ukey": 2}
session_res = dest_session.query(v1.StorageSession).first()
assert session_res is not None
assert session_res.id == "session1"
assert session_res.state == {"skey": 3}
event_res = dest_session.query(v1.StorageEvent).first()
assert event_res is not None
assert event_res.id == "event1"
assert "state_delta" in event_res.event_data["actions"]
assert event_res.event_data["actions"]["state_delta"] == {"skey": 4}
dest_session.close()