You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat!: Migrate DatabaseSessionService to use JSON serialization schema
Also provide a command line tool `adk migrate session` for DB migration Addresses https://github.com/google/adk-python/discussions/3605 Addresses https://github.com/google/adk-python/issues/3681 To verify: ``` # Start one postgres DB docker run --name my-postgres -d -e POSTGRES_DB=agent -e POSTGRES_USER=agent -e POSTGRES_PASSWORD=agent -e PGDATA=/var/lib/postgresql/data/pgdata -v pgvolume:/var/lib/postgresql/data -p 5532:5432 postgres # Connect to an old version of ADK and produce some query data adk web --session_service_uri=postgresql://agent:agent@localhost:5532/agent # Check out to the latest branch and restart ADK web # You should see error log ask you to migrate the DB # Start a new DB docker run --name migration-test-db \ -d \ --rm \ -e POSTGRES_DB=agent \ -e POSTGRES_USER=agent \ -e POSTGRES_PASSWORD=agent -e PGDATA=/var/lib/postgresql/data/pgdata -v migration_test_vol:/var/lib/postgresql/data -p 5533:5432 postgres # DB Migration adk migrate session \ --source_db_url="postgresql://agent:agent@localhost:5532/agent" \ --dest_db_url="postgresql://agent:agent@localhost:5533/agent" # Run ADK web with the new DB adk web --session_service_uri=postgresql+asyncpg://agent:agent@localhost:5533/agent # You should see the data from old DB is migrated ``` Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 837341139
This commit is contained in:
committed by
Copybara-Service
parent
786aaed335
commit
0094eea3ca
@@ -36,6 +36,7 @@ from . import cli_create
|
||||
from . import cli_deploy
|
||||
from .. import version
|
||||
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
from ..sessions.migration import migration_runner
|
||||
from .cli import run_cli
|
||||
from .fast_api import get_fast_api_app
|
||||
from .utils import envs
|
||||
@@ -1485,6 +1486,41 @@ def cli_deploy_cloud_run(
|
||||
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
||||
|
||||
|
||||
@main.group()
|
||||
def migrate():
|
||||
"""Migrate ADK database schemas."""
|
||||
pass
|
||||
|
||||
|
||||
@migrate.command("session", cls=HelpfulCommand)
|
||||
@click.option(
|
||||
"--source_db_url",
|
||||
required=True,
|
||||
help="SQLAlchemy URL of source database.",
|
||||
)
|
||||
@click.option(
|
||||
"--dest_db_url",
|
||||
required=True,
|
||||
help="SQLAlchemy URL of destination database.",
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
type=LOG_LEVELS,
|
||||
default="INFO",
|
||||
help="Optional. Set the logging level",
|
||||
)
|
||||
def cli_migrate_session(
|
||||
*, source_db_url: str, dest_db_url: str, log_level: str
|
||||
):
|
||||
"""Migrates a session database to the latest schema version."""
|
||||
logs.setup_adk_logger(getattr(logging, log_level.upper()))
|
||||
try:
|
||||
migration_runner.upgrade(source_db_url, dest_db_url)
|
||||
click.secho("Migration check and upgrade process finished.", fg="green")
|
||||
except Exception as e:
|
||||
click.secho(f"Migration failed: {e}", fg="red", err=True)
|
||||
|
||||
|
||||
@deploy.command("agent_engine")
|
||||
@click.option(
|
||||
"--api_key",
|
||||
|
||||
@@ -19,18 +19,16 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import Dialect
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import mysql
|
||||
@@ -41,14 +39,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.schema import MetaData
|
||||
from sqlalchemy.types import DateTime
|
||||
from sqlalchemy.types import PickleType
|
||||
from sqlalchemy.types import String
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from typing_extensions import override
|
||||
@@ -57,10 +52,10 @@ from tzlocal import get_localzone
|
||||
from . import _session_util
|
||||
from ..errors.already_exists_error import AlreadyExistsError
|
||||
from ..events.event import Event
|
||||
from ..events.event_actions import EventActions
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
from .base_session_service import ListSessionsResponse
|
||||
from .migration import _schema_check
|
||||
from .session import Session
|
||||
from .state import State
|
||||
|
||||
@@ -111,41 +106,22 @@ class PreciseTimestamp(TypeDecorator):
|
||||
return self.impl
|
||||
|
||||
|
||||
class DynamicPickleType(TypeDecorator):
|
||||
"""Represents a type that can be pickled."""
|
||||
|
||||
impl = PickleType
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "mysql":
|
||||
return dialect.type_descriptor(mysql.LONGBLOB)
|
||||
if dialect.name == "spanner+spanner":
|
||||
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType
|
||||
|
||||
return dialect.type_descriptor(SpannerPickleType)
|
||||
return self.impl
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Ensures the pickled value is a bytes object before passing it to the database dialect."""
|
||||
if value is not None:
|
||||
if dialect.name in ("spanner+spanner", "mysql"):
|
||||
return pickle.dumps(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
|
||||
if value is not None:
|
||||
if dialect.name in ("spanner+spanner", "mysql"):
|
||||
return pickle.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for database tables."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageMetadata(Base):
|
||||
"""Represents internal metadata stored in the database."""
|
||||
|
||||
__tablename__ = "adk_internal_metadata"
|
||||
key: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
|
||||
|
||||
class StorageSession(Base):
|
||||
"""Represents a session stored in the database."""
|
||||
|
||||
@@ -237,46 +213,10 @@ class StorageEvent(Base):
|
||||
)
|
||||
|
||||
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType)
|
||||
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
branch: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
timestamp: Mapped[PreciseTimestamp] = mapped_column(
|
||||
PreciseTimestamp, default=func.now()
|
||||
)
|
||||
|
||||
# === Fields from llm_response.py ===
|
||||
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
||||
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
|
||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
error_code: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
|
||||
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
input_transcription: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
output_transcription: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
|
||||
|
||||
storage_session: Mapped[StorageSession] = relationship(
|
||||
"StorageSession",
|
||||
@@ -291,102 +231,27 @@ class StorageEvent(Base):
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def long_running_tool_ids(self) -> set[str]:
|
||||
return (
|
||||
set(json.loads(self.long_running_tool_ids_json))
|
||||
if self.long_running_tool_ids_json
|
||||
else set()
|
||||
)
|
||||
|
||||
@long_running_tool_ids.setter
|
||||
def long_running_tool_ids(self, value: set[str]):
|
||||
if value is None:
|
||||
self.long_running_tool_ids_json = None
|
||||
else:
|
||||
self.long_running_tool_ids_json = json.dumps(list(value))
|
||||
|
||||
@classmethod
|
||||
def from_event(cls, session: Session, event: Event) -> StorageEvent:
|
||||
storage_event = StorageEvent(
|
||||
"""Creates a StorageEvent from an Event."""
|
||||
return StorageEvent(
|
||||
id=event.id,
|
||||
invocation_id=event.invocation_id,
|
||||
author=event.author,
|
||||
branch=event.branch,
|
||||
actions=event.actions,
|
||||
session_id=session.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||
long_running_tool_ids=event.long_running_tool_ids,
|
||||
partial=event.partial,
|
||||
turn_complete=event.turn_complete,
|
||||
error_code=event.error_code,
|
||||
error_message=event.error_message,
|
||||
interrupted=event.interrupted,
|
||||
event_data=event.model_dump(exclude_none=True, mode="json"),
|
||||
)
|
||||
if event.content:
|
||||
storage_event.content = event.content.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.grounding_metadata:
|
||||
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.custom_metadata:
|
||||
storage_event.custom_metadata = event.custom_metadata
|
||||
if event.usage_metadata:
|
||||
storage_event.usage_metadata = event.usage_metadata.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.citation_metadata:
|
||||
storage_event.citation_metadata = event.citation_metadata.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.input_transcription:
|
||||
storage_event.input_transcription = event.input_transcription.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.output_transcription:
|
||||
storage_event.output_transcription = (
|
||||
event.output_transcription.model_dump(exclude_none=True, mode="json")
|
||||
)
|
||||
return storage_event
|
||||
|
||||
def to_event(self) -> Event:
|
||||
return Event(
|
||||
id=self.id,
|
||||
invocation_id=self.invocation_id,
|
||||
author=self.author,
|
||||
branch=self.branch,
|
||||
# This is needed as previous ADK version pickled actions might not have
|
||||
# value defined in the current version of the EventActions model.
|
||||
actions=EventActions().model_copy(update=self.actions.model_dump()),
|
||||
timestamp=self.timestamp.timestamp(),
|
||||
long_running_tool_ids=self.long_running_tool_ids,
|
||||
partial=self.partial,
|
||||
turn_complete=self.turn_complete,
|
||||
error_code=self.error_code,
|
||||
error_message=self.error_message,
|
||||
interrupted=self.interrupted,
|
||||
custom_metadata=self.custom_metadata,
|
||||
content=_session_util.decode_model(self.content, types.Content),
|
||||
grounding_metadata=_session_util.decode_model(
|
||||
self.grounding_metadata, types.GroundingMetadata
|
||||
),
|
||||
usage_metadata=_session_util.decode_model(
|
||||
self.usage_metadata, types.GenerateContentResponseUsageMetadata
|
||||
),
|
||||
citation_metadata=_session_util.decode_model(
|
||||
self.citation_metadata, types.CitationMetadata
|
||||
),
|
||||
input_transcription=_session_util.decode_model(
|
||||
self.input_transcription, types.Transcription
|
||||
),
|
||||
output_transcription=_session_util.decode_model(
|
||||
self.output_transcription, types.Transcription
|
||||
),
|
||||
)
|
||||
"""Converts the StorageEvent to an Event."""
|
||||
return Event.model_validate({
|
||||
**self.event_data,
|
||||
"id": self.id,
|
||||
"invocation_id": self.invocation_id,
|
||||
"timestamp": self.timestamp.timestamp(),
|
||||
})
|
||||
|
||||
|
||||
class StorageAppState(Base):
|
||||
@@ -463,7 +328,6 @@ class DatabaseSessionService(BaseSessionService):
|
||||
logger.info("Local timezone: %s", local_timezone)
|
||||
|
||||
self.db_engine: AsyncEngine = db_engine
|
||||
self.metadata: MetaData = MetaData()
|
||||
|
||||
# DB session factory method
|
||||
self.database_session_factory: async_sessionmaker[
|
||||
@@ -483,10 +347,46 @@ class DatabaseSessionService(BaseSessionService):
|
||||
async with self._table_creation_lock:
|
||||
# Double-check after acquiring the lock
|
||||
if not self._tables_created:
|
||||
# Check schema version BEFORE creating tables.
|
||||
# This prevents creating metadata table on a v0.1 DB.
|
||||
async with self.database_session_factory() as sql_session:
|
||||
version, is_v01 = await sql_session.run_sync(
|
||||
_schema_check.get_version_and_v01_status_sync
|
||||
)
|
||||
|
||||
if is_v01:
|
||||
raise RuntimeError(
|
||||
"Database schema appears to be v0.1, but"
|
||||
f" {_schema_check.CURRENT_SCHEMA_VERSION} is required. Please"
|
||||
" migrate the database using 'adk migrate session'."
|
||||
)
|
||||
elif version and version < _schema_check.CURRENT_SCHEMA_VERSION:
|
||||
raise RuntimeError(
|
||||
f"Database schema version is {version}, but current version is"
|
||||
f" {_schema_check.CURRENT_SCHEMA_VERSION}. Please migrate"
|
||||
" the database to the latest version using 'adk migrate"
|
||||
" session'."
|
||||
)
|
||||
|
||||
async with self.db_engine.begin() as conn:
|
||||
# Uncomment to recreate DB every time
|
||||
# await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# If we are here, DB is either new or >= current version.
|
||||
# If new or without metadata row, stamp it as current version.
|
||||
async with self.database_session_factory() as sql_session:
|
||||
metadata = await sql_session.get(
|
||||
StorageMetadata, _schema_check.SCHEMA_VERSION_KEY
|
||||
)
|
||||
if not metadata:
|
||||
sql_session.add(
|
||||
StorageMetadata(
|
||||
key=_schema_check.SCHEMA_VERSION_KEY,
|
||||
value=_schema_check.CURRENT_SCHEMA_VERSION,
|
||||
)
|
||||
)
|
||||
await sql_session.commit()
|
||||
self._tables_created = True
|
||||
|
||||
@override
|
||||
@@ -723,7 +623,9 @@ class DatabaseSessionService(BaseSessionService):
|
||||
storage_session.state = storage_session.state | session_state_delta
|
||||
|
||||
if storage_session._dialect_name == "sqlite":
|
||||
update_time = datetime.utcfromtimestamp(event.timestamp)
|
||||
update_time = datetime.fromtimestamp(
|
||||
event.timestamp, timezone.utc
|
||||
).replace(tzinfo=None)
|
||||
else:
|
||||
update_time = datetime.fromtimestamp(event.timestamp)
|
||||
storage_session.update_time = update_time
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Database schema version check utility."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import create_engine as create_sync_engine
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
SCHEMA_VERSION_KEY = "schema_version"
|
||||
SCHEMA_VERSION_0_1_PICKLE = "0.1"
|
||||
SCHEMA_VERSION_1_0_JSON = "1.0"
|
||||
CURRENT_SCHEMA_VERSION = "1.0"
|
||||
|
||||
|
||||
def _to_sync_url(db_url: str) -> str:
|
||||
"""Removes +driver from SQLAlchemy URL."""
|
||||
if "://" in db_url:
|
||||
scheme, _, rest = db_url.partition("://")
|
||||
if "+" in scheme:
|
||||
dialect = scheme.split("+", 1)[0]
|
||||
return f"{dialect}://{rest}"
|
||||
return db_url
|
||||
|
||||
|
||||
def get_version_and_v01_status_sync(
|
||||
sess: sqlalchemy.orm.Session,
|
||||
) -> tuple[str | None, bool]:
|
||||
"""Returns (version, is_v01) inspecting the database."""
|
||||
inspector = sqlalchemy.inspect(sess.get_bind())
|
||||
if inspector.has_table("adk_internal_metadata"):
|
||||
try:
|
||||
result = sess.execute(
|
||||
text("SELECT value FROM adk_internal_metadata WHERE key = :key"),
|
||||
{"key": SCHEMA_VERSION_KEY},
|
||||
).fetchone()
|
||||
# If table exists, with or without key, it's 1.0 or newer.
|
||||
return (result[0] if result else SCHEMA_VERSION_1_0_JSON), False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not read from adk_internal_metadata: %s. Assuming v1.0.",
|
||||
e,
|
||||
)
|
||||
return SCHEMA_VERSION_1_0_JSON, False
|
||||
|
||||
if inspector.has_table("events"):
|
||||
try:
|
||||
cols = {c["name"] for c in inspector.get_columns("events")}
|
||||
if "actions" in cols and "event_data" not in cols:
|
||||
return None, True # 0.1 schema
|
||||
except Exception as e:
|
||||
logger.warning("Could not inspect 'events' table columns: %s", e)
|
||||
return None, False # New DB
|
||||
|
||||
|
||||
def get_db_schema_version(db_url: str) -> str | None:
|
||||
"""Reads schema version from DB.
|
||||
|
||||
Checks metadata table first, falls back to table structure for 0.1 vs 1.0.
|
||||
"""
|
||||
engine = None
|
||||
try:
|
||||
engine = create_sync_engine(_to_sync_url(db_url))
|
||||
inspector = inspect(engine)
|
||||
|
||||
if inspector.has_table("adk_internal_metadata"):
|
||||
with engine.connect() as connection:
|
||||
result = connection.execute(
|
||||
text("SELECT value FROM adk_internal_metadata WHERE key = :key"),
|
||||
parameters={"key": SCHEMA_VERSION_KEY},
|
||||
).fetchone()
|
||||
# If table exists, with or without key, it's 1.0 or newer.
|
||||
return result[0] if result else SCHEMA_VERSION_1_0_JSON
|
||||
|
||||
# Metadata table doesn't exist, check for 0.1 schema.
|
||||
# 0.1 schema has an 'events' table with an 'actions' column.
|
||||
if inspector.has_table("events"):
|
||||
try:
|
||||
cols = {c["name"] for c in inspector.get_columns("events")}
|
||||
if "actions" in cols and "event_data" not in cols:
|
||||
return SCHEMA_VERSION_0_1_PICKLE
|
||||
except Exception as e:
|
||||
logger.warning("Could not inspect 'events' table columns: %s", e)
|
||||
|
||||
# If no metadata table and not identifiable as 0.1,
|
||||
# assume it is a new/empty DB requiring schema 1.0.
|
||||
return SCHEMA_VERSION_1_0_JSON
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Could not determine schema version by inspecting database: %s."
|
||||
" Assuming v1.0.",
|
||||
e,
|
||||
)
|
||||
return SCHEMA_VERSION_1_0_JSON
|
||||
finally:
|
||||
if engine:
|
||||
engine.dispose()
|
||||
@@ -0,0 +1,492 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Migration script from SQLAlchemy DB with Pickle Events to JSON schema."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.events.event_actions import EventActions
|
||||
from google.adk.sessions import _session_util
|
||||
from google.adk.sessions import database_session_service as dss
|
||||
from google.adk.sessions.migration import _schema_check
|
||||
from google.genai import types
|
||||
import sqlalchemy
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import mysql
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.types import PickleType
|
||||
from sqlalchemy.types import String
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
# --- Old Schema Definitions ---
|
||||
class DynamicPickleType(TypeDecorator):
|
||||
"""Represents a type that can be pickled."""
|
||||
|
||||
impl = PickleType
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "mysql":
|
||||
return dialect.type_descriptor(mysql.LONGBLOB)
|
||||
if dialect.name == "spanner+spanner":
|
||||
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType
|
||||
|
||||
return dialect.type_descriptor(SpannerPickleType)
|
||||
return self.impl
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Ensures the pickled value is a bytes object before passing it to the database dialect."""
|
||||
if value is not None:
|
||||
if dialect.name in ("spanner+spanner", "mysql"):
|
||||
return pickle.dumps(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
|
||||
if value is not None:
|
||||
if dialect.name in ("spanner+spanner", "mysql"):
|
||||
return pickle.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
class OldBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class OldStorageSession(OldBase):
|
||||
__tablename__ = "sessions"
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(dss.DynamicJSON), default={}
|
||||
)
|
||||
create_time: Mapped[datetime] = mapped_column(
|
||||
dss.PreciseTimestamp, default=func.now()
|
||||
)
|
||||
update_time: Mapped[datetime] = mapped_column(
|
||||
dss.PreciseTimestamp, default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class OldStorageEvent(OldBase):
|
||||
"""Old storage event with pickle."""
|
||||
|
||||
__tablename__ = "events"
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
invocation_id: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_VARCHAR_LENGTH)
|
||||
)
|
||||
author: Mapped[str] = mapped_column(String(dss.DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
actions: Mapped[Any] = mapped_column(DynamicPickleType)
|
||||
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
branch: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
dss.PreciseTimestamp, default=func.now()
|
||||
)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(
|
||||
dss.DynamicJSON, nullable=True
|
||||
)
|
||||
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
dss.DynamicJSON, nullable=True
|
||||
)
|
||||
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
dss.DynamicJSON, nullable=True
|
||||
)
|
||||
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
dss.DynamicJSON, nullable=True
|
||||
)
|
||||
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
dss.DynamicJSON, nullable=True
|
||||
)
|
||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
error_code: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
|
||||
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
input_transcription: Mapped[dict[str, Any]] = mapped_column(
|
||||
dss.DynamicJSON, nullable=True
|
||||
)
|
||||
output_transcription: Mapped[dict[str, Any]] = mapped_column(
|
||||
dss.DynamicJSON, nullable=True
|
||||
)
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
["app_name", "user_id", "session_id"],
|
||||
["sessions.app_name", "sessions.user_id", "sessions.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def long_running_tool_ids(self) -> set[str]:
|
||||
return (
|
||||
set(json.loads(self.long_running_tool_ids_json))
|
||||
if self.long_running_tool_ids_json
|
||||
else set()
|
||||
)
|
||||
|
||||
|
||||
def _to_datetime_obj(val: Any) -> datetime | Any:
|
||||
"""Converts string to datetime if needed."""
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return datetime.strptime(val, "%Y-%m-%d %H:%M:%S.%f")
|
||||
except ValueError:
|
||||
try:
|
||||
return datetime.strptime(val, "%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
pass # return as is if not matching format
|
||||
return val
|
||||
|
||||
|
||||
def _row_to_event(row: dict) -> Event:
|
||||
"""Converts event row (dict) to event object, handling missing columns and deserializing."""
|
||||
|
||||
actions_val = row.get("actions")
|
||||
actions = None
|
||||
if actions_val is not None:
|
||||
try:
|
||||
if isinstance(actions_val, bytes):
|
||||
actions = pickle.loads(actions_val)
|
||||
else: # for spanner - it might return object directly
|
||||
actions = actions_val
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to unpickle actions for event {row.get('id')}: {e}"
|
||||
)
|
||||
actions = None
|
||||
|
||||
if actions and hasattr(actions, "model_dump"):
|
||||
actions = EventActions().model_copy(update=actions.model_dump())
|
||||
elif isinstance(actions, dict):
|
||||
actions = EventActions(**actions)
|
||||
else:
|
||||
actions = EventActions()
|
||||
|
||||
def _safe_json_load(val):
|
||||
data = None
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
data = json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to decode JSON for event {row.get('id')}")
|
||||
return None
|
||||
elif isinstance(val, dict):
|
||||
data = val # for postgres JSONB
|
||||
return data
|
||||
|
||||
content_dict = _safe_json_load(row.get("content"))
|
||||
grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata"))
|
||||
custom_metadata_dict = _safe_json_load(row.get("custom_metadata"))
|
||||
usage_metadata_dict = _safe_json_load(row.get("usage_metadata"))
|
||||
citation_metadata_dict = _safe_json_load(row.get("citation_metadata"))
|
||||
input_transcription_dict = _safe_json_load(row.get("input_transcription"))
|
||||
output_transcription_dict = _safe_json_load(row.get("output_transcription"))
|
||||
|
||||
long_running_tool_ids_json = row.get("long_running_tool_ids_json")
|
||||
long_running_tool_ids = set()
|
||||
if long_running_tool_ids_json:
|
||||
try:
|
||||
long_running_tool_ids = set(json.loads(long_running_tool_ids_json))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to decode long_running_tool_ids_json for event"
|
||||
f" {row.get('id')}"
|
||||
)
|
||||
long_running_tool_ids = set()
|
||||
|
||||
event_id = row.get("id")
|
||||
if not event_id:
|
||||
raise ValueError("Event must have an id.")
|
||||
timestamp = _to_datetime_obj(row.get("timestamp"))
|
||||
if not timestamp:
|
||||
raise ValueError(f"Event {event_id} must have a timestamp.")
|
||||
|
||||
return Event(
|
||||
id=event_id,
|
||||
invocation_id=row.get("invocation_id", ""),
|
||||
author=row.get("author", "agent"),
|
||||
branch=row.get("branch"),
|
||||
actions=actions,
|
||||
timestamp=timestamp.replace(tzinfo=timezone.utc).timestamp(),
|
||||
long_running_tool_ids=long_running_tool_ids,
|
||||
partial=row.get("partial"),
|
||||
turn_complete=row.get("turn_complete"),
|
||||
error_code=row.get("error_code"),
|
||||
error_message=row.get("error_message"),
|
||||
interrupted=row.get("interrupted"),
|
||||
custom_metadata=custom_metadata_dict,
|
||||
content=_session_util.decode_model(content_dict, types.Content),
|
||||
grounding_metadata=_session_util.decode_model(
|
||||
grounding_metadata_dict, types.GroundingMetadata
|
||||
),
|
||||
usage_metadata=_session_util.decode_model(
|
||||
usage_metadata_dict, types.GenerateContentResponseUsageMetadata
|
||||
),
|
||||
citation_metadata=_session_util.decode_model(
|
||||
citation_metadata_dict, types.CitationMetadata
|
||||
),
|
||||
input_transcription=_session_util.decode_model(
|
||||
input_transcription_dict, types.Transcription
|
||||
),
|
||||
output_transcription=_session_util.decode_model(
|
||||
output_transcription_dict, types.Transcription
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_state_dict(state_val: Any) -> dict:
|
||||
"""Safely load dict from JSON string or return dict if already dict."""
|
||||
if isinstance(state_val, dict):
|
||||
return state_val
|
||||
if isinstance(state_val, str):
|
||||
try:
|
||||
return json.loads(state_val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to parse state JSON string, defaulting to empty dict."
|
||||
)
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
class OldStorageAppState(OldBase):
|
||||
__tablename__ = "app_states"
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(dss.DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[datetime] = mapped_column(
|
||||
dss.PreciseTimestamp, default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class OldStorageUserState(OldBase):
|
||||
__tablename__ = "user_states"
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(dss.DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[datetime] = mapped_column(
|
||||
dss.PreciseTimestamp, default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
# --- Migration Logic ---
|
||||
def migrate(source_db_url: str, dest_db_url: str):
|
||||
"""Migrates data from old pickle schema to new JSON schema."""
|
||||
logger.info(f"Connecting to source database: {source_db_url}")
|
||||
try:
|
||||
source_engine = create_engine(source_db_url)
|
||||
SourceSession = sessionmaker(bind=source_engine)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to source database: {e}")
|
||||
raise RuntimeError(f"Failed to connect to source database: {e}") from e
|
||||
|
||||
logger.info(f"Connecting to destination database: {dest_db_url}")
|
||||
try:
|
||||
dest_engine = create_engine(dest_db_url)
|
||||
dss.Base.metadata.create_all(dest_engine)
|
||||
DestSession = sessionmaker(bind=dest_engine)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to destination database: {e}")
|
||||
raise RuntimeError(f"Failed to connect to destination database: {e}") from e
|
||||
|
||||
with SourceSession() as source_session, DestSession() as dest_session:
|
||||
dest_session.merge(
|
||||
dss.StorageMetadata(
|
||||
key=_schema_check.SCHEMA_VERSION_KEY,
|
||||
value=_schema_check.SCHEMA_VERSION_1_0_JSON,
|
||||
)
|
||||
)
|
||||
dest_session.commit()
|
||||
try:
|
||||
inspector = sqlalchemy.inspect(source_engine)
|
||||
|
||||
logger.info("Migrating app_states...")
|
||||
if inspector.has_table("app_states"):
|
||||
rows = (
|
||||
source_session.execute(text("SELECT * FROM app_states"))
|
||||
.mappings()
|
||||
.all()
|
||||
)
|
||||
for row in rows:
|
||||
dest_session.merge(
|
||||
dss.StorageAppState(
|
||||
app_name=row["app_name"],
|
||||
state=_get_state_dict(row.get("state")),
|
||||
update_time=_to_datetime_obj(row["update_time"]),
|
||||
)
|
||||
)
|
||||
dest_session.commit()
|
||||
logger.info(f"Migrated {len(rows)} app_states.")
|
||||
else:
|
||||
logger.info("No 'app_states' table found in source db.")
|
||||
|
||||
logger.info("Migrating user_states...")
|
||||
if inspector.has_table("user_states"):
|
||||
rows = (
|
||||
source_session.execute(text("SELECT * FROM user_states"))
|
||||
.mappings()
|
||||
.all()
|
||||
)
|
||||
for row in rows:
|
||||
dest_session.merge(
|
||||
dss.StorageUserState(
|
||||
app_name=row["app_name"],
|
||||
user_id=row["user_id"],
|
||||
state=_get_state_dict(row.get("state")),
|
||||
update_time=_to_datetime_obj(row["update_time"]),
|
||||
)
|
||||
)
|
||||
dest_session.commit()
|
||||
logger.info(f"Migrated {len(rows)} user_states.")
|
||||
else:
|
||||
logger.info("No 'user_states' table found in source db.")
|
||||
|
||||
logger.info("Migrating sessions...")
|
||||
if inspector.has_table("sessions"):
|
||||
rows = (
|
||||
source_session.execute(text("SELECT * FROM sessions"))
|
||||
.mappings()
|
||||
.all()
|
||||
)
|
||||
for row in rows:
|
||||
dest_session.merge(
|
||||
dss.StorageSession(
|
||||
app_name=row["app_name"],
|
||||
user_id=row["user_id"],
|
||||
id=row["id"],
|
||||
state=_get_state_dict(row.get("state")),
|
||||
create_time=_to_datetime_obj(row["create_time"]),
|
||||
update_time=_to_datetime_obj(row["update_time"]),
|
||||
)
|
||||
)
|
||||
dest_session.commit()
|
||||
logger.info(f"Migrated {len(rows)} sessions.")
|
||||
else:
|
||||
logger.info("No 'sessions' table found in source db.")
|
||||
|
||||
logger.info("Migrating events...")
|
||||
events = []
|
||||
if inspector.has_table("events"):
|
||||
rows = (
|
||||
source_session.execute(text("SELECT * FROM events"))
|
||||
.mappings()
|
||||
.all()
|
||||
)
|
||||
for row in rows:
|
||||
try:
|
||||
event_obj = _row_to_event(dict(row))
|
||||
new_event = dss.StorageEvent(
|
||||
id=event_obj.id,
|
||||
app_name=row["app_name"],
|
||||
user_id=row["user_id"],
|
||||
session_id=row["session_id"],
|
||||
invocation_id=event_obj.invocation_id,
|
||||
timestamp=datetime.fromtimestamp(
|
||||
event_obj.timestamp, timezone.utc
|
||||
).replace(tzinfo=None),
|
||||
event_data=event_obj.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
dest_session.merge(new_event)
|
||||
events.append(new_event)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to migrate event row {row.get('id', 'N/A')}: {e}"
|
||||
)
|
||||
dest_session.commit()
|
||||
logger.info(f"Migrated {len(events)} events.")
|
||||
else:
|
||||
logger.info("No 'events' table found in source database.")
|
||||
|
||||
logger.info("Migration completed successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during migration: {e}", exc_info=True)
|
||||
dest_session.rollback()
|
||||
raise RuntimeError(f"An error occurred during migration: {e}") from e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Migrate ADK sessions from SQLAlchemy Pickle format to JSON format."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source_db_url", required=True, help="SQLAlchemy URL of source database"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dest_db_url",
|
||||
required=True,
|
||||
help="SQLAlchemy URL of destination database",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
migrate(args.source_db_url, args.dest_db_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {e}")
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,128 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Migration runner to upgrade schemas to the latest version."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from google.adk.sessions.migration import _schema_check
|
||||
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
# Migration map where key is start_version and value is
|
||||
# (end_version, migration_function).
|
||||
# Each key is a schema version, and its value is a tuple containing:
|
||||
# (the schema version AFTER this migration step, the migration function to run).
|
||||
# The migration function should accept (source_db_url, dest_db_url) as
|
||||
# arguments.
|
||||
MIGRATIONS = {
|
||||
_schema_check.SCHEMA_VERSION_0_1_PICKLE: (
|
||||
_schema_check.SCHEMA_VERSION_1_0_JSON,
|
||||
migrate_from_sqlalchemy_pickle.migrate,
|
||||
),
|
||||
}
|
||||
# The most recent schema version. The migration process stops once this version
|
||||
# is reached.
|
||||
LATEST_VERSION = _schema_check.CURRENT_SCHEMA_VERSION
|
||||
|
||||
|
||||
def upgrade(source_db_url: str, dest_db_url: str):
|
||||
"""Migrates a database from its current version to the latest version.
|
||||
|
||||
If the source database schema is older than the latest version, this
|
||||
function applies migration scripts sequentially until the schema reaches the
|
||||
LATEST_VERSION.
|
||||
|
||||
If multiple migration steps are required, intermediate results are stored in
|
||||
temporary SQLite database files. This means a multi-step migration
|
||||
between other database types (e.g. PostgreSQL to PostgreSQL) will use
|
||||
SQLite for intermediate steps.
|
||||
|
||||
In-place migration (source_db_url == dest_db_url) is not supported,
|
||||
as migrations always read from a source and write to a destination.
|
||||
|
||||
Args:
|
||||
source_db_url: The SQLAlchemy URL of the database to migrate from.
|
||||
dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be
|
||||
different from source_db_url.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If source_db_url and dest_db_url are the same, or if no
|
||||
migration path is found.
|
||||
"""
|
||||
current_version = _schema_check.get_db_schema_version(source_db_url)
|
||||
|
||||
if current_version == LATEST_VERSION:
|
||||
logger.info(
|
||||
f"Database {source_db_url} is already at latest version"
|
||||
f" {LATEST_VERSION}. No migration needed."
|
||||
)
|
||||
return
|
||||
|
||||
if source_db_url == dest_db_url:
|
||||
raise RuntimeError(
|
||||
"In-place migration is not supported. "
|
||||
"Please provide a different file for dest_db_url."
|
||||
)
|
||||
|
||||
# Build the list of migration steps required to reach LATEST_VERSION.
|
||||
migrations_to_run = []
|
||||
ver = current_version
|
||||
while ver in MIGRATIONS and ver != LATEST_VERSION:
|
||||
migrations_to_run.append(MIGRATIONS[ver])
|
||||
ver = MIGRATIONS[ver][0]
|
||||
|
||||
if not migrations_to_run:
|
||||
raise RuntimeError(
|
||||
"Could not find migration path for schema version"
|
||||
f" {current_version} to {LATEST_VERSION}."
|
||||
)
|
||||
|
||||
temp_files = []
|
||||
in_url = source_db_url
|
||||
try:
|
||||
for i, (end_version, migrate_func) in enumerate(migrations_to_run):
|
||||
is_last_step = i == len(migrations_to_run) - 1
|
||||
|
||||
if is_last_step:
|
||||
out_url = dest_db_url
|
||||
else:
|
||||
# For intermediate steps, create a temporary SQLite DB to store the
|
||||
# result.
|
||||
fd, temp_path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
out_url = f"sqlite:///{temp_path}"
|
||||
temp_files.append(temp_path)
|
||||
logger.debug(f"Created temp db {out_url} for step {i+1}")
|
||||
|
||||
logger.info(
|
||||
f"Migrating from {in_url} to {out_url} (schema {end_version})..."
|
||||
)
|
||||
migrate_func(in_url, out_url)
|
||||
logger.info(f"Finished migration step to schema {end_version}.")
|
||||
# The output of this step becomes the input for the next step.
|
||||
in_url = out_url
|
||||
finally:
|
||||
# Ensure temporary files are cleaned up even if migration fails.
|
||||
# Cleanup temp files
|
||||
for path in temp_files:
|
||||
try:
|
||||
os.remove(path)
|
||||
logger.debug(f"Removed temp db {path}")
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to remove temp db file {path}: {e}")
|
||||
@@ -107,7 +107,7 @@ class SqliteSessionService(BaseSessionService):
|
||||
f"Database {db_path} seems to use an old schema."
|
||||
" Please run the migration command to"
|
||||
" migrate it to the new schema. Example: `python -m"
|
||||
" google.adk.sessions.migrate_from_sqlalchemy_sqlite"
|
||||
" google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite"
|
||||
f" --source_db_path {db_path} --dest_db_path"
|
||||
f" {db_path}.new` then backup {db_path} and rename"
|
||||
f" {db_path}.new to {db_path}."
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for migration scripts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from google.adk.events.event_actions import EventActions
|
||||
from google.adk.sessions import database_session_service as dss
|
||||
from google.adk.sessions.migration import _schema_check
|
||||
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
def test_migrate_from_sqlalchemy_pickle(tmp_path):
|
||||
"""Tests for migrate_from_sqlalchemy_pickle."""
|
||||
source_db_path = tmp_path / "source_pickle.db"
|
||||
dest_db_path = tmp_path / "dest_json.db"
|
||||
source_db_url = f"sqlite:///{source_db_path}"
|
||||
dest_db_url = f"sqlite:///{dest_db_path}"
|
||||
|
||||
# Setup source DB with old pickle schema
|
||||
source_engine = create_engine(source_db_url)
|
||||
mfsp.OldBase.metadata.create_all(source_engine)
|
||||
SourceSession = sessionmaker(bind=source_engine)
|
||||
source_session = SourceSession()
|
||||
|
||||
# Populate source data
|
||||
now = datetime.now(timezone.utc)
|
||||
app_state = mfsp.OldStorageAppState(
|
||||
app_name="app1", state={"akey": 1}, update_time=now
|
||||
)
|
||||
user_state = mfsp.OldStorageUserState(
|
||||
app_name="app1", user_id="user1", state={"ukey": 2}, update_time=now
|
||||
)
|
||||
session = mfsp.OldStorageSession(
|
||||
app_name="app1",
|
||||
user_id="user1",
|
||||
id="session1",
|
||||
state={"skey": 3},
|
||||
create_time=now,
|
||||
update_time=now,
|
||||
)
|
||||
event = mfsp.OldStorageEvent(
|
||||
id="event1",
|
||||
app_name="app1",
|
||||
user_id="user1",
|
||||
session_id="session1",
|
||||
invocation_id="invoke1",
|
||||
author="user",
|
||||
actions=EventActions(state_delta={"skey": 4}),
|
||||
timestamp=now,
|
||||
)
|
||||
source_session.add_all([app_state, user_state, session, event])
|
||||
source_session.commit()
|
||||
source_session.close()
|
||||
|
||||
mfsp.migrate(source_db_url, dest_db_url)
|
||||
|
||||
# Verify destination DB
|
||||
dest_engine = create_engine(dest_db_url)
|
||||
DestSession = sessionmaker(bind=dest_engine)
|
||||
dest_session = DestSession()
|
||||
|
||||
metadata = dest_session.query(dss.StorageMetadata).first()
|
||||
assert metadata is not None
|
||||
assert metadata.key == _schema_check.SCHEMA_VERSION_KEY
|
||||
assert metadata.value == _schema_check.SCHEMA_VERSION_1_0_JSON
|
||||
|
||||
app_state_res = dest_session.query(dss.StorageAppState).first()
|
||||
assert app_state_res is not None
|
||||
assert app_state_res.app_name == "app1"
|
||||
assert app_state_res.state == {"akey": 1}
|
||||
|
||||
user_state_res = dest_session.query(dss.StorageUserState).first()
|
||||
assert user_state_res is not None
|
||||
assert user_state_res.user_id == "user1"
|
||||
assert user_state_res.state == {"ukey": 2}
|
||||
|
||||
session_res = dest_session.query(dss.StorageSession).first()
|
||||
assert session_res is not None
|
||||
assert session_res.id == "session1"
|
||||
assert session_res.state == {"skey": 3}
|
||||
|
||||
event_res = dest_session.query(dss.StorageEvent).first()
|
||||
assert event_res is not None
|
||||
assert event_res.id == "event1"
|
||||
assert "state_delta" in event_res.event_data["actions"]
|
||||
assert event_res.event_data["actions"]["state_delta"] == {"skey": 4}
|
||||
|
||||
dest_session.close()
|
||||
@@ -1,181 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.sessions.database_session_service import DynamicPickleType
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pickle_type():
|
||||
"""Fixture for DynamicPickleType instance."""
|
||||
return DynamicPickleType()
|
||||
|
||||
|
||||
def test_load_dialect_impl_mysql(pickle_type):
|
||||
"""Test that MySQL dialect uses LONGBLOB."""
|
||||
# Mock the MySQL dialect
|
||||
mock_dialect = mock.Mock()
|
||||
mock_dialect.name = "mysql"
|
||||
|
||||
# Mock the return value of type_descriptor
|
||||
mock_longblob_type = mock.Mock()
|
||||
mock_dialect.type_descriptor.return_value = mock_longblob_type
|
||||
|
||||
impl = pickle_type.load_dialect_impl(mock_dialect)
|
||||
|
||||
# Verify type_descriptor was called once with mysql.LONGBLOB
|
||||
mock_dialect.type_descriptor.assert_called_once_with(mysql.LONGBLOB)
|
||||
# Verify the return value is what we expect
|
||||
assert impl == mock_longblob_type
|
||||
|
||||
|
||||
def test_load_dialect_impl_spanner(pickle_type):
|
||||
"""Test that Spanner dialect uses SpannerPickleType."""
|
||||
# Mock the spanner dialect
|
||||
mock_dialect = mock.Mock()
|
||||
mock_dialect.name = "spanner+spanner"
|
||||
|
||||
with mock.patch(
|
||||
"google.cloud.sqlalchemy_spanner.sqlalchemy_spanner.SpannerPickleType"
|
||||
) as mock_spanner_type:
|
||||
pickle_type.load_dialect_impl(mock_dialect)
|
||||
mock_dialect.type_descriptor.assert_called_once_with(mock_spanner_type)
|
||||
|
||||
|
||||
def test_load_dialect_impl_default(pickle_type):
|
||||
"""Test that other dialects use default PickleType."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
dialect = engine.dialect
|
||||
impl = pickle_type.load_dialect_impl(dialect)
|
||||
# Should return the default impl (PickleType)
|
||||
assert impl == pickle_type.impl
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dialect_name",
|
||||
[
|
||||
pytest.param("mysql", id="mysql"),
|
||||
pytest.param("spanner+spanner", id="spanner"),
|
||||
],
|
||||
)
|
||||
def test_process_bind_param_pickle_dialects(pickle_type, dialect_name):
|
||||
"""Test that MySQL and Spanner dialects pickle the value."""
|
||||
mock_dialect = mock.Mock()
|
||||
mock_dialect.name = dialect_name
|
||||
|
||||
test_data = {"key": "value", "nested": [1, 2, 3]}
|
||||
result = pickle_type.process_bind_param(test_data, mock_dialect)
|
||||
|
||||
# Should be pickled bytes
|
||||
assert isinstance(result, bytes)
|
||||
# Should be able to unpickle back to original
|
||||
assert pickle.loads(result) == test_data
|
||||
|
||||
|
||||
def test_process_bind_param_default(pickle_type):
|
||||
"""Test that other dialects return value as-is."""
|
||||
mock_dialect = mock.Mock()
|
||||
mock_dialect.name = "sqlite"
|
||||
|
||||
test_data = {"key": "value"}
|
||||
result = pickle_type.process_bind_param(test_data, mock_dialect)
|
||||
|
||||
# Should return value unchanged (SQLAlchemy's PickleType handles it)
|
||||
assert result == test_data
|
||||
|
||||
|
||||
def test_process_bind_param_none(pickle_type):
|
||||
"""Test that None values are handled correctly."""
|
||||
mock_dialect = mock.Mock()
|
||||
mock_dialect.name = "mysql"
|
||||
|
||||
result = pickle_type.process_bind_param(None, mock_dialect)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dialect_name",
|
||||
[
|
||||
pytest.param("mysql", id="mysql"),
|
||||
pytest.param("spanner+spanner", id="spanner"),
|
||||
],
|
||||
)
|
||||
def test_process_result_value_pickle_dialects(pickle_type, dialect_name):
|
||||
"""Test that MySQL and Spanner dialects unpickle the value."""
|
||||
mock_dialect = mock.Mock()
|
||||
mock_dialect.name = dialect_name
|
||||
|
||||
test_data = {"key": "value", "nested": [1, 2, 3]}
|
||||
pickled_data = pickle.dumps(test_data)
|
||||
|
||||
result = pickle_type.process_result_value(pickled_data, mock_dialect)
|
||||
|
||||
# Should be unpickled back to original
|
||||
assert result == test_data
|
||||
|
||||
|
||||
def test_process_result_value_default(pickle_type):
|
||||
"""Test that other dialects return value as-is."""
|
||||
mock_dialect = mock.Mock()
|
||||
mock_dialect.name = "sqlite"
|
||||
|
||||
test_data = {"key": "value"}
|
||||
result = pickle_type.process_result_value(test_data, mock_dialect)
|
||||
|
||||
# Should return value unchanged (SQLAlchemy's PickleType handles it)
|
||||
assert result == test_data
|
||||
|
||||
|
||||
def test_process_result_value_none(pickle_type):
|
||||
"""Test that None values are handled correctly."""
|
||||
mock_dialect = mock.Mock()
|
||||
mock_dialect.name = "mysql"
|
||||
|
||||
result = pickle_type.process_result_value(None, mock_dialect)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dialect_name",
|
||||
[
|
||||
pytest.param("mysql", id="mysql"),
|
||||
pytest.param("spanner+spanner", id="spanner"),
|
||||
],
|
||||
)
|
||||
def test_roundtrip_pickle_dialects(pickle_type, dialect_name):
|
||||
"""Test full roundtrip for MySQL and Spanner: bind -> result."""
|
||||
mock_dialect = mock.Mock()
|
||||
mock_dialect.name = dialect_name
|
||||
|
||||
original_data = {
|
||||
"string": "test",
|
||||
"number": 42,
|
||||
"list": [1, 2, 3],
|
||||
"nested": {"a": 1, "b": 2},
|
||||
}
|
||||
|
||||
# Simulate bind (Python -> DB)
|
||||
bound_value = pickle_type.process_bind_param(original_data, mock_dialect)
|
||||
assert isinstance(bound_value, bytes)
|
||||
|
||||
# Simulate result (DB -> Python)
|
||||
result_value = pickle_type.process_result_value(bound_value, mock_dialect)
|
||||
assert result_value == original_data
|
||||
Reference in New Issue
Block a user