You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Add usage_metadata and citation_metadata to DatabaseSessionService
PiperOrigin-RevId: 819900773
This commit is contained in:
committed by
Copybara-Service
parent
2424d6a3b1
commit
6ab1498aa0
@@ -16,23 +16,16 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
from google.genai import types
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
def decode_content(
|
||||
content: Optional[dict[str, Any]],
|
||||
) -> Optional[types.Content]:
|
||||
"""Decodes a content object from a JSON dictionary."""
|
||||
if not content:
|
||||
def decode_model(
|
||||
data: Optional[dict[str, Any]], model_cls: Type[M]
|
||||
) -> Optional[M]:
|
||||
"""Decodes a pydantic model object from a JSON dictionary."""
|
||||
if data is None:
|
||||
return None
|
||||
return types.Content.model_validate(content)
|
||||
|
||||
|
||||
def decode_grounding_metadata(
|
||||
grounding_metadata: Optional[dict[str, Any]],
|
||||
) -> Optional[types.GroundingMetadata]:
|
||||
"""Decodes a grounding metadata object from a JSON dictionary."""
|
||||
if not grounding_metadata:
|
||||
return None
|
||||
return types.GroundingMetadata.model_validate(grounding_metadata)
|
||||
return model_cls.model_validate(data)
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import Any
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import Dialect
|
||||
@@ -252,6 +253,12 @@ class StorageEvent(Base):
|
||||
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
|
||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
@@ -318,6 +325,14 @@ class StorageEvent(Base):
|
||||
)
|
||||
if event.custom_metadata:
|
||||
storage_event.custom_metadata = event.custom_metadata
|
||||
if event.usage_metadata:
|
||||
storage_event.usage_metadata = event.usage_metadata.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
if event.citation_metadata:
|
||||
storage_event.citation_metadata = event.citation_metadata.model_dump(
|
||||
exclude_none=True, mode="json"
|
||||
)
|
||||
return storage_event
|
||||
|
||||
def to_event(self) -> Event:
|
||||
@@ -328,17 +343,23 @@ class StorageEvent(Base):
|
||||
branch=self.branch,
|
||||
actions=self.actions,
|
||||
timestamp=self.timestamp.timestamp(),
|
||||
content=_session_util.decode_content(self.content),
|
||||
long_running_tool_ids=self.long_running_tool_ids,
|
||||
partial=self.partial,
|
||||
turn_complete=self.turn_complete,
|
||||
error_code=self.error_code,
|
||||
error_message=self.error_message,
|
||||
interrupted=self.interrupted,
|
||||
grounding_metadata=_session_util.decode_grounding_metadata(
|
||||
self.grounding_metadata
|
||||
),
|
||||
custom_metadata=self.custom_metadata,
|
||||
content=_session_util.decode_model(self.content, types.Content),
|
||||
grounding_metadata=_session_util.decode_model(
|
||||
self.grounding_metadata, types.GroundingMetadata
|
||||
),
|
||||
usage_metadata=_session_util.decode_model(
|
||||
self.usage_metadata, types.GenerateContentResponseUsageMetadata
|
||||
),
|
||||
citation_metadata=_session_util.decode_model(
|
||||
self.citation_metadata, types.CitationMetadata
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -376,8 +376,9 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
|
||||
interrupted = getattr(event_metadata, 'interrupted', None)
|
||||
branch = getattr(event_metadata, 'branch', None)
|
||||
custom_metadata = getattr(event_metadata, 'custom_metadata', None)
|
||||
grounding_metadata = _session_util.decode_grounding_metadata(
|
||||
getattr(event_metadata, 'grounding_metadata', None)
|
||||
grounding_metadata = _session_util.decode_model(
|
||||
getattr(event_metadata, 'grounding_metadata', None),
|
||||
types.GroundingMetadata,
|
||||
)
|
||||
else:
|
||||
long_running_tool_ids = None
|
||||
@@ -393,8 +394,8 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
|
||||
invocation_id=api_event_obj.invocation_id,
|
||||
author=api_event_obj.author,
|
||||
actions=event_actions,
|
||||
content=_session_util.decode_content(
|
||||
getattr(api_event_obj, 'content', None)
|
||||
content=_session_util.decode_model(
|
||||
getattr(api_event_obj, 'content', None), types.Content
|
||||
),
|
||||
timestamp=api_event_obj.timestamp.timestamp(),
|
||||
error_code=getattr(api_event_obj, 'error_code', None),
|
||||
|
||||
@@ -390,6 +390,11 @@ async def test_append_event_complete(service_type):
|
||||
error_code='error_code',
|
||||
error_message='error_message',
|
||||
interrupted=True,
|
||||
usage_metadata=types.GenerateContentResponseUsageMetadata(
|
||||
prompt_token_count=1, candidates_token_count=1, total_token_count=2
|
||||
),
|
||||
citation_metadata=types.CitationMetadata(),
|
||||
custom_metadata={'custom_key': 'custom_value'},
|
||||
)
|
||||
await session_service.append_event(session=session, event=event)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user