fix: Add usage_metadata and citation_metadata to DatabaseSessionService

PiperOrigin-RevId: 819900773
This commit is contained in:
Shangjie Chen
2025-10-15 13:54:52 -07:00
committed by Copybara-Service
parent 2424d6a3b1
commit 6ab1498aa0
4 changed files with 44 additions and 24 deletions
+9 -16
View File
@@ -16,23 +16,16 @@ from __future__ import annotations
from typing import Any
from typing import Optional
from typing import Type
from typing import TypeVar
from google.genai import types
M = TypeVar("M")
def decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
"""Decodes a content object from a JSON dictionary."""
if not content:
def decode_model(
data: Optional[dict[str, Any]], model_cls: Type[M]
) -> Optional[M]:
"""Decodes a pydantic model object from a JSON dictionary."""
if data is None:
return None
return types.Content.model_validate(content)
def decode_grounding_metadata(
grounding_metadata: Optional[dict[str, Any]],
) -> Optional[types.GroundingMetadata]:
"""Decodes a grounding metadata object from a JSON dictionary."""
if not grounding_metadata:
return None
return types.GroundingMetadata.model_validate(grounding_metadata)
return model_cls.model_validate(data)
@@ -23,6 +23,7 @@ from typing import Any
from typing import Optional
import uuid
from google.genai import types
from sqlalchemy import Boolean
from sqlalchemy import delete
from sqlalchemy import Dialect
@@ -252,6 +253,12 @@ class StorageEvent(Base):
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
@@ -318,6 +325,14 @@ class StorageEvent(Base):
)
if event.custom_metadata:
storage_event.custom_metadata = event.custom_metadata
if event.usage_metadata:
storage_event.usage_metadata = event.usage_metadata.model_dump(
exclude_none=True, mode="json"
)
if event.citation_metadata:
storage_event.citation_metadata = event.citation_metadata.model_dump(
exclude_none=True, mode="json"
)
return storage_event
def to_event(self) -> Event:
@@ -328,17 +343,23 @@ class StorageEvent(Base):
branch=self.branch,
actions=self.actions,
timestamp=self.timestamp.timestamp(),
content=_session_util.decode_content(self.content),
long_running_tool_ids=self.long_running_tool_ids,
partial=self.partial,
turn_complete=self.turn_complete,
error_code=self.error_code,
error_message=self.error_message,
interrupted=self.interrupted,
grounding_metadata=_session_util.decode_grounding_metadata(
self.grounding_metadata
),
custom_metadata=self.custom_metadata,
content=_session_util.decode_model(self.content, types.Content),
grounding_metadata=_session_util.decode_model(
self.grounding_metadata, types.GroundingMetadata
),
usage_metadata=_session_util.decode_model(
self.usage_metadata, types.GenerateContentResponseUsageMetadata
),
citation_metadata=_session_util.decode_model(
self.citation_metadata, types.CitationMetadata
),
)
@@ -376,8 +376,9 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
interrupted = getattr(event_metadata, 'interrupted', None)
branch = getattr(event_metadata, 'branch', None)
custom_metadata = getattr(event_metadata, 'custom_metadata', None)
grounding_metadata = _session_util.decode_grounding_metadata(
getattr(event_metadata, 'grounding_metadata', None)
grounding_metadata = _session_util.decode_model(
getattr(event_metadata, 'grounding_metadata', None),
types.GroundingMetadata,
)
else:
long_running_tool_ids = None
@@ -393,8 +394,8 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
invocation_id=api_event_obj.invocation_id,
author=api_event_obj.author,
actions=event_actions,
content=_session_util.decode_content(
getattr(api_event_obj, 'content', None)
content=_session_util.decode_model(
getattr(api_event_obj, 'content', None), types.Content
),
timestamp=api_event_obj.timestamp.timestamp(),
error_code=getattr(api_event_obj, 'error_code', None),
@@ -390,6 +390,11 @@ async def test_append_event_complete(service_type):
error_code='error_code',
error_message='error_message',
interrupted=True,
usage_metadata=types.GenerateContentResponseUsageMetadata(
prompt_token_count=1, candidates_token_count=1, total_token_count=2
),
citation_metadata=types.CitationMetadata(),
custom_metadata={'custom_key': 'custom_value'},
)
await session_service.append_event(session=session, event=event)