You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Introduce new database schema for DatabaseSessionService
Part 1 of https://github.com/google/adk-python/discussions/3605. This change adds a new schema that uses JSON serialization to store Events data in the database. A new "adk_internal_metadata" table is also added to store information like schema version. Since we want to keep supporting existing DB, we fork from the original schema and call it "v0", while the new one is called "v1". The change is no-op for existing users. In later change, the new schema will be used for new databases, and migration scripts will be provided for existing databases. Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 844986248
This commit is contained in:
committed by
Copybara-Service
parent
a0885064b0
commit
7e6ef71eec
File diff suppressed because it is too large
Load Diff
@@ -22,8 +22,8 @@ import logging
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
from google.adk.sessions import database_session_service as dss
|
||||
from google.adk.sessions import sqlite_session_service as sss
|
||||
from google.adk.sessions.schemas import v0 as v0_schema
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -35,7 +35,9 @@ def migrate(source_db_url: str, dest_db_path: str):
|
||||
logger.info(f"Connecting to source database: {source_db_url}")
|
||||
try:
|
||||
engine = create_engine(source_db_url)
|
||||
dss.Base.metadata.create_all(engine) # Ensure tables exist for inspection
|
||||
v0_schema.Base.metadata.create_all(
|
||||
engine
|
||||
) # Ensure tables exist for inspection
|
||||
SourceSession = sessionmaker(bind=engine)
|
||||
source_session = SourceSession()
|
||||
except Exception as e:
|
||||
@@ -55,7 +57,7 @@ def migrate(source_db_url: str, dest_db_path: str):
|
||||
try:
|
||||
# Migrate app_states
|
||||
logger.info("Migrating app_states...")
|
||||
app_states = source_session.query(dss.StorageAppState).all()
|
||||
app_states = source_session.query(v0_schema.StorageAppState).all()
|
||||
for item in app_states:
|
||||
dest_cursor.execute(
|
||||
"INSERT INTO app_states (app_name, state, update_time) VALUES (?,"
|
||||
@@ -70,7 +72,7 @@ def migrate(source_db_url: str, dest_db_path: str):
|
||||
|
||||
# Migrate user_states
|
||||
logger.info("Migrating user_states...")
|
||||
user_states = source_session.query(dss.StorageUserState).all()
|
||||
user_states = source_session.query(v0_schema.StorageUserState).all()
|
||||
for item in user_states:
|
||||
dest_cursor.execute(
|
||||
"INSERT INTO user_states (app_name, user_id, state, update_time)"
|
||||
@@ -86,7 +88,7 @@ def migrate(source_db_url: str, dest_db_path: str):
|
||||
|
||||
# Migrate sessions
|
||||
logger.info("Migrating sessions...")
|
||||
sessions = source_session.query(dss.StorageSession).all()
|
||||
sessions = source_session.query(v0_schema.StorageSession).all()
|
||||
for item in sessions:
|
||||
dest_cursor.execute(
|
||||
"INSERT INTO sessions (app_name, user_id, id, state, create_time,"
|
||||
@@ -104,7 +106,7 @@ def migrate(source_db_url: str, dest_db_path: str):
|
||||
|
||||
# Migrate events
|
||||
logger.info("Migrating events...")
|
||||
events = source_session.query(dss.StorageEvent).all()
|
||||
events = source_session.query(v0_schema.StorageEvent).all()
|
||||
for item in events:
|
||||
try:
|
||||
event_obj = item.to_event()
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import Dialect
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import mysql
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.types import DateTime
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
DEFAULT_MAX_KEY_LENGTH = 128
|
||||
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
||||
|
||||
|
||||
class DynamicJSON(TypeDecorator):
|
||||
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases."""
|
||||
|
||||
impl = Text # Default implementation is TEXT
|
||||
|
||||
def load_dialect_impl(self, dialect: Dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(postgresql.JSONB)
|
||||
if dialect.name == "mysql":
|
||||
# Use LONGTEXT for MySQL to address the data too long issue
|
||||
return dialect.type_descriptor(mysql.LONGTEXT)
|
||||
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
||||
|
||||
def process_bind_param(self, value, dialect: Dialect):
|
||||
if value is not None:
|
||||
if dialect.name == "postgresql":
|
||||
return value # JSONB handles dict directly
|
||||
return json.dumps(value) # Serialize to JSON string for TEXT
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect: Dialect):
|
||||
if value is not None:
|
||||
if dialect.name == "postgresql":
|
||||
return value # JSONB returns dict directly
|
||||
else:
|
||||
return json.loads(value) # Deserialize from JSON string for TEXT
|
||||
return value
|
||||
|
||||
|
||||
class PreciseTimestamp(TypeDecorator):
|
||||
"""Represents a timestamp precise to the microsecond."""
|
||||
|
||||
impl = DateTime
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "mysql":
|
||||
return dialect.type_descriptor(mysql.DATETIME(fsp=6))
|
||||
return self.impl
|
||||
@@ -0,0 +1,373 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""V0 database schema for ADK versions from 1.19.0 to 1.21.0.
|
||||
|
||||
This module defines SQLAlchemy models for storing session and event data
|
||||
in a relational database with the EventActions object using pickle
|
||||
serialization. To migrate from the schemas in earlier ADK versions to this
|
||||
v0 schema, see
|
||||
https://github.com/google/adk-python/blob/main/docs/upgrading_from_1_22_0.md.
|
||||
|
||||
The latest schema is defined in `v1.py`. That module uses JSON serialization
|
||||
for the EventActions data as well as other fields in the `events` table. See
|
||||
https://github.com/google/adk-python/discussions/3605 for more details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
import json
|
||||
import pickle
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import mysql
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import PickleType
|
||||
from sqlalchemy.types import String
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from .. import _session_util
|
||||
from ...events.event import Event
|
||||
from ...events.event_actions import EventActions
|
||||
from ..session import Session
|
||||
from .shared import DEFAULT_MAX_KEY_LENGTH
|
||||
from .shared import DEFAULT_MAX_VARCHAR_LENGTH
|
||||
from .shared import DynamicJSON
|
||||
from .shared import PreciseTimestamp
|
||||
|
||||
|
||||
class DynamicPickleType(TypeDecorator):
|
||||
"""Represents a type that can be pickled."""
|
||||
|
||||
impl = PickleType
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "mysql":
|
||||
return dialect.type_descriptor(mysql.LONGBLOB)
|
||||
if dialect.name == "spanner+spanner":
|
||||
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType
|
||||
|
||||
return dialect.type_descriptor(SpannerPickleType)
|
||||
return self.impl
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Ensures the pickled value is a bytes object before passing it to the database dialect."""
|
||||
if value is not None:
|
||||
if dialect.name in ("spanner+spanner", "mysql"):
|
||||
return pickle.dumps(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
|
||||
if value is not None:
|
||||
if dialect.name in ("spanner+spanner", "mysql"):
|
||||
return pickle.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for v0 database tables."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageSession(Base):
|
||||
"""Represents a session stored in the database."""
|
||||
|
||||
__tablename__ = "sessions"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
|
||||
create_time: Mapped[datetime] = mapped_column(
|
||||
PreciseTimestamp, default=func.now()
|
||||
)
|
||||
update_time: Mapped[datetime] = mapped_column(
|
||||
PreciseTimestamp, default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
storage_events: Mapped[list[StorageEvent]] = relationship(
|
||||
"StorageEvent",
|
||||
back_populates="storage_session",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
|
||||
|
||||
@property
|
||||
def _dialect_name(self) -> Optional[str]:
|
||||
session = inspect(self).session
|
||||
return session.bind.dialect.name if session else None
|
||||
|
||||
@property
|
||||
def update_timestamp_tz(self) -> datetime:
|
||||
"""Returns the time zone aware update timestamp."""
|
||||
if self._dialect_name == "sqlite":
|
||||
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
|
||||
# object without timezone information. We need to convert it to UTC
|
||||
# manually.
|
||||
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
|
||||
return self.update_time.timestamp()
|
||||
|
||||
def to_session(
|
||||
self,
|
||||
state: dict[str, Any] | None = None,
|
||||
events: list[Event] | None = None,
|
||||
) -> Session:
|
||||
"""Converts the storage session to a session object."""
|
||||
if state is None:
|
||||
state = {}
|
||||
if events is None:
|
||||
events = []
|
||||
|
||||
return Session(
|
||||
app_name=self.app_name,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
state=state,
|
||||
events=events,
|
||||
last_update_time=self.update_timestamp_tz,
|
||||
)
|
||||
|
||||
|
||||
class StorageEvent(Base):
|
||||
"""Represents an event stored in the database."""
|
||||
|
||||
__tablename__ = "events"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
|
||||
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType)
|
||||
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
branch: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
timestamp: Mapped[PreciseTimestamp] = mapped_column(
|
||||
PreciseTimestamp, default=func.now()
|
||||
)
|
||||
|
||||
# === Fields from llm_response.py ===
|
||||
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
||||
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
|
||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
error_code: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||
)
|
||||
error_message: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
input_transcription: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
output_transcription: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
|
||||
storage_session: Mapped[StorageSession] = relationship(
|
||||
"StorageSession",
|
||||
back_populates="storage_events",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
["app_name", "user_id", "session_id"],
|
||||
["sessions.app_name", "sessions.user_id", "sessions.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def long_running_tool_ids(self) -> set[str]:
|
||||
return (
|
||||
set(json.loads(self.long_running_tool_ids_json))
|
||||
if self.long_running_tool_ids_json
|
||||
else set()
|
||||
)
|
||||
|
||||
@long_running_tool_ids.setter
|
||||
def long_running_tool_ids(self, value: set[str]):
|
||||
if value is None:
|
||||
self.long_running_tool_ids_json = None
|
||||
else:
|
||||
self.long_running_tool_ids_json = json.dumps(list(value))
|
||||
|
||||
@classmethod
|
||||
def from_event(cls, session: Session, event: Event) -> StorageEvent:
|
||||
storage_event = StorageEvent(
|
||||
id=event.id,
|
||||
invocation_id=event.invocation_id,
|
||||
author=event.author,
|
||||
branch=event.branch,
|
||||
actions=event.actions,
|
||||
session_id=session.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||
long_running_tool_ids=event.long_running_tool_ids,
|
||||
partial=event.partial,
|
||||
turn_complete=event.turn_complete,
|
||||
error_code=event.error_code,
|
||||
error_message=event.error_message,
|
||||
interrupted=event.interrupted,
|
||||
)
|
||||
if event.content:
|
||||
storage_event.content = event.content.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.grounding_metadata:
|
||||
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.custom_metadata:
|
||||
storage_event.custom_metadata = event.custom_metadata
|
||||
if event.usage_metadata:
|
||||
storage_event.usage_metadata = event.usage_metadata.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.citation_metadata:
|
||||
storage_event.citation_metadata = event.citation_metadata.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.input_transcription:
|
||||
storage_event.input_transcription = event.input_transcription.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.output_transcription:
|
||||
storage_event.output_transcription = (
|
||||
event.output_transcription.model_dump(exclude_none=True, mode="json")
|
||||
)
|
||||
return storage_event
|
||||
|
||||
def to_event(self) -> Event:
|
||||
return Event(
|
||||
id=self.id,
|
||||
invocation_id=self.invocation_id,
|
||||
author=self.author,
|
||||
branch=self.branch,
|
||||
# This is needed as previous ADK version pickled actions might not have
|
||||
# value defined in the current version of the EventActions model.
|
||||
actions=EventActions().model_copy(update=self.actions.model_dump()),
|
||||
timestamp=self.timestamp.timestamp(),
|
||||
long_running_tool_ids=self.long_running_tool_ids,
|
||||
partial=self.partial,
|
||||
turn_complete=self.turn_complete,
|
||||
error_code=self.error_code,
|
||||
error_message=self.error_message,
|
||||
interrupted=self.interrupted,
|
||||
custom_metadata=self.custom_metadata,
|
||||
content=_session_util.decode_model(self.content, types.Content),
|
||||
grounding_metadata=_session_util.decode_model(
|
||||
self.grounding_metadata, types.GroundingMetadata
|
||||
),
|
||||
usage_metadata=_session_util.decode_model(
|
||||
self.usage_metadata, types.GenerateContentResponseUsageMetadata
|
||||
),
|
||||
citation_metadata=_session_util.decode_model(
|
||||
self.citation_metadata, types.CitationMetadata
|
||||
),
|
||||
input_transcription=_session_util.decode_model(
|
||||
self.input_transcription, types.Transcription
|
||||
),
|
||||
output_transcription=_session_util.decode_model(
|
||||
self.output_transcription, types.Transcription
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class StorageAppState(Base):
|
||||
"""Represents an app state stored in the database."""
|
||||
|
||||
__tablename__ = "app_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[datetime] = mapped_column(
|
||||
PreciseTimestamp, default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class StorageUserState(Base):
|
||||
"""Represents a user state stored in the database."""
|
||||
|
||||
__tablename__ = "user_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[datetime] = mapped_column(
|
||||
PreciseTimestamp, default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -0,0 +1,239 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The v1 database schema for the DatabaseSessionService.
|
||||
|
||||
This module defines SQLAlchemy models for storing session and event data
|
||||
in a relational database with the "events" table using JSON
|
||||
serialization for Event data.
|
||||
|
||||
See https://github.com/google/adk-python/discussions/3605 for more details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import String
|
||||
|
||||
from ...events.event import Event
|
||||
from ..session import Session
|
||||
from .shared import DEFAULT_MAX_KEY_LENGTH
|
||||
from .shared import DEFAULT_MAX_VARCHAR_LENGTH
|
||||
from .shared import DynamicJSON
|
||||
from .shared import PreciseTimestamp
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for v1 database tables."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageMetadata(Base):
|
||||
"""Represents ADK internal metadata stored in the database.
|
||||
|
||||
This table is used to store internal information like the schema version.
|
||||
The DatabaseSessionService will populate and utilize this table to manage
|
||||
database compatibility and migrations.
|
||||
"""
|
||||
|
||||
__tablename__ = "adk_internal_metadata"
|
||||
key: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
|
||||
|
||||
class StorageSession(Base):
|
||||
"""Represents a session stored in the database."""
|
||||
|
||||
__tablename__ = "sessions"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
|
||||
create_time: Mapped[datetime] = mapped_column(
|
||||
PreciseTimestamp, default=func.now()
|
||||
)
|
||||
update_time: Mapped[datetime] = mapped_column(
|
||||
PreciseTimestamp, default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
storage_events: Mapped[list[StorageEvent]] = relationship(
|
||||
"StorageEvent",
|
||||
back_populates="storage_session",
|
||||
# Deleting a session will now automatically delete its associated events
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
|
||||
|
||||
@property
|
||||
def _dialect_name(self) -> Optional[str]:
|
||||
session = inspect(self).session
|
||||
return session.bind.dialect.name if session else None
|
||||
|
||||
@property
|
||||
def update_timestamp_tz(self) -> datetime:
|
||||
"""Returns the time zone aware update timestamp."""
|
||||
if self._dialect_name == "sqlite":
|
||||
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
|
||||
# object without timezone information. We need to convert it to UTC
|
||||
# manually.
|
||||
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
|
||||
return self.update_time.timestamp()
|
||||
|
||||
def to_session(
|
||||
self,
|
||||
state: dict[str, Any] | None = None,
|
||||
events: list[Event] | None = None,
|
||||
) -> Session:
|
||||
"""Converts the storage session to a session object."""
|
||||
if state is None:
|
||||
state = {}
|
||||
if events is None:
|
||||
events = []
|
||||
|
||||
return Session(
|
||||
app_name=self.app_name,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
state=state,
|
||||
events=events,
|
||||
last_update_time=self.update_timestamp_tz,
|
||||
)
|
||||
|
||||
|
||||
class StorageEvent(Base):
|
||||
"""Represents an event stored in the database."""
|
||||
|
||||
__tablename__ = "events"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
|
||||
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||
timestamp: Mapped[PreciseTimestamp] = mapped_column(
|
||||
PreciseTimestamp, default=func.now()
|
||||
)
|
||||
# The event_data uses JSON serialization to store the Event data, replacing
|
||||
# various fields previously used.
|
||||
event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
||||
|
||||
storage_session: Mapped[StorageSession] = relationship(
|
||||
"StorageSession",
|
||||
back_populates="storage_events",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
["app_name", "user_id", "session_id"],
|
||||
["sessions.app_name", "sessions.user_id", "sessions.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_event(cls, session: Session, event: Event) -> StorageEvent:
|
||||
"""Creates a StorageEvent from an Event."""
|
||||
return StorageEvent(
|
||||
id=event.id,
|
||||
invocation_id=event.invocation_id,
|
||||
session_id=session.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||
event_data=event.model_dump(exclude_none=True, mode="json"),
|
||||
)
|
||||
|
||||
def to_event(self) -> Event:
|
||||
"""Converts the StorageEvent to an Event."""
|
||||
return Event.model_validate({
|
||||
**self.event_data,
|
||||
"id": self.id,
|
||||
"invocation_id": self.invocation_id,
|
||||
"timestamp": self.timestamp.timestamp(),
|
||||
})
|
||||
|
||||
|
||||
class StorageAppState(Base):
|
||||
"""Represents an app state stored in the database."""
|
||||
|
||||
__tablename__ = "app_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[datetime] = mapped_column(
|
||||
PreciseTimestamp, default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class StorageUserState(Base):
|
||||
"""Represents a user state stored in the database."""
|
||||
|
||||
__tablename__ = "user_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||
)
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[datetime] = mapped_column(
|
||||
PreciseTimestamp, default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
import pickle
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.sessions.database_session_service import DynamicPickleType
|
||||
from google.adk.sessions.schemas.v0 import DynamicPickleType
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
Reference in New Issue
Block a user