diff --git a/src/google/adk/sessions/_session_util.py b/src/google/adk/sessions/_session_util.py index 2cc65949..68340c8b 100644 --- a/src/google/adk/sessions/_session_util.py +++ b/src/google/adk/sessions/_session_util.py @@ -16,23 +16,16 @@ from __future__ import annotations from typing import Any from typing import Optional +from typing import Type +from typing import TypeVar -from google.genai import types +M = TypeVar("M") -def decode_content( - content: Optional[dict[str, Any]], -) -> Optional[types.Content]: - """Decodes a content object from a JSON dictionary.""" - if not content: +def decode_model( + data: Optional[dict[str, Any]], model_cls: Type[M] +) -> Optional[M]: + """Decodes a pydantic model object from a JSON dictionary.""" + if data is None: return None - return types.Content.model_validate(content) - - -def decode_grounding_metadata( - grounding_metadata: Optional[dict[str, Any]], -) -> Optional[types.GroundingMetadata]: - """Decodes a grounding metadata object from a JSON dictionary.""" - if not grounding_metadata: - return None - return types.GroundingMetadata.model_validate(grounding_metadata) + return model_cls.model_validate(data) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 9eca2753..04af695b 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -23,6 +23,7 @@ from typing import Any from typing import Optional import uuid +from google.genai import types from sqlalchemy import Boolean from sqlalchemy import delete from sqlalchemy import Dialect @@ -252,6 +253,12 @@ class StorageEvent(Base): custom_metadata: Mapped[dict[str, Any]] = mapped_column( DynamicJSON, nullable=True ) + usage_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + citation_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) partial: Mapped[bool] = mapped_column(Boolean, nullable=True) turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) @@ -318,6 +325,14 @@ class StorageEvent(Base): ) if event.custom_metadata: storage_event.custom_metadata = event.custom_metadata + if event.usage_metadata: + storage_event.usage_metadata = event.usage_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.citation_metadata: + storage_event.citation_metadata = event.citation_metadata.model_dump( + exclude_none=True, mode="json" + ) return storage_event def to_event(self) -> Event: @@ -328,17 +343,23 @@ class StorageEvent(Base): branch=self.branch, actions=self.actions, timestamp=self.timestamp.timestamp(), - content=_session_util.decode_content(self.content), long_running_tool_ids=self.long_running_tool_ids, partial=self.partial, turn_complete=self.turn_complete, error_code=self.error_code, error_message=self.error_message, interrupted=self.interrupted, - grounding_metadata=_session_util.decode_grounding_metadata( - self.grounding_metadata - ), custom_metadata=self.custom_metadata, + content=_session_util.decode_model(self.content, types.Content), + grounding_metadata=_session_util.decode_model( + self.grounding_metadata, types.GroundingMetadata + ), + usage_metadata=_session_util.decode_model( + self.usage_metadata, types.GenerateContentResponseUsageMetadata + ), + citation_metadata=_session_util.decode_model( + self.citation_metadata, types.CitationMetadata + ), ) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 72ff0d6c..6901f4cb 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -376,8 +376,9 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: interrupted = getattr(event_metadata, 'interrupted', None) branch = getattr(event_metadata, 'branch', None) custom_metadata = getattr(event_metadata, 'custom_metadata', None) - grounding_metadata = _session_util.decode_grounding_metadata( - getattr(event_metadata, 'grounding_metadata', None) + grounding_metadata = _session_util.decode_model( + getattr(event_metadata, 'grounding_metadata', None), + types.GroundingMetadata, ) else: long_running_tool_ids = None @@ -393,8 +394,8 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: invocation_id=api_event_obj.invocation_id, author=api_event_obj.author, actions=event_actions, - content=_session_util.decode_content( - getattr(api_event_obj, 'content', None) + content=_session_util.decode_model( + getattr(api_event_obj, 'content', None), types.Content ), timestamp=api_event_obj.timestamp.timestamp(), error_code=getattr(api_event_obj, 'error_code', None), diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 0c005d68..0dd4162e 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -390,6 +390,11 @@ async def test_append_event_complete(service_type): error_code='error_code', error_message='error_message', interrupted=True, + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, candidates_token_count=1, total_token_count=2 + ), + citation_metadata=types.CitationMetadata(), + custom_metadata={'custom_key': 'custom_value'}, ) await session_service.append_event(session=session, event=event)