feat!: Migrate DatabaseSessionService to use JSON serialization schema

Also provide a command line tool `adk migrate session` for DB migration

Addresses https://github.com/google/adk-python/discussions/3605
Addresses https://github.com/google/adk-python/issues/3681

To verify:
```
# Start one postgres DB
docker run --name my-postgres -d -e POSTGRES_DB=agent  -e POSTGRES_USER=agent  -e POSTGRES_PASSWORD=agent  -e PGDATA=/var/lib/postgresql/data/pgdata -v pgvolume:/var/lib/postgresql/data -p 5532:5432 postgres

# Connect to an old version of ADK and produce some query data
adk web --session_service_uri=postgresql://agent:agent@localhost:5532/agent

# Check out to the latest branch and restart ADK web
# You should see error log ask you to migrate the DB

# Start a new DB
docker run --name migration-test-db \
  -d \ --rm \ -e POSTGRES_DB=agent \ -e POSTGRES_USER=agent \ -e POSTGRES_PASSWORD=agent -e PGDATA=/var/lib/postgresql/data/pgdata -v migration_test_vol:/var/lib/postgresql/data -p 5533:5432 postgres

# DB Migration
adk migrate session \
  --source_db_url="postgresql://agent:agent@localhost:5532/agent" \
  --dest_db_url="postgresql://agent:agent@localhost:5533/agent"

# Run ADK web with the new DB
adk web --session_service_uri=postgresql+asyncpg://agent:agent@localhost:5533/agent

# You should see the data from old DB is migrated
```

Co-authored-by: Shangjie Chen <deanchen@google.com>
PiperOrigin-RevId: 837341139
This commit is contained in:
Shangjie Chen
2025-11-26 19:48:34 -08:00
committed by Copybara-Service
parent 786aaed335
commit 0094eea3ca
9 changed files with 939 additions and 342 deletions
+36
View File
@@ -36,6 +36,7 @@ from . import cli_create
from . import cli_deploy
from .. import version
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
from ..sessions.migration import migration_runner
from .cli import run_cli
from .fast_api import get_fast_api_app
from .utils import envs
@@ -1485,6 +1486,41 @@ def cli_deploy_cloud_run(
click.secho(f"Deploy failed: {e}", fg="red", err=True)
@main.group()
def migrate():
"""Migrate ADK database schemas."""
pass
@migrate.command("session", cls=HelpfulCommand)
@click.option(
"--source_db_url",
required=True,
help="SQLAlchemy URL of source database.",
)
@click.option(
"--dest_db_url",
required=True,
help="SQLAlchemy URL of destination database.",
)
@click.option(
"--log_level",
type=LOG_LEVELS,
default="INFO",
help="Optional. Set the logging level",
)
def cli_migrate_session(
*, source_db_url: str, dest_db_url: str, log_level: str
):
"""Migrates a session database to the latest schema version."""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
try:
migration_runner.upgrade(source_db_url, dest_db_url)
click.secho("Migration check and upgrade process finished.", fg="green")
except Exception as e:
click.secho(f"Migration failed: {e}", fg="red", err=True)
@deploy.command("agent_engine")
@click.option(
"--api_key",
@@ -19,18 +19,16 @@ from datetime import datetime
from datetime import timezone
import json
import logging
import pickle
from typing import Any
from typing import Optional
import uuid
from google.genai import types
from sqlalchemy import Boolean
from sqlalchemy import delete
from sqlalchemy import Dialect
from sqlalchemy import event
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import inspect
from sqlalchemy import select
from sqlalchemy import Text
from sqlalchemy.dialects import mysql
@@ -41,14 +39,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.schema import MetaData
from sqlalchemy.types import DateTime
from sqlalchemy.types import PickleType
from sqlalchemy.types import String
from sqlalchemy.types import TypeDecorator
from typing_extensions import override
@@ -57,10 +52,10 @@ from tzlocal import get_localzone
from . import _session_util
from ..errors.already_exists_error import AlreadyExistsError
from ..events.event import Event
from ..events.event_actions import EventActions
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsResponse
from .migration import _schema_check
from .session import Session
from .state import State
@@ -111,41 +106,22 @@ class PreciseTimestamp(TypeDecorator):
return self.impl
class DynamicPickleType(TypeDecorator):
"""Represents a type that can be pickled."""
impl = PickleType
def load_dialect_impl(self, dialect):
if dialect.name == "mysql":
return dialect.type_descriptor(mysql.LONGBLOB)
if dialect.name == "spanner+spanner":
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType
return dialect.type_descriptor(SpannerPickleType)
return self.impl
def process_bind_param(self, value, dialect):
"""Ensures the pickled value is a bytes object before passing it to the database dialect."""
if value is not None:
if dialect.name in ("spanner+spanner", "mysql"):
return pickle.dumps(value)
return value
def process_result_value(self, value, dialect):
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
if value is not None:
if dialect.name in ("spanner+spanner", "mysql"):
return pickle.loads(value)
return value
class Base(DeclarativeBase):
"""Base class for database tables."""
pass
class StorageMetadata(Base):
"""Represents internal metadata stored in the database."""
__tablename__ = "adk_internal_metadata"
key: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
class StorageSession(Base):
"""Represents a session stored in the database."""
@@ -237,46 +213,10 @@ class StorageEvent(Base):
)
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
Text, nullable=True
)
branch: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
timestamp: Mapped[PreciseTimestamp] = mapped_column(
PreciseTimestamp, default=func.now()
)
# === Fields from llm_response.py ===
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
input_transcription: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
output_transcription: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
storage_session: Mapped[StorageSession] = relationship(
"StorageSession",
@@ -291,102 +231,27 @@ class StorageEvent(Base):
),
)
@property
def long_running_tool_ids(self) -> set[str]:
return (
set(json.loads(self.long_running_tool_ids_json))
if self.long_running_tool_ids_json
else set()
)
@long_running_tool_ids.setter
def long_running_tool_ids(self, value: set[str]):
if value is None:
self.long_running_tool_ids_json = None
else:
self.long_running_tool_ids_json = json.dumps(list(value))
@classmethod
def from_event(cls, session: Session, event: Event) -> StorageEvent:
storage_event = StorageEvent(
"""Creates a StorageEvent from an Event."""
return StorageEvent(
id=event.id,
invocation_id=event.invocation_id,
author=event.author,
branch=event.branch,
actions=event.actions,
session_id=session.id,
app_name=session.app_name,
user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp),
long_running_tool_ids=event.long_running_tool_ids,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
error_message=event.error_message,
interrupted=event.interrupted,
event_data=event.model_dump(exclude_none=True, mode="json"),
)
if event.content:
storage_event.content = event.content.model_dump(
exclude_none=True, mode="json"
)
if event.grounding_metadata:
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
exclude_none=True, mode="json"
)
if event.custom_metadata:
storage_event.custom_metadata = event.custom_metadata
if event.usage_metadata:
storage_event.usage_metadata = event.usage_metadata.model_dump(
exclude_none=True, mode="json"
)
if event.citation_metadata:
storage_event.citation_metadata = event.citation_metadata.model_dump(
exclude_none=True, mode="json"
)
if event.input_transcription:
storage_event.input_transcription = event.input_transcription.model_dump(
exclude_none=True, mode="json"
)
if event.output_transcription:
storage_event.output_transcription = (
event.output_transcription.model_dump(exclude_none=True, mode="json")
)
return storage_event
def to_event(self) -> Event:
return Event(
id=self.id,
invocation_id=self.invocation_id,
author=self.author,
branch=self.branch,
# This is needed as previous ADK version pickled actions might not have
# value defined in the current version of the EventActions model.
actions=EventActions().model_copy(update=self.actions.model_dump()),
timestamp=self.timestamp.timestamp(),
long_running_tool_ids=self.long_running_tool_ids,
partial=self.partial,
turn_complete=self.turn_complete,
error_code=self.error_code,
error_message=self.error_message,
interrupted=self.interrupted,
custom_metadata=self.custom_metadata,
content=_session_util.decode_model(self.content, types.Content),
grounding_metadata=_session_util.decode_model(
self.grounding_metadata, types.GroundingMetadata
),
usage_metadata=_session_util.decode_model(
self.usage_metadata, types.GenerateContentResponseUsageMetadata
),
citation_metadata=_session_util.decode_model(
self.citation_metadata, types.CitationMetadata
),
input_transcription=_session_util.decode_model(
self.input_transcription, types.Transcription
),
output_transcription=_session_util.decode_model(
self.output_transcription, types.Transcription
),
)
"""Converts the StorageEvent to an Event."""
return Event.model_validate({
**self.event_data,
"id": self.id,
"invocation_id": self.invocation_id,
"timestamp": self.timestamp.timestamp(),
})
class StorageAppState(Base):
@@ -463,7 +328,6 @@ class DatabaseSessionService(BaseSessionService):
logger.info("Local timezone: %s", local_timezone)
self.db_engine: AsyncEngine = db_engine
self.metadata: MetaData = MetaData()
# DB session factory method
self.database_session_factory: async_sessionmaker[
@@ -483,10 +347,46 @@ class DatabaseSessionService(BaseSessionService):
async with self._table_creation_lock:
# Double-check after acquiring the lock
if not self._tables_created:
# Check schema version BEFORE creating tables.
# This prevents creating metadata table on a v0.1 DB.
async with self.database_session_factory() as sql_session:
version, is_v01 = await sql_session.run_sync(
_schema_check.get_version_and_v01_status_sync
)
if is_v01:
raise RuntimeError(
"Database schema appears to be v0.1, but"
f" {_schema_check.CURRENT_SCHEMA_VERSION} is required. Please"
" migrate the database using 'adk migrate session'."
)
elif version and version < _schema_check.CURRENT_SCHEMA_VERSION:
raise RuntimeError(
f"Database schema version is {version}, but current version is"
f" {_schema_check.CURRENT_SCHEMA_VERSION}. Please migrate"
" the database to the latest version using 'adk migrate"
" session'."
)
async with self.db_engine.begin() as conn:
# Uncomment to recreate DB every time
# await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
# If we are here, DB is either new or >= current version.
# If new or without metadata row, stamp it as current version.
async with self.database_session_factory() as sql_session:
metadata = await sql_session.get(
StorageMetadata, _schema_check.SCHEMA_VERSION_KEY
)
if not metadata:
sql_session.add(
StorageMetadata(
key=_schema_check.SCHEMA_VERSION_KEY,
value=_schema_check.CURRENT_SCHEMA_VERSION,
)
)
await sql_session.commit()
self._tables_created = True
@override
@@ -723,7 +623,9 @@ class DatabaseSessionService(BaseSessionService):
storage_session.state = storage_session.state | session_state_delta
if storage_session._dialect_name == "sqlite":
update_time = datetime.utcfromtimestamp(event.timestamp)
update_time = datetime.fromtimestamp(
event.timestamp, timezone.utc
).replace(tzinfo=None)
else:
update_time = datetime.fromtimestamp(event.timestamp)
storage_session.update_time = update_time
@@ -0,0 +1,114 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Database schema version check utility."""
from __future__ import annotations
import logging
import sqlalchemy
from sqlalchemy import create_engine as create_sync_engine
from sqlalchemy import inspect
from sqlalchemy import text
logger = logging.getLogger("google_adk." + __name__)
SCHEMA_VERSION_KEY = "schema_version"
SCHEMA_VERSION_0_1_PICKLE = "0.1"
SCHEMA_VERSION_1_0_JSON = "1.0"
CURRENT_SCHEMA_VERSION = "1.0"
def _to_sync_url(db_url: str) -> str:
"""Removes +driver from SQLAlchemy URL."""
if "://" in db_url:
scheme, _, rest = db_url.partition("://")
if "+" in scheme:
dialect = scheme.split("+", 1)[0]
return f"{dialect}://{rest}"
return db_url
def get_version_and_v01_status_sync(
sess: sqlalchemy.orm.Session,
) -> tuple[str | None, bool]:
"""Returns (version, is_v01) inspecting the database."""
inspector = sqlalchemy.inspect(sess.get_bind())
if inspector.has_table("adk_internal_metadata"):
try:
result = sess.execute(
text("SELECT value FROM adk_internal_metadata WHERE key = :key"),
{"key": SCHEMA_VERSION_KEY},
).fetchone()
# If table exists, with or without key, it's 1.0 or newer.
return (result[0] if result else SCHEMA_VERSION_1_0_JSON), False
except Exception as e:
logger.warning(
"Could not read from adk_internal_metadata: %s. Assuming v1.0.",
e,
)
return SCHEMA_VERSION_1_0_JSON, False
if inspector.has_table("events"):
try:
cols = {c["name"] for c in inspector.get_columns("events")}
if "actions" in cols and "event_data" not in cols:
return None, True # 0.1 schema
except Exception as e:
logger.warning("Could not inspect 'events' table columns: %s", e)
return None, False # New DB
def get_db_schema_version(db_url: str) -> str | None:
"""Reads schema version from DB.
Checks metadata table first, falls back to table structure for 0.1 vs 1.0.
"""
engine = None
try:
engine = create_sync_engine(_to_sync_url(db_url))
inspector = inspect(engine)
if inspector.has_table("adk_internal_metadata"):
with engine.connect() as connection:
result = connection.execute(
text("SELECT value FROM adk_internal_metadata WHERE key = :key"),
parameters={"key": SCHEMA_VERSION_KEY},
).fetchone()
# If table exists, with or without key, it's 1.0 or newer.
return result[0] if result else SCHEMA_VERSION_1_0_JSON
# Metadata table doesn't exist, check for 0.1 schema.
# 0.1 schema has an 'events' table with an 'actions' column.
if inspector.has_table("events"):
try:
cols = {c["name"] for c in inspector.get_columns("events")}
if "actions" in cols and "event_data" not in cols:
return SCHEMA_VERSION_0_1_PICKLE
except Exception as e:
logger.warning("Could not inspect 'events' table columns: %s", e)
# If no metadata table and not identifiable as 0.1,
# assume it is a new/empty DB requiring schema 1.0.
return SCHEMA_VERSION_1_0_JSON
except Exception as e:
logger.info(
"Could not determine schema version by inspecting database: %s."
" Assuming v1.0.",
e,
)
return SCHEMA_VERSION_1_0_JSON
finally:
if engine:
engine.dispose()
@@ -0,0 +1,492 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Migration script from SQLAlchemy DB with Pickle Events to JSON schema."""
from __future__ import annotations
import argparse
from datetime import datetime
from datetime import timezone
import json
import logging
import pickle
import sys
from typing import Any
from typing import Optional
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.sessions import _session_util
from google.adk.sessions import database_session_service as dss
from google.adk.sessions.migration import _schema_check
from google.genai import types
import sqlalchemy
from sqlalchemy import Boolean
from sqlalchemy import create_engine
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import text
from sqlalchemy import Text
from sqlalchemy.dialects import mysql
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import sessionmaker
from sqlalchemy.types import PickleType
from sqlalchemy.types import String
from sqlalchemy.types import TypeDecorator
logger = logging.getLogger("google_adk." + __name__)
# --- Old Schema Definitions ---
class DynamicPickleType(TypeDecorator):
"""Represents a type that can be pickled."""
impl = PickleType
def load_dialect_impl(self, dialect):
if dialect.name == "mysql":
return dialect.type_descriptor(mysql.LONGBLOB)
if dialect.name == "spanner+spanner":
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType
return dialect.type_descriptor(SpannerPickleType)
return self.impl
def process_bind_param(self, value, dialect):
"""Ensures the pickled value is a bytes object before passing it to the database dialect."""
if value is not None:
if dialect.name in ("spanner+spanner", "mysql"):
return pickle.dumps(value)
return value
def process_result_value(self, value, dialect):
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
if value is not None:
if dialect.name in ("spanner+spanner", "mysql"):
return pickle.loads(value)
return value
class OldBase(DeclarativeBase):
pass
class OldStorageSession(OldBase):
__tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
id: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(dss.DynamicJSON), default={}
)
create_time: Mapped[datetime] = mapped_column(
dss.PreciseTimestamp, default=func.now()
)
update_time: Mapped[datetime] = mapped_column(
dss.PreciseTimestamp, default=func.now(), onupdate=func.now()
)
class OldStorageEvent(OldBase):
"""Old storage event with pickle."""
__tablename__ = "events"
id: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
app_name: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
session_id: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
invocation_id: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_VARCHAR_LENGTH)
)
author: Mapped[str] = mapped_column(String(dss.DEFAULT_MAX_VARCHAR_LENGTH))
actions: Mapped[Any] = mapped_column(DynamicPickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
Text, nullable=True
)
branch: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
timestamp: Mapped[datetime] = mapped_column(
dss.PreciseTimestamp, default=func.now()
)
content: Mapped[dict[str, Any]] = mapped_column(
dss.DynamicJSON, nullable=True
)
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
dss.DynamicJSON, nullable=True
)
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
dss.DynamicJSON, nullable=True
)
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
dss.DynamicJSON, nullable=True
)
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
dss.DynamicJSON, nullable=True
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
input_transcription: Mapped[dict[str, Any]] = mapped_column(
dss.DynamicJSON, nullable=True
)
output_transcription: Mapped[dict[str, Any]] = mapped_column(
dss.DynamicJSON, nullable=True
)
__table_args__ = (
ForeignKeyConstraint(
["app_name", "user_id", "session_id"],
["sessions.app_name", "sessions.user_id", "sessions.id"],
ondelete="CASCADE",
),
)
@property
def long_running_tool_ids(self) -> set[str]:
return (
set(json.loads(self.long_running_tool_ids_json))
if self.long_running_tool_ids_json
else set()
)
def _to_datetime_obj(val: Any) -> datetime | Any:
"""Converts string to datetime if needed."""
if isinstance(val, str):
try:
return datetime.strptime(val, "%Y-%m-%d %H:%M:%S.%f")
except ValueError:
try:
return datetime.strptime(val, "%Y-%m-%d %H:%M:%S")
except ValueError:
pass # return as is if not matching format
return val
def _row_to_event(row: dict) -> Event:
"""Converts event row (dict) to event object, handling missing columns and deserializing."""
actions_val = row.get("actions")
actions = None
if actions_val is not None:
try:
if isinstance(actions_val, bytes):
actions = pickle.loads(actions_val)
else: # for spanner - it might return object directly
actions = actions_val
except Exception as e:
logger.warning(
f"Failed to unpickle actions for event {row.get('id')}: {e}"
)
actions = None
if actions and hasattr(actions, "model_dump"):
actions = EventActions().model_copy(update=actions.model_dump())
elif isinstance(actions, dict):
actions = EventActions(**actions)
else:
actions = EventActions()
def _safe_json_load(val):
data = None
if isinstance(val, str):
try:
data = json.loads(val)
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON for event {row.get('id')}")
return None
elif isinstance(val, dict):
data = val # for postgres JSONB
return data
content_dict = _safe_json_load(row.get("content"))
grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata"))
custom_metadata_dict = _safe_json_load(row.get("custom_metadata"))
usage_metadata_dict = _safe_json_load(row.get("usage_metadata"))
citation_metadata_dict = _safe_json_load(row.get("citation_metadata"))
input_transcription_dict = _safe_json_load(row.get("input_transcription"))
output_transcription_dict = _safe_json_load(row.get("output_transcription"))
long_running_tool_ids_json = row.get("long_running_tool_ids_json")
long_running_tool_ids = set()
if long_running_tool_ids_json:
try:
long_running_tool_ids = set(json.loads(long_running_tool_ids_json))
except json.JSONDecodeError:
logger.warning(
"Failed to decode long_running_tool_ids_json for event"
f" {row.get('id')}"
)
long_running_tool_ids = set()
event_id = row.get("id")
if not event_id:
raise ValueError("Event must have an id.")
timestamp = _to_datetime_obj(row.get("timestamp"))
if not timestamp:
raise ValueError(f"Event {event_id} must have a timestamp.")
return Event(
id=event_id,
invocation_id=row.get("invocation_id", ""),
author=row.get("author", "agent"),
branch=row.get("branch"),
actions=actions,
timestamp=timestamp.replace(tzinfo=timezone.utc).timestamp(),
long_running_tool_ids=long_running_tool_ids,
partial=row.get("partial"),
turn_complete=row.get("turn_complete"),
error_code=row.get("error_code"),
error_message=row.get("error_message"),
interrupted=row.get("interrupted"),
custom_metadata=custom_metadata_dict,
content=_session_util.decode_model(content_dict, types.Content),
grounding_metadata=_session_util.decode_model(
grounding_metadata_dict, types.GroundingMetadata
),
usage_metadata=_session_util.decode_model(
usage_metadata_dict, types.GenerateContentResponseUsageMetadata
),
citation_metadata=_session_util.decode_model(
citation_metadata_dict, types.CitationMetadata
),
input_transcription=_session_util.decode_model(
input_transcription_dict, types.Transcription
),
output_transcription=_session_util.decode_model(
output_transcription_dict, types.Transcription
),
)
def _get_state_dict(state_val: Any) -> dict:
"""Safely load dict from JSON string or return dict if already dict."""
if isinstance(state_val, dict):
return state_val
if isinstance(state_val, str):
try:
return json.loads(state_val)
except json.JSONDecodeError:
logger.warning(
"Failed to parse state JSON string, defaulting to empty dict."
)
return {}
return {}
class OldStorageAppState(OldBase):
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(dss.DynamicJSON), default={}
)
update_time: Mapped[datetime] = mapped_column(
dss.PreciseTimestamp, default=func.now(), onupdate=func.now()
)
class OldStorageUserState(OldBase):
__tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(dss.DynamicJSON), default={}
)
update_time: Mapped[datetime] = mapped_column(
dss.PreciseTimestamp, default=func.now(), onupdate=func.now()
)
# --- Migration Logic ---
def migrate(source_db_url: str, dest_db_url: str):
"""Migrates data from old pickle schema to new JSON schema."""
logger.info(f"Connecting to source database: {source_db_url}")
try:
source_engine = create_engine(source_db_url)
SourceSession = sessionmaker(bind=source_engine)
except Exception as e:
logger.error(f"Failed to connect to source database: {e}")
raise RuntimeError(f"Failed to connect to source database: {e}") from e
logger.info(f"Connecting to destination database: {dest_db_url}")
try:
dest_engine = create_engine(dest_db_url)
dss.Base.metadata.create_all(dest_engine)
DestSession = sessionmaker(bind=dest_engine)
except Exception as e:
logger.error(f"Failed to connect to destination database: {e}")
raise RuntimeError(f"Failed to connect to destination database: {e}") from e
with SourceSession() as source_session, DestSession() as dest_session:
dest_session.merge(
dss.StorageMetadata(
key=_schema_check.SCHEMA_VERSION_KEY,
value=_schema_check.SCHEMA_VERSION_1_0_JSON,
)
)
dest_session.commit()
try:
inspector = sqlalchemy.inspect(source_engine)
logger.info("Migrating app_states...")
if inspector.has_table("app_states"):
rows = (
source_session.execute(text("SELECT * FROM app_states"))
.mappings()
.all()
)
for row in rows:
dest_session.merge(
dss.StorageAppState(
app_name=row["app_name"],
state=_get_state_dict(row.get("state")),
update_time=_to_datetime_obj(row["update_time"]),
)
)
dest_session.commit()
logger.info(f"Migrated {len(rows)} app_states.")
else:
logger.info("No 'app_states' table found in source db.")
logger.info("Migrating user_states...")
if inspector.has_table("user_states"):
rows = (
source_session.execute(text("SELECT * FROM user_states"))
.mappings()
.all()
)
for row in rows:
dest_session.merge(
dss.StorageUserState(
app_name=row["app_name"],
user_id=row["user_id"],
state=_get_state_dict(row.get("state")),
update_time=_to_datetime_obj(row["update_time"]),
)
)
dest_session.commit()
logger.info(f"Migrated {len(rows)} user_states.")
else:
logger.info("No 'user_states' table found in source db.")
logger.info("Migrating sessions...")
if inspector.has_table("sessions"):
rows = (
source_session.execute(text("SELECT * FROM sessions"))
.mappings()
.all()
)
for row in rows:
dest_session.merge(
dss.StorageSession(
app_name=row["app_name"],
user_id=row["user_id"],
id=row["id"],
state=_get_state_dict(row.get("state")),
create_time=_to_datetime_obj(row["create_time"]),
update_time=_to_datetime_obj(row["update_time"]),
)
)
dest_session.commit()
logger.info(f"Migrated {len(rows)} sessions.")
else:
logger.info("No 'sessions' table found in source db.")
logger.info("Migrating events...")
events = []
if inspector.has_table("events"):
rows = (
source_session.execute(text("SELECT * FROM events"))
.mappings()
.all()
)
for row in rows:
try:
event_obj = _row_to_event(dict(row))
new_event = dss.StorageEvent(
id=event_obj.id,
app_name=row["app_name"],
user_id=row["user_id"],
session_id=row["session_id"],
invocation_id=event_obj.invocation_id,
timestamp=datetime.fromtimestamp(
event_obj.timestamp, timezone.utc
).replace(tzinfo=None),
event_data=event_obj.model_dump(mode="json", exclude_none=True),
)
dest_session.merge(new_event)
events.append(new_event)
except Exception as e:
logger.warning(
f"Failed to migrate event row {row.get('id', 'N/A')}: {e}"
)
dest_session.commit()
logger.info(f"Migrated {len(events)} events.")
else:
logger.info("No 'events' table found in source database.")
logger.info("Migration completed successfully.")
except Exception as e:
logger.error(f"An error occurred during migration: {e}", exc_info=True)
dest_session.rollback()
raise RuntimeError(f"An error occurred during migration: {e}") from e
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"Migrate ADK sessions from SQLAlchemy Pickle format to JSON format."
)
)
parser.add_argument(
"--source_db_url", required=True, help="SQLAlchemy URL of source database"
)
parser.add_argument(
"--dest_db_url",
required=True,
help="SQLAlchemy URL of destination database",
)
args = parser.parse_args()
try:
migrate(args.source_db_url, args.dest_db_url)
except Exception as e:
logger.error(f"Migration failed: {e}")
sys.exit(1)
@@ -0,0 +1,128 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Migration runner to upgrade schemas to the latest version."""
from __future__ import annotations
import logging
import os
import tempfile
from google.adk.sessions.migration import _schema_check
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle
logger = logging.getLogger("google_adk." + __name__)
# Migration map where key is start_version and value is
# (end_version, migration_function).
# Each key is a schema version, and its value is a tuple containing:
# (the schema version AFTER this migration step, the migration function to run).
# The migration function should accept (source_db_url, dest_db_url) as
# arguments.
MIGRATIONS = {
_schema_check.SCHEMA_VERSION_0_1_PICKLE: (
_schema_check.SCHEMA_VERSION_1_0_JSON,
migrate_from_sqlalchemy_pickle.migrate,
),
}
# The most recent schema version. The migration process stops once this version
# is reached.
LATEST_VERSION = _schema_check.CURRENT_SCHEMA_VERSION
def upgrade(source_db_url: str, dest_db_url: str):
"""Migrates a database from its current version to the latest version.
If the source database schema is older than the latest version, this
function applies migration scripts sequentially until the schema reaches the
LATEST_VERSION.
If multiple migration steps are required, intermediate results are stored in
temporary SQLite database files. This means a multi-step migration
between other database types (e.g. PostgreSQL to PostgreSQL) will use
SQLite for intermediate steps.
In-place migration (source_db_url == dest_db_url) is not supported,
as migrations always read from a source and write to a destination.
Args:
source_db_url: The SQLAlchemy URL of the database to migrate from.
dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be
different from source_db_url.
Raises:
RuntimeError: If source_db_url and dest_db_url are the same, or if no
migration path is found.
"""
current_version = _schema_check.get_db_schema_version(source_db_url)
if current_version == LATEST_VERSION:
logger.info(
f"Database {source_db_url} is already at latest version"
f" {LATEST_VERSION}. No migration needed."
)
return
if source_db_url == dest_db_url:
raise RuntimeError(
"In-place migration is not supported. "
"Please provide a different file for dest_db_url."
)
# Build the list of migration steps required to reach LATEST_VERSION.
migrations_to_run = []
ver = current_version
while ver in MIGRATIONS and ver != LATEST_VERSION:
migrations_to_run.append(MIGRATIONS[ver])
ver = MIGRATIONS[ver][0]
if not migrations_to_run:
raise RuntimeError(
"Could not find migration path for schema version"
f" {current_version} to {LATEST_VERSION}."
)
temp_files = []
in_url = source_db_url
try:
for i, (end_version, migrate_func) in enumerate(migrations_to_run):
is_last_step = i == len(migrations_to_run) - 1
if is_last_step:
out_url = dest_db_url
else:
# For intermediate steps, create a temporary SQLite DB to store the
# result.
fd, temp_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
out_url = f"sqlite:///{temp_path}"
temp_files.append(temp_path)
logger.debug(f"Created temp db {out_url} for step {i+1}")
logger.info(
f"Migrating from {in_url} to {out_url} (schema {end_version})..."
)
migrate_func(in_url, out_url)
logger.info(f"Finished migration step to schema {end_version}.")
# The output of this step becomes the input for the next step.
in_url = out_url
finally:
# Ensure temporary files are cleaned up even if migration fails.
# Cleanup temp files
for path in temp_files:
try:
os.remove(path)
logger.debug(f"Removed temp db {path}")
except OSError as e:
logger.warning(f"Failed to remove temp db file {path}: {e}")
@@ -107,7 +107,7 @@ class SqliteSessionService(BaseSessionService):
f"Database {db_path} seems to use an old schema."
" Please run the migration command to"
" migrate it to the new schema. Example: `python -m"
" google.adk.sessions.migrate_from_sqlalchemy_sqlite"
" google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite"
f" --source_db_path {db_path} --dest_db_path"
f" {db_path}.new` then backup {db_path} and rename"
f" {db_path}.new to {db_path}."
@@ -0,0 +1,106 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for migration scripts."""
from __future__ import annotations
from datetime import datetime
from datetime import timezone
from google.adk.events.event_actions import EventActions
from google.adk.sessions import database_session_service as dss
from google.adk.sessions.migration import _schema_check
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
def test_migrate_from_sqlalchemy_pickle(tmp_path):
"""Tests for migrate_from_sqlalchemy_pickle."""
source_db_path = tmp_path / "source_pickle.db"
dest_db_path = tmp_path / "dest_json.db"
source_db_url = f"sqlite:///{source_db_path}"
dest_db_url = f"sqlite:///{dest_db_path}"
# Setup source DB with old pickle schema
source_engine = create_engine(source_db_url)
mfsp.OldBase.metadata.create_all(source_engine)
SourceSession = sessionmaker(bind=source_engine)
source_session = SourceSession()
# Populate source data
now = datetime.now(timezone.utc)
app_state = mfsp.OldStorageAppState(
app_name="app1", state={"akey": 1}, update_time=now
)
user_state = mfsp.OldStorageUserState(
app_name="app1", user_id="user1", state={"ukey": 2}, update_time=now
)
session = mfsp.OldStorageSession(
app_name="app1",
user_id="user1",
id="session1",
state={"skey": 3},
create_time=now,
update_time=now,
)
event = mfsp.OldStorageEvent(
id="event1",
app_name="app1",
user_id="user1",
session_id="session1",
invocation_id="invoke1",
author="user",
actions=EventActions(state_delta={"skey": 4}),
timestamp=now,
)
source_session.add_all([app_state, user_state, session, event])
source_session.commit()
source_session.close()
mfsp.migrate(source_db_url, dest_db_url)
# Verify destination DB
dest_engine = create_engine(dest_db_url)
DestSession = sessionmaker(bind=dest_engine)
dest_session = DestSession()
metadata = dest_session.query(dss.StorageMetadata).first()
assert metadata is not None
assert metadata.key == _schema_check.SCHEMA_VERSION_KEY
assert metadata.value == _schema_check.SCHEMA_VERSION_1_0_JSON
app_state_res = dest_session.query(dss.StorageAppState).first()
assert app_state_res is not None
assert app_state_res.app_name == "app1"
assert app_state_res.state == {"akey": 1}
user_state_res = dest_session.query(dss.StorageUserState).first()
assert user_state_res is not None
assert user_state_res.user_id == "user1"
assert user_state_res.state == {"ukey": 2}
session_res = dest_session.query(dss.StorageSession).first()
assert session_res is not None
assert session_res.id == "session1"
assert session_res.state == {"skey": 3}
event_res = dest_session.query(dss.StorageEvent).first()
assert event_res is not None
assert event_res.id == "event1"
assert "state_delta" in event_res.event_data["actions"]
assert event_res.event_data["actions"]["state_delta"] == {"skey": 4}
dest_session.close()
@@ -1,181 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pickle
from unittest import mock
from google.adk.sessions.database_session_service import DynamicPickleType
import pytest
from sqlalchemy import create_engine
from sqlalchemy.dialects import mysql
@pytest.fixture
def pickle_type():
"""Fixture for DynamicPickleType instance."""
return DynamicPickleType()
def test_load_dialect_impl_mysql(pickle_type):
"""Test that MySQL dialect uses LONGBLOB."""
# Mock the MySQL dialect
mock_dialect = mock.Mock()
mock_dialect.name = "mysql"
# Mock the return value of type_descriptor
mock_longblob_type = mock.Mock()
mock_dialect.type_descriptor.return_value = mock_longblob_type
impl = pickle_type.load_dialect_impl(mock_dialect)
# Verify type_descriptor was called once with mysql.LONGBLOB
mock_dialect.type_descriptor.assert_called_once_with(mysql.LONGBLOB)
# Verify the return value is what we expect
assert impl == mock_longblob_type
def test_load_dialect_impl_spanner(pickle_type):
"""Test that Spanner dialect uses SpannerPickleType."""
# Mock the spanner dialect
mock_dialect = mock.Mock()
mock_dialect.name = "spanner+spanner"
with mock.patch(
"google.cloud.sqlalchemy_spanner.sqlalchemy_spanner.SpannerPickleType"
) as mock_spanner_type:
pickle_type.load_dialect_impl(mock_dialect)
mock_dialect.type_descriptor.assert_called_once_with(mock_spanner_type)
def test_load_dialect_impl_default(pickle_type):
"""Test that other dialects use default PickleType."""
engine = create_engine("sqlite:///:memory:")
dialect = engine.dialect
impl = pickle_type.load_dialect_impl(dialect)
# Should return the default impl (PickleType)
assert impl == pickle_type.impl
@pytest.mark.parametrize(
"dialect_name",
[
pytest.param("mysql", id="mysql"),
pytest.param("spanner+spanner", id="spanner"),
],
)
def test_process_bind_param_pickle_dialects(pickle_type, dialect_name):
"""Test that MySQL and Spanner dialects pickle the value."""
mock_dialect = mock.Mock()
mock_dialect.name = dialect_name
test_data = {"key": "value", "nested": [1, 2, 3]}
result = pickle_type.process_bind_param(test_data, mock_dialect)
# Should be pickled bytes
assert isinstance(result, bytes)
# Should be able to unpickle back to original
assert pickle.loads(result) == test_data
def test_process_bind_param_default(pickle_type):
"""Test that other dialects return value as-is."""
mock_dialect = mock.Mock()
mock_dialect.name = "sqlite"
test_data = {"key": "value"}
result = pickle_type.process_bind_param(test_data, mock_dialect)
# Should return value unchanged (SQLAlchemy's PickleType handles it)
assert result == test_data
def test_process_bind_param_none(pickle_type):
"""Test that None values are handled correctly."""
mock_dialect = mock.Mock()
mock_dialect.name = "mysql"
result = pickle_type.process_bind_param(None, mock_dialect)
assert result is None
@pytest.mark.parametrize(
"dialect_name",
[
pytest.param("mysql", id="mysql"),
pytest.param("spanner+spanner", id="spanner"),
],
)
def test_process_result_value_pickle_dialects(pickle_type, dialect_name):
"""Test that MySQL and Spanner dialects unpickle the value."""
mock_dialect = mock.Mock()
mock_dialect.name = dialect_name
test_data = {"key": "value", "nested": [1, 2, 3]}
pickled_data = pickle.dumps(test_data)
result = pickle_type.process_result_value(pickled_data, mock_dialect)
# Should be unpickled back to original
assert result == test_data
def test_process_result_value_default(pickle_type):
"""Test that other dialects return value as-is."""
mock_dialect = mock.Mock()
mock_dialect.name = "sqlite"
test_data = {"key": "value"}
result = pickle_type.process_result_value(test_data, mock_dialect)
# Should return value unchanged (SQLAlchemy's PickleType handles it)
assert result == test_data
def test_process_result_value_none(pickle_type):
"""Test that None values are handled correctly."""
mock_dialect = mock.Mock()
mock_dialect.name = "mysql"
result = pickle_type.process_result_value(None, mock_dialect)
assert result is None
@pytest.mark.parametrize(
"dialect_name",
[
pytest.param("mysql", id="mysql"),
pytest.param("spanner+spanner", id="spanner"),
],
)
def test_roundtrip_pickle_dialects(pickle_type, dialect_name):
"""Test full roundtrip for MySQL and Spanner: bind -> result."""
mock_dialect = mock.Mock()
mock_dialect.name = dialect_name
original_data = {
"string": "test",
"number": 42,
"list": [1, 2, 3],
"nested": {"a": 1, "b": 2},
}
# Simulate bind (Python -> DB)
bound_value = pickle_type.process_bind_param(original_data, mock_dialect)
assert isinstance(bound_value, bytes)
# Simulate result (DB -> Python)
result_value = pickle_type.process_result_value(bound_value, mock_dialect)
assert result_value == original_data