From fb009d8ea672bbbef4753e4cd25229dbebd0ff8d Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 2 Sep 2025 20:09:10 -0700 Subject: [PATCH] fix: Add `custom_metadata` to DatabaseSessionService Resolve https://github.com/google/adk-python/issues/2677 PiperOrigin-RevId: 802375768 --- .../adk/sessions/database_session_service.py | 18 ++++++--- .../sessions/test_session_service.py | 39 +++++++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index e37f91d0..d2bb71d0 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -231,21 +231,26 @@ class StorageEvent(Base): invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) + long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( + Text, nullable=True + ) branch: Mapped[str] = mapped_column( String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True ) timestamp: Mapped[PreciseTimestamp] = mapped_column( PreciseTimestamp, default=func.now() ) - content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) - actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) - long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( - Text, nullable=True - ) + # === Fileds from llm_response.py === + content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) grounding_metadata: Mapped[dict[str, Any]] = mapped_column( DynamicJSON, nullable=True ) + custom_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + partial: Mapped[bool] = mapped_column(Boolean, nullable=True) turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) error_code: Mapped[str] = mapped_column( @@ -309,6 +314,8 @@ class StorageEvent(Base): storage_event.grounding_metadata = event.grounding_metadata.model_dump( exclude_none=True, mode="json" ) + if event.custom_metadata: + storage_event.custom_metadata = event.custom_metadata return storage_event def to_event(self) -> Event: @@ -329,6 +336,7 @@ class StorageEvent(Base): grounding_metadata=_session_util.decode_grounding_metadata( self.grounding_metadata ), + custom_metadata=self.custom_metadata, ) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 4acfd265..c2a3a1d9 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -402,3 +402,42 @@ async def test_get_session_with_config(service_type): ) events = session.events assert len(events) == num_test_events - after_timestamp + 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] +) +async def test_append_event_with_fields(service_type): + session_service = get_session_service(service_type) + app_name = 'my_app' + user_id = 'test_user' + session = await session_service.create_session( + app_name=app_name, user_id=user_id, state={} + ) + + event = Event( + invocation_id='invocation', + author='user', + content=types.Content(role='user', parts=[types.Part(text='text')]), + long_running_tool_ids={'tool1', 'tool2'}, + partial=False, + turn_complete=True, + error_code='ERROR_CODE', + error_message='error message', + interrupted=True, + grounding_metadata=types.GroundingMetadata( + web_search_queries=['query1'], + ), + custom_metadata={'custom_key': 'custom_value'}, + ) + await session_service.append_event(session, event) + + retrieved_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert retrieved_session + assert len(retrieved_session.events) == 1 + retrieved_event = retrieved_session.events[0] + + assert retrieved_event == event