From 0094eea3cadf5fe2e960cc558e467dd2131de1b7 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 26 Nov 2025 19:48:34 -0800 Subject: [PATCH] feat!: Migrate DatabaseSessionService to use JSON serialization schema Also provide a command line tool `adk migrate session` for DB migration Addresses https://github.com/google/adk-python/discussions/3605 Addresses https://github.com/google/adk-python/issues/3681 To verify: ``` # Start one postgres DB docker run --name my-postgres -d -e POSTGRES_DB=agent -e POSTGRES_USER=agent -e POSTGRES_PASSWORD=agent -e PGDATA=/var/lib/postgresql/data/pgdata -v pgvolume:/var/lib/postgresql/data -p 5532:5432 postgres # Connect to an old version of ADK and produce some query data adk web --session_service_uri=postgresql://agent:agent@localhost:5532/agent # Check out to the latest branch and restart ADK web # You should see error log ask you to migrate the DB # Start a new DB docker run --name migration-test-db \ -d \ --rm \ -e POSTGRES_DB=agent \ -e POSTGRES_USER=agent \ -e POSTGRES_PASSWORD=agent -e PGDATA=/var/lib/postgresql/data/pgdata -v migration_test_vol:/var/lib/postgresql/data -p 5533:5432 postgres # DB Migration adk migrate session \ --source_db_url="postgresql://agent:agent@localhost:5532/agent" \ --dest_db_url="postgresql://agent:agent@localhost:5533/agent" # Run ADK web with the new DB adk web --session_service_uri=postgresql+asyncpg://agent:agent@localhost:5533/agent # You should see the data from old DB is migrated ``` Co-authored-by: Shangjie Chen PiperOrigin-RevId: 837341139 --- src/google/adk/cli/cli_tools_click.py | 36 ++ .../adk/sessions/database_session_service.py | 222 +++----- .../adk/sessions/migration/_schema_check.py | 114 ++++ .../migrate_from_sqlalchemy_pickle.py | 492 ++++++++++++++++++ .../migrate_from_sqlalchemy_sqlite.py | 0 .../sessions/migration/migration_runner.py | 128 +++++ .../adk/sessions/sqlite_session_service.py | 2 +- .../sessions/migration/test_migrations.py | 106 ++++ .../sessions/test_dynamic_pickle_type.py | 181 ------- 9 files changed, 939 insertions(+), 342 deletions(-) create mode 100644 src/google/adk/sessions/migration/_schema_check.py create mode 100644 src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py rename src/google/adk/sessions/{ => migration}/migrate_from_sqlalchemy_sqlite.py (100%) create mode 100644 src/google/adk/sessions/migration/migration_runner.py create mode 100644 tests/unittests/sessions/migration/test_migrations.py delete mode 100644 tests/unittests/sessions/test_dynamic_pickle_type.py diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index c4a13dd1..e5194272 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -36,6 +36,7 @@ from . import cli_create from . import cli_deploy from .. import version from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from ..sessions.migration import migration_runner from .cli import run_cli from .fast_api import get_fast_api_app from .utils import envs @@ -1485,6 +1486,41 @@ def cli_deploy_cloud_run( click.secho(f"Deploy failed: {e}", fg="red", err=True) +@main.group() +def migrate(): + """Migrate ADK database schemas.""" + pass + + +@migrate.command("session", cls=HelpfulCommand) +@click.option( + "--source_db_url", + required=True, + help="SQLAlchemy URL of source database.", +) +@click.option( + "--dest_db_url", + required=True, + help="SQLAlchemy URL of destination database.", +) +@click.option( + "--log_level", + type=LOG_LEVELS, + default="INFO", + help="Optional. Set the logging level", +) +def cli_migrate_session( + *, source_db_url: str, dest_db_url: str, log_level: str +): + """Migrates a session database to the latest schema version.""" + logs.setup_adk_logger(getattr(logging, log_level.upper())) + try: + migration_runner.upgrade(source_db_url, dest_db_url) + click.secho("Migration check and upgrade process finished.", fg="green") + except Exception as e: + click.secho(f"Migration failed: {e}", fg="red", err=True) + + @deploy.command("agent_engine") @click.option( "--api_key", diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index a3529182..1576151f 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -19,18 +19,16 @@ from datetime import datetime from datetime import timezone import json import logging -import pickle from typing import Any from typing import Optional import uuid -from google.genai import types -from sqlalchemy import Boolean from sqlalchemy import delete from sqlalchemy import Dialect from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func +from sqlalchemy import inspect from sqlalchemy import select from sqlalchemy import Text from sqlalchemy.dialects import mysql @@ -41,14 +39,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.inspection import inspect from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship -from sqlalchemy.schema import MetaData from sqlalchemy.types import DateTime -from sqlalchemy.types import PickleType from sqlalchemy.types import String from sqlalchemy.types import TypeDecorator from typing_extensions import override @@ -57,10 +52,10 @@ from tzlocal import get_localzone from . import _session_util from ..errors.already_exists_error import AlreadyExistsError from ..events.event import Event -from ..events.event_actions import EventActions from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse +from .migration import _schema_check from .session import Session from .state import State @@ -111,41 +106,22 @@ class PreciseTimestamp(TypeDecorator): return self.impl -class DynamicPickleType(TypeDecorator): - """Represents a type that can be pickled.""" - - impl = PickleType - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.LONGBLOB) - if dialect.name == "spanner+spanner": - from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType - - return dialect.type_descriptor(SpannerPickleType) - return self.impl - - def process_bind_param(self, value, dialect): - """Ensures the pickled value is a bytes object before passing it to the database dialect.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.dumps(value) - return value - - def process_result_value(self, value, dialect): - """Ensures the raw bytes from the database are unpickled back into a Python object.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) - return value - - class Base(DeclarativeBase): """Base class for database tables.""" pass +class StorageMetadata(Base): + """Represents internal metadata stored in the database.""" + + __tablename__ = "adk_internal_metadata" + key: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + + class StorageSession(Base): """Represents a session stored in the database.""" @@ -237,46 +213,10 @@ class StorageEvent(Base): ) invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) - long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( - Text, nullable=True - ) - branch: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) timestamp: Mapped[PreciseTimestamp] = mapped_column( PreciseTimestamp, default=func.now() ) - - # === Fields from llm_response.py === - content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) - grounding_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - custom_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - usage_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - citation_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - - partial: Mapped[bool] = mapped_column(Boolean, nullable=True) - turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) - error_code: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - error_message: Mapped[str] = mapped_column(String(1024), nullable=True) - interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) - input_transcription: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - output_transcription: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) + event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON) storage_session: Mapped[StorageSession] = relationship( "StorageSession", @@ -291,102 +231,27 @@ class StorageEvent(Base): ), ) - @property - def long_running_tool_ids(self) -> set[str]: - return ( - set(json.loads(self.long_running_tool_ids_json)) - if self.long_running_tool_ids_json - else set() - ) - - @long_running_tool_ids.setter - def long_running_tool_ids(self, value: set[str]): - if value is None: - self.long_running_tool_ids_json = None - else: - self.long_running_tool_ids_json = json.dumps(list(value)) - @classmethod def from_event(cls, session: Session, event: Event) -> StorageEvent: - storage_event = StorageEvent( + """Creates a StorageEvent from an Event.""" + return StorageEvent( id=event.id, invocation_id=event.invocation_id, - author=event.author, - branch=event.branch, - actions=event.actions, session_id=session.id, app_name=session.app_name, user_id=session.user_id, timestamp=datetime.fromtimestamp(event.timestamp), - long_running_tool_ids=event.long_running_tool_ids, - partial=event.partial, - turn_complete=event.turn_complete, - error_code=event.error_code, - error_message=event.error_message, - interrupted=event.interrupted, + event_data=event.model_dump(exclude_none=True, mode="json"), ) - if event.content: - storage_event.content = event.content.model_dump( - exclude_none=True, mode="json" - ) - if event.grounding_metadata: - storage_event.grounding_metadata = event.grounding_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.custom_metadata: - storage_event.custom_metadata = event.custom_metadata - if event.usage_metadata: - storage_event.usage_metadata = event.usage_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.citation_metadata: - storage_event.citation_metadata = event.citation_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.input_transcription: - storage_event.input_transcription = event.input_transcription.model_dump( - exclude_none=True, mode="json" - ) - if event.output_transcription: - storage_event.output_transcription = ( - event.output_transcription.model_dump(exclude_none=True, mode="json") - ) - return storage_event def to_event(self) -> Event: - return Event( - id=self.id, - invocation_id=self.invocation_id, - author=self.author, - branch=self.branch, - # This is needed as previous ADK version pickled actions might not have - # value defined in the current version of the EventActions model. - actions=EventActions().model_copy(update=self.actions.model_dump()), - timestamp=self.timestamp.timestamp(), - long_running_tool_ids=self.long_running_tool_ids, - partial=self.partial, - turn_complete=self.turn_complete, - error_code=self.error_code, - error_message=self.error_message, - interrupted=self.interrupted, - custom_metadata=self.custom_metadata, - content=_session_util.decode_model(self.content, types.Content), - grounding_metadata=_session_util.decode_model( - self.grounding_metadata, types.GroundingMetadata - ), - usage_metadata=_session_util.decode_model( - self.usage_metadata, types.GenerateContentResponseUsageMetadata - ), - citation_metadata=_session_util.decode_model( - self.citation_metadata, types.CitationMetadata - ), - input_transcription=_session_util.decode_model( - self.input_transcription, types.Transcription - ), - output_transcription=_session_util.decode_model( - self.output_transcription, types.Transcription - ), - ) + """Converts the StorageEvent to an Event.""" + return Event.model_validate({ + **self.event_data, + "id": self.id, + "invocation_id": self.invocation_id, + "timestamp": self.timestamp.timestamp(), + }) class StorageAppState(Base): @@ -463,7 +328,6 @@ class DatabaseSessionService(BaseSessionService): logger.info("Local timezone: %s", local_timezone) self.db_engine: AsyncEngine = db_engine - self.metadata: MetaData = MetaData() # DB session factory method self.database_session_factory: async_sessionmaker[ @@ -483,10 +347,46 @@ class DatabaseSessionService(BaseSessionService): async with self._table_creation_lock: # Double-check after acquiring the lock if not self._tables_created: + # Check schema version BEFORE creating tables. + # This prevents creating metadata table on a v0.1 DB. + async with self.database_session_factory() as sql_session: + version, is_v01 = await sql_session.run_sync( + _schema_check.get_version_and_v01_status_sync + ) + + if is_v01: + raise RuntimeError( + "Database schema appears to be v0.1, but" + f" {_schema_check.CURRENT_SCHEMA_VERSION} is required. Please" + " migrate the database using 'adk migrate session'." + ) + elif version and version < _schema_check.CURRENT_SCHEMA_VERSION: + raise RuntimeError( + f"Database schema version is {version}, but current version is" + f" {_schema_check.CURRENT_SCHEMA_VERSION}. Please migrate" + " the database to the latest version using 'adk migrate" + " session'." + ) + async with self.db_engine.begin() as conn: # Uncomment to recreate DB every time # await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) + + # If we are here, DB is either new or >= current version. + # If new or without metadata row, stamp it as current version. + async with self.database_session_factory() as sql_session: + metadata = await sql_session.get( + StorageMetadata, _schema_check.SCHEMA_VERSION_KEY + ) + if not metadata: + sql_session.add( + StorageMetadata( + key=_schema_check.SCHEMA_VERSION_KEY, + value=_schema_check.CURRENT_SCHEMA_VERSION, + ) + ) + await sql_session.commit() self._tables_created = True @override @@ -723,7 +623,9 @@ class DatabaseSessionService(BaseSessionService): storage_session.state = storage_session.state | session_state_delta if storage_session._dialect_name == "sqlite": - update_time = datetime.utcfromtimestamp(event.timestamp) + update_time = datetime.fromtimestamp( + event.timestamp, timezone.utc + ).replace(tzinfo=None) else: update_time = datetime.fromtimestamp(event.timestamp) storage_session.update_time = update_time diff --git a/src/google/adk/sessions/migration/_schema_check.py b/src/google/adk/sessions/migration/_schema_check.py new file mode 100644 index 00000000..f6fdc599 --- /dev/null +++ b/src/google/adk/sessions/migration/_schema_check.py @@ -0,0 +1,114 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Database schema version check utility.""" + +from __future__ import annotations + +import logging + +import sqlalchemy +from sqlalchemy import create_engine as create_sync_engine +from sqlalchemy import inspect +from sqlalchemy import text + +logger = logging.getLogger("google_adk." + __name__) + +SCHEMA_VERSION_KEY = "schema_version" +SCHEMA_VERSION_0_1_PICKLE = "0.1" +SCHEMA_VERSION_1_0_JSON = "1.0" +CURRENT_SCHEMA_VERSION = "1.0" + + +def _to_sync_url(db_url: str) -> str: + """Removes +driver from SQLAlchemy URL.""" + if "://" in db_url: + scheme, _, rest = db_url.partition("://") + if "+" in scheme: + dialect = scheme.split("+", 1)[0] + return f"{dialect}://{rest}" + return db_url + + +def get_version_and_v01_status_sync( + sess: sqlalchemy.orm.Session, +) -> tuple[str | None, bool]: + """Returns (version, is_v01) inspecting the database.""" + inspector = sqlalchemy.inspect(sess.get_bind()) + if inspector.has_table("adk_internal_metadata"): + try: + result = sess.execute( + text("SELECT value FROM adk_internal_metadata WHERE key = :key"), + {"key": SCHEMA_VERSION_KEY}, + ).fetchone() + # If table exists, with or without key, it's 1.0 or newer. + return (result[0] if result else SCHEMA_VERSION_1_0_JSON), False + except Exception as e: + logger.warning( + "Could not read from adk_internal_metadata: %s. Assuming v1.0.", + e, + ) + return SCHEMA_VERSION_1_0_JSON, False + + if inspector.has_table("events"): + try: + cols = {c["name"] for c in inspector.get_columns("events")} + if "actions" in cols and "event_data" not in cols: + return None, True # 0.1 schema + except Exception as e: + logger.warning("Could not inspect 'events' table columns: %s", e) + return None, False # New DB + + +def get_db_schema_version(db_url: str) -> str | None: + """Reads schema version from DB. + + Checks metadata table first, falls back to table structure for 0.1 vs 1.0. + """ + engine = None + try: + engine = create_sync_engine(_to_sync_url(db_url)) + inspector = inspect(engine) + + if inspector.has_table("adk_internal_metadata"): + with engine.connect() as connection: + result = connection.execute( + text("SELECT value FROM adk_internal_metadata WHERE key = :key"), + parameters={"key": SCHEMA_VERSION_KEY}, + ).fetchone() + # If table exists, with or without key, it's 1.0 or newer. + return result[0] if result else SCHEMA_VERSION_1_0_JSON + + # Metadata table doesn't exist, check for 0.1 schema. + # 0.1 schema has an 'events' table with an 'actions' column. + if inspector.has_table("events"): + try: + cols = {c["name"] for c in inspector.get_columns("events")} + if "actions" in cols and "event_data" not in cols: + return SCHEMA_VERSION_0_1_PICKLE + except Exception as e: + logger.warning("Could not inspect 'events' table columns: %s", e) + + # If no metadata table and not identifiable as 0.1, + # assume it is a new/empty DB requiring schema 1.0. + return SCHEMA_VERSION_1_0_JSON + except Exception as e: + logger.info( + "Could not determine schema version by inspecting database: %s." + " Assuming v1.0.", + e, + ) + return SCHEMA_VERSION_1_0_JSON + finally: + if engine: + engine.dispose() diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py new file mode 100644 index 00000000..f33ef3f5 --- /dev/null +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py @@ -0,0 +1,492 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Migration script from SQLAlchemy DB with Pickle Events to JSON schema.""" + +from __future__ import annotations + +import argparse +from datetime import datetime +from datetime import timezone +import json +import logging +import pickle +import sys +from typing import Any +from typing import Optional + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions import _session_util +from google.adk.sessions import database_session_service as dss +from google.adk.sessions.migration import _schema_check +from google.genai import types +import sqlalchemy +from sqlalchemy import Boolean +from sqlalchemy import create_engine +from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import func +from sqlalchemy import text +from sqlalchemy import Text +from sqlalchemy.dialects import mysql +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import sessionmaker +from sqlalchemy.types import PickleType +from sqlalchemy.types import String +from sqlalchemy.types import TypeDecorator + +logger = logging.getLogger("google_adk." + __name__) + + +# --- Old Schema Definitions --- +class DynamicPickleType(TypeDecorator): + """Represents a type that can be pickled.""" + + impl = PickleType + + def load_dialect_impl(self, dialect): + if dialect.name == "mysql": + return dialect.type_descriptor(mysql.LONGBLOB) + if dialect.name == "spanner+spanner": + from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType + + return dialect.type_descriptor(SpannerPickleType) + return self.impl + + def process_bind_param(self, value, dialect): + """Ensures the pickled value is a bytes object before passing it to the database dialect.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.dumps(value) + return value + + def process_result_value(self, value, dialect): + """Ensures the raw bytes from the database are unpickled back into a Python object.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.loads(value) + return value + + +class OldBase(DeclarativeBase): + pass + + +class OldStorageSession(OldBase): + __tablename__ = "sessions" + app_name: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(dss.DynamicJSON), default={} + ) + create_time: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now() + ) + update_time: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +class OldStorageEvent(OldBase): + """Old storage event with pickle.""" + + __tablename__ = "events" + id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + invocation_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_VARCHAR_LENGTH) + ) + author: Mapped[str] = mapped_column(String(dss.DEFAULT_MAX_VARCHAR_LENGTH)) + actions: Mapped[Any] = mapped_column(DynamicPickleType) + long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( + Text, nullable=True + ) + branch: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + timestamp: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now() + ) + content: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + grounding_metadata: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + custom_metadata: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + usage_metadata: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + citation_metadata: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + partial: Mapped[bool] = mapped_column(Boolean, nullable=True) + turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) + error_code: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + error_message: Mapped[str] = mapped_column(String(1024), nullable=True) + interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) + input_transcription: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + output_transcription: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + __table_args__ = ( + ForeignKeyConstraint( + ["app_name", "user_id", "session_id"], + ["sessions.app_name", "sessions.user_id", "sessions.id"], + ondelete="CASCADE", + ), + ) + + @property + def long_running_tool_ids(self) -> set[str]: + return ( + set(json.loads(self.long_running_tool_ids_json)) + if self.long_running_tool_ids_json + else set() + ) + + +def _to_datetime_obj(val: Any) -> datetime | Any: + """Converts string to datetime if needed.""" + if isinstance(val, str): + try: + return datetime.strptime(val, "%Y-%m-%d %H:%M:%S.%f") + except ValueError: + try: + return datetime.strptime(val, "%Y-%m-%d %H:%M:%S") + except ValueError: + pass # return as is if not matching format + return val + + +def _row_to_event(row: dict) -> Event: + """Converts event row (dict) to event object, handling missing columns and deserializing.""" + + actions_val = row.get("actions") + actions = None + if actions_val is not None: + try: + if isinstance(actions_val, bytes): + actions = pickle.loads(actions_val) + else: # for spanner - it might return object directly + actions = actions_val + except Exception as e: + logger.warning( + f"Failed to unpickle actions for event {row.get('id')}: {e}" + ) + actions = None + + if actions and hasattr(actions, "model_dump"): + actions = EventActions().model_copy(update=actions.model_dump()) + elif isinstance(actions, dict): + actions = EventActions(**actions) + else: + actions = EventActions() + + def _safe_json_load(val): + data = None + if isinstance(val, str): + try: + data = json.loads(val) + except json.JSONDecodeError: + logger.warning(f"Failed to decode JSON for event {row.get('id')}") + return None + elif isinstance(val, dict): + data = val # for postgres JSONB + return data + + content_dict = _safe_json_load(row.get("content")) + grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata")) + custom_metadata_dict = _safe_json_load(row.get("custom_metadata")) + usage_metadata_dict = _safe_json_load(row.get("usage_metadata")) + citation_metadata_dict = _safe_json_load(row.get("citation_metadata")) + input_transcription_dict = _safe_json_load(row.get("input_transcription")) + output_transcription_dict = _safe_json_load(row.get("output_transcription")) + + long_running_tool_ids_json = row.get("long_running_tool_ids_json") + long_running_tool_ids = set() + if long_running_tool_ids_json: + try: + long_running_tool_ids = set(json.loads(long_running_tool_ids_json)) + except json.JSONDecodeError: + logger.warning( + "Failed to decode long_running_tool_ids_json for event" + f" {row.get('id')}" + ) + long_running_tool_ids = set() + + event_id = row.get("id") + if not event_id: + raise ValueError("Event must have an id.") + timestamp = _to_datetime_obj(row.get("timestamp")) + if not timestamp: + raise ValueError(f"Event {event_id} must have a timestamp.") + + return Event( + id=event_id, + invocation_id=row.get("invocation_id", ""), + author=row.get("author", "agent"), + branch=row.get("branch"), + actions=actions, + timestamp=timestamp.replace(tzinfo=timezone.utc).timestamp(), + long_running_tool_ids=long_running_tool_ids, + partial=row.get("partial"), + turn_complete=row.get("turn_complete"), + error_code=row.get("error_code"), + error_message=row.get("error_message"), + interrupted=row.get("interrupted"), + custom_metadata=custom_metadata_dict, + content=_session_util.decode_model(content_dict, types.Content), + grounding_metadata=_session_util.decode_model( + grounding_metadata_dict, types.GroundingMetadata + ), + usage_metadata=_session_util.decode_model( + usage_metadata_dict, types.GenerateContentResponseUsageMetadata + ), + citation_metadata=_session_util.decode_model( + citation_metadata_dict, types.CitationMetadata + ), + input_transcription=_session_util.decode_model( + input_transcription_dict, types.Transcription + ), + output_transcription=_session_util.decode_model( + output_transcription_dict, types.Transcription + ), + ) + + +def _get_state_dict(state_val: Any) -> dict: + """Safely load dict from JSON string or return dict if already dict.""" + if isinstance(state_val, dict): + return state_val + if isinstance(state_val, str): + try: + return json.loads(state_val) + except json.JSONDecodeError: + logger.warning( + "Failed to parse state JSON string, defaulting to empty dict." + ) + return {} + return {} + + +class OldStorageAppState(OldBase): + __tablename__ = "app_states" + app_name: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(dss.DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +class OldStorageUserState(OldBase): + __tablename__ = "user_states" + app_name: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(dss.DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +# --- Migration Logic --- +def migrate(source_db_url: str, dest_db_url: str): + """Migrates data from old pickle schema to new JSON schema.""" + logger.info(f"Connecting to source database: {source_db_url}") + try: + source_engine = create_engine(source_db_url) + SourceSession = sessionmaker(bind=source_engine) + except Exception as e: + logger.error(f"Failed to connect to source database: {e}") + raise RuntimeError(f"Failed to connect to source database: {e}") from e + + logger.info(f"Connecting to destination database: {dest_db_url}") + try: + dest_engine = create_engine(dest_db_url) + dss.Base.metadata.create_all(dest_engine) + DestSession = sessionmaker(bind=dest_engine) + except Exception as e: + logger.error(f"Failed to connect to destination database: {e}") + raise RuntimeError(f"Failed to connect to destination database: {e}") from e + + with SourceSession() as source_session, DestSession() as dest_session: + dest_session.merge( + dss.StorageMetadata( + key=_schema_check.SCHEMA_VERSION_KEY, + value=_schema_check.SCHEMA_VERSION_1_0_JSON, + ) + ) + dest_session.commit() + try: + inspector = sqlalchemy.inspect(source_engine) + + logger.info("Migrating app_states...") + if inspector.has_table("app_states"): + rows = ( + source_session.execute(text("SELECT * FROM app_states")) + .mappings() + .all() + ) + for row in rows: + dest_session.merge( + dss.StorageAppState( + app_name=row["app_name"], + state=_get_state_dict(row.get("state")), + update_time=_to_datetime_obj(row["update_time"]), + ) + ) + dest_session.commit() + logger.info(f"Migrated {len(rows)} app_states.") + else: + logger.info("No 'app_states' table found in source db.") + + logger.info("Migrating user_states...") + if inspector.has_table("user_states"): + rows = ( + source_session.execute(text("SELECT * FROM user_states")) + .mappings() + .all() + ) + for row in rows: + dest_session.merge( + dss.StorageUserState( + app_name=row["app_name"], + user_id=row["user_id"], + state=_get_state_dict(row.get("state")), + update_time=_to_datetime_obj(row["update_time"]), + ) + ) + dest_session.commit() + logger.info(f"Migrated {len(rows)} user_states.") + else: + logger.info("No 'user_states' table found in source db.") + + logger.info("Migrating sessions...") + if inspector.has_table("sessions"): + rows = ( + source_session.execute(text("SELECT * FROM sessions")) + .mappings() + .all() + ) + for row in rows: + dest_session.merge( + dss.StorageSession( + app_name=row["app_name"], + user_id=row["user_id"], + id=row["id"], + state=_get_state_dict(row.get("state")), + create_time=_to_datetime_obj(row["create_time"]), + update_time=_to_datetime_obj(row["update_time"]), + ) + ) + dest_session.commit() + logger.info(f"Migrated {len(rows)} sessions.") + else: + logger.info("No 'sessions' table found in source db.") + + logger.info("Migrating events...") + events = [] + if inspector.has_table("events"): + rows = ( + source_session.execute(text("SELECT * FROM events")) + .mappings() + .all() + ) + for row in rows: + try: + event_obj = _row_to_event(dict(row)) + new_event = dss.StorageEvent( + id=event_obj.id, + app_name=row["app_name"], + user_id=row["user_id"], + session_id=row["session_id"], + invocation_id=event_obj.invocation_id, + timestamp=datetime.fromtimestamp( + event_obj.timestamp, timezone.utc + ).replace(tzinfo=None), + event_data=event_obj.model_dump(mode="json", exclude_none=True), + ) + dest_session.merge(new_event) + events.append(new_event) + except Exception as e: + logger.warning( + f"Failed to migrate event row {row.get('id', 'N/A')}: {e}" + ) + dest_session.commit() + logger.info(f"Migrated {len(events)} events.") + else: + logger.info("No 'events' table found in source database.") + + logger.info("Migration completed successfully.") + except Exception as e: + logger.error(f"An error occurred during migration: {e}", exc_info=True) + dest_session.rollback() + raise RuntimeError(f"An error occurred during migration: {e}") from e + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=( + "Migrate ADK sessions from SQLAlchemy Pickle format to JSON format." + ) + ) + parser.add_argument( + "--source_db_url", required=True, help="SQLAlchemy URL of source database" + ) + parser.add_argument( + "--dest_db_url", + required=True, + help="SQLAlchemy URL of destination database", + ) + args = parser.parse_args() + try: + migrate(args.source_db_url, args.dest_db_url) + except Exception as e: + logger.error(f"Migration failed: {e}") + sys.exit(1) diff --git a/src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py similarity index 100% rename from src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py rename to src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py diff --git a/src/google/adk/sessions/migration/migration_runner.py b/src/google/adk/sessions/migration/migration_runner.py new file mode 100644 index 00000000..d7abbe41 --- /dev/null +++ b/src/google/adk/sessions/migration/migration_runner.py @@ -0,0 +1,128 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Migration runner to upgrade schemas to the latest version.""" + +from __future__ import annotations + +import logging +import os +import tempfile + +from google.adk.sessions.migration import _schema_check +from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle + +logger = logging.getLogger("google_adk." + __name__) + +# Migration map where key is start_version and value is +# (end_version, migration_function). +# Each key is a schema version, and its value is a tuple containing: +# (the schema version AFTER this migration step, the migration function to run). +# The migration function should accept (source_db_url, dest_db_url) as +# arguments. +MIGRATIONS = { + _schema_check.SCHEMA_VERSION_0_1_PICKLE: ( + _schema_check.SCHEMA_VERSION_1_0_JSON, + migrate_from_sqlalchemy_pickle.migrate, + ), +} +# The most recent schema version. The migration process stops once this version +# is reached. +LATEST_VERSION = _schema_check.CURRENT_SCHEMA_VERSION + + +def upgrade(source_db_url: str, dest_db_url: str): + """Migrates a database from its current version to the latest version. + + If the source database schema is older than the latest version, this + function applies migration scripts sequentially until the schema reaches the + LATEST_VERSION. + + If multiple migration steps are required, intermediate results are stored in + temporary SQLite database files. This means a multi-step migration + between other database types (e.g. PostgreSQL to PostgreSQL) will use + SQLite for intermediate steps. + + In-place migration (source_db_url == dest_db_url) is not supported, + as migrations always read from a source and write to a destination. + + Args: + source_db_url: The SQLAlchemy URL of the database to migrate from. + dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be + different from source_db_url. + + Raises: + RuntimeError: If source_db_url and dest_db_url are the same, or if no + migration path is found. + """ + current_version = _schema_check.get_db_schema_version(source_db_url) + + if current_version == LATEST_VERSION: + logger.info( + f"Database {source_db_url} is already at latest version" + f" {LATEST_VERSION}. No migration needed." + ) + return + + if source_db_url == dest_db_url: + raise RuntimeError( + "In-place migration is not supported. " + "Please provide a different file for dest_db_url." + ) + + # Build the list of migration steps required to reach LATEST_VERSION. + migrations_to_run = [] + ver = current_version + while ver in MIGRATIONS and ver != LATEST_VERSION: + migrations_to_run.append(MIGRATIONS[ver]) + ver = MIGRATIONS[ver][0] + + if not migrations_to_run: + raise RuntimeError( + "Could not find migration path for schema version" + f" {current_version} to {LATEST_VERSION}." + ) + + temp_files = [] + in_url = source_db_url + try: + for i, (end_version, migrate_func) in enumerate(migrations_to_run): + is_last_step = i == len(migrations_to_run) - 1 + + if is_last_step: + out_url = dest_db_url + else: + # For intermediate steps, create a temporary SQLite DB to store the + # result. + fd, temp_path = tempfile.mkstemp(suffix=".db") + os.close(fd) + out_url = f"sqlite:///{temp_path}" + temp_files.append(temp_path) + logger.debug(f"Created temp db {out_url} for step {i+1}") + + logger.info( + f"Migrating from {in_url} to {out_url} (schema {end_version})..." + ) + migrate_func(in_url, out_url) + logger.info(f"Finished migration step to schema {end_version}.") + # The output of this step becomes the input for the next step. + in_url = out_url + finally: + # Ensure temporary files are cleaned up even if migration fails. + # Cleanup temp files + for path in temp_files: + try: + os.remove(path) + logger.debug(f"Removed temp db {path}") + except OSError as e: + logger.warning(f"Failed to remove temp db file {path}: {e}") diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 8ba6531f..e0d44b38 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -107,7 +107,7 @@ class SqliteSessionService(BaseSessionService): f"Database {db_path} seems to use an old schema." " Please run the migration command to" " migrate it to the new schema. Example: `python -m" - " google.adk.sessions.migrate_from_sqlalchemy_sqlite" + " google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite" f" --source_db_path {db_path} --dest_db_path" f" {db_path}.new` then backup {db_path} and rename" f" {db_path}.new to {db_path}." diff --git a/tests/unittests/sessions/migration/test_migrations.py b/tests/unittests/sessions/migration/test_migrations.py new file mode 100644 index 00000000..938387d2 --- /dev/null +++ b/tests/unittests/sessions/migration/test_migrations.py @@ -0,0 +1,106 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for migration scripts.""" + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone + +from google.adk.events.event_actions import EventActions +from google.adk.sessions import database_session_service as dss +from google.adk.sessions.migration import _schema_check +from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + + +def test_migrate_from_sqlalchemy_pickle(tmp_path): + """Tests for migrate_from_sqlalchemy_pickle.""" + source_db_path = tmp_path / "source_pickle.db" + dest_db_path = tmp_path / "dest_json.db" + source_db_url = f"sqlite:///{source_db_path}" + dest_db_url = f"sqlite:///{dest_db_path}" + + # Setup source DB with old pickle schema + source_engine = create_engine(source_db_url) + mfsp.OldBase.metadata.create_all(source_engine) + SourceSession = sessionmaker(bind=source_engine) + source_session = SourceSession() + + # Populate source data + now = datetime.now(timezone.utc) + app_state = mfsp.OldStorageAppState( + app_name="app1", state={"akey": 1}, update_time=now + ) + user_state = mfsp.OldStorageUserState( + app_name="app1", user_id="user1", state={"ukey": 2}, update_time=now + ) + session = mfsp.OldStorageSession( + app_name="app1", + user_id="user1", + id="session1", + state={"skey": 3}, + create_time=now, + update_time=now, + ) + event = mfsp.OldStorageEvent( + id="event1", + app_name="app1", + user_id="user1", + session_id="session1", + invocation_id="invoke1", + author="user", + actions=EventActions(state_delta={"skey": 4}), + timestamp=now, + ) + source_session.add_all([app_state, user_state, session, event]) + source_session.commit() + source_session.close() + + mfsp.migrate(source_db_url, dest_db_url) + + # Verify destination DB + dest_engine = create_engine(dest_db_url) + DestSession = sessionmaker(bind=dest_engine) + dest_session = DestSession() + + metadata = dest_session.query(dss.StorageMetadata).first() + assert metadata is not None + assert metadata.key == _schema_check.SCHEMA_VERSION_KEY + assert metadata.value == _schema_check.SCHEMA_VERSION_1_0_JSON + + app_state_res = dest_session.query(dss.StorageAppState).first() + assert app_state_res is not None + assert app_state_res.app_name == "app1" + assert app_state_res.state == {"akey": 1} + + user_state_res = dest_session.query(dss.StorageUserState).first() + assert user_state_res is not None + assert user_state_res.user_id == "user1" + assert user_state_res.state == {"ukey": 2} + + session_res = dest_session.query(dss.StorageSession).first() + assert session_res is not None + assert session_res.id == "session1" + assert session_res.state == {"skey": 3} + + event_res = dest_session.query(dss.StorageEvent).first() + assert event_res is not None + assert event_res.id == "event1" + assert "state_delta" in event_res.event_data["actions"] + assert event_res.event_data["actions"]["state_delta"] == {"skey": 4} + + dest_session.close() diff --git a/tests/unittests/sessions/test_dynamic_pickle_type.py b/tests/unittests/sessions/test_dynamic_pickle_type.py deleted file mode 100644 index e4eb084f..00000000 --- a/tests/unittests/sessions/test_dynamic_pickle_type.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import pickle -from unittest import mock - -from google.adk.sessions.database_session_service import DynamicPickleType -import pytest -from sqlalchemy import create_engine -from sqlalchemy.dialects import mysql - - -@pytest.fixture -def pickle_type(): - """Fixture for DynamicPickleType instance.""" - return DynamicPickleType() - - -def test_load_dialect_impl_mysql(pickle_type): - """Test that MySQL dialect uses LONGBLOB.""" - # Mock the MySQL dialect - mock_dialect = mock.Mock() - mock_dialect.name = "mysql" - - # Mock the return value of type_descriptor - mock_longblob_type = mock.Mock() - mock_dialect.type_descriptor.return_value = mock_longblob_type - - impl = pickle_type.load_dialect_impl(mock_dialect) - - # Verify type_descriptor was called once with mysql.LONGBLOB - mock_dialect.type_descriptor.assert_called_once_with(mysql.LONGBLOB) - # Verify the return value is what we expect - assert impl == mock_longblob_type - - -def test_load_dialect_impl_spanner(pickle_type): - """Test that Spanner dialect uses SpannerPickleType.""" - # Mock the spanner dialect - mock_dialect = mock.Mock() - mock_dialect.name = "spanner+spanner" - - with mock.patch( - "google.cloud.sqlalchemy_spanner.sqlalchemy_spanner.SpannerPickleType" - ) as mock_spanner_type: - pickle_type.load_dialect_impl(mock_dialect) - mock_dialect.type_descriptor.assert_called_once_with(mock_spanner_type) - - -def test_load_dialect_impl_default(pickle_type): - """Test that other dialects use default PickleType.""" - engine = create_engine("sqlite:///:memory:") - dialect = engine.dialect - impl = pickle_type.load_dialect_impl(dialect) - # Should return the default impl (PickleType) - assert impl == pickle_type.impl - - -@pytest.mark.parametrize( - "dialect_name", - [ - pytest.param("mysql", id="mysql"), - pytest.param("spanner+spanner", id="spanner"), - ], -) -def test_process_bind_param_pickle_dialects(pickle_type, dialect_name): - """Test that MySQL and Spanner dialects pickle the value.""" - mock_dialect = mock.Mock() - mock_dialect.name = dialect_name - - test_data = {"key": "value", "nested": [1, 2, 3]} - result = pickle_type.process_bind_param(test_data, mock_dialect) - - # Should be pickled bytes - assert isinstance(result, bytes) - # Should be able to unpickle back to original - assert pickle.loads(result) == test_data - - -def test_process_bind_param_default(pickle_type): - """Test that other dialects return value as-is.""" - mock_dialect = mock.Mock() - mock_dialect.name = "sqlite" - - test_data = {"key": "value"} - result = pickle_type.process_bind_param(test_data, mock_dialect) - - # Should return value unchanged (SQLAlchemy's PickleType handles it) - assert result == test_data - - -def test_process_bind_param_none(pickle_type): - """Test that None values are handled correctly.""" - mock_dialect = mock.Mock() - mock_dialect.name = "mysql" - - result = pickle_type.process_bind_param(None, mock_dialect) - assert result is None - - -@pytest.mark.parametrize( - "dialect_name", - [ - pytest.param("mysql", id="mysql"), - pytest.param("spanner+spanner", id="spanner"), - ], -) -def test_process_result_value_pickle_dialects(pickle_type, dialect_name): - """Test that MySQL and Spanner dialects unpickle the value.""" - mock_dialect = mock.Mock() - mock_dialect.name = dialect_name - - test_data = {"key": "value", "nested": [1, 2, 3]} - pickled_data = pickle.dumps(test_data) - - result = pickle_type.process_result_value(pickled_data, mock_dialect) - - # Should be unpickled back to original - assert result == test_data - - -def test_process_result_value_default(pickle_type): - """Test that other dialects return value as-is.""" - mock_dialect = mock.Mock() - mock_dialect.name = "sqlite" - - test_data = {"key": "value"} - result = pickle_type.process_result_value(test_data, mock_dialect) - - # Should return value unchanged (SQLAlchemy's PickleType handles it) - assert result == test_data - - -def test_process_result_value_none(pickle_type): - """Test that None values are handled correctly.""" - mock_dialect = mock.Mock() - mock_dialect.name = "mysql" - - result = pickle_type.process_result_value(None, mock_dialect) - assert result is None - - -@pytest.mark.parametrize( - "dialect_name", - [ - pytest.param("mysql", id="mysql"), - pytest.param("spanner+spanner", id="spanner"), - ], -) -def test_roundtrip_pickle_dialects(pickle_type, dialect_name): - """Test full roundtrip for MySQL and Spanner: bind -> result.""" - mock_dialect = mock.Mock() - mock_dialect.name = dialect_name - - original_data = { - "string": "test", - "number": 42, - "list": [1, 2, 3], - "nested": {"a": 1, "b": 2}, - } - - # Simulate bind (Python -> DB) - bound_value = pickle_type.process_bind_param(original_data, mock_dialect) - assert isinstance(bound_value, bytes) - - # Simulate result (DB -> Python) - result_value = pickle_type.process_result_value(bound_value, mock_dialect) - assert result_value == original_data