fix: Propagate RunConfig custom metadata to all events

Adds a method to merge custom metadata from the RunConfig into each Event. This metadata is applied to events generated by the agent, early exit events, and the initial user message event.

Close #3953

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 852433171
This commit is contained in:
George Weale
2026-01-05 13:29:36 -08:00
committed by Copybara-Service
parent 4ddb2cb2a8
commit e3db2d0d83
2 changed files with 89 additions and 3 deletions
+27 -1
View File
@@ -84,6 +84,19 @@ def _has_non_empty_transcription_text(transcription) -> bool:
)
def _apply_run_config_custom_metadata(
event: Event, run_config: RunConfig | None
) -> None:
"""Merges run-level custom metadata into the event, if present."""
if not run_config or not run_config.custom_metadata:
return
event.custom_metadata = {
**run_config.custom_metadata,
**(event.custom_metadata or {}),
}
class Runner:
"""The Runner class is used to run agents.
@@ -695,6 +708,9 @@ class Runner:
author='model',
content=early_exit_result,
)
_apply_run_config_custom_metadata(
early_exit_event, invocation_context.run_config
)
if self._should_append_event(early_exit_event, is_live_call):
await self.session_service.append_event(
session=session,
@@ -721,6 +737,9 @@ class Runner:
async with Aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
_apply_run_config_custom_metadata(
event, invocation_context.run_config
)
if is_live_call:
if event.partial and _is_transcription(event):
is_transcribing = True
@@ -775,7 +794,13 @@ class Runner:
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
)
yield (modified_event if modified_event else event)
if modified_event:
_apply_run_config_custom_metadata(
modified_event, invocation_context.run_config
)
yield modified_event
else:
yield event
# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
@@ -846,6 +871,7 @@ class Runner:
author='user',
content=new_message,
)
_apply_run_config_custom_metadata(event, invocation_context.run_config)
# If new_message is a function response, find the matching function call
# and use its branch as the new event's branch.
if function_call := invocation_context._find_matching_function_call(event):
+62 -2
View File
@@ -16,6 +16,7 @@ import importlib
from pathlib import Path
import sys
import textwrap
from typing import AsyncGenerator
from typing import Optional
from unittest.mock import AsyncMock
@@ -23,6 +24,7 @@ from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.apps.app import App
from google.adk.apps.app import ResumabilityConfig
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
@@ -54,7 +56,9 @@ class MockAgent(BaseAgent):
if parent_agent:
self.parent_agent = parent_agent
async def _run_async_impl(self, invocation_context):
async def _run_async_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
@@ -78,7 +82,9 @@ class MockLlmAgent(LlmAgent):
self.disallow_transfer_to_parent = disallow_transfer_to_parent
self.parent_agent = parent_agent
async def _run_async_impl(self, invocation_context):
async def _run_async_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
@@ -88,6 +94,25 @@ class MockLlmAgent(LlmAgent):
)
class MockAgentWithMetadata(BaseAgent):
"""Mock agent that returns event-level custom metadata."""
def __init__(self, name: str):
super().__init__(name=name, sub_agents=[])
async def _run_async_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
content=types.Content(
role="model", parts=[types.Part(text="Test response")]
),
custom_metadata={"event_key": "event_value"},
)
class MockPlugin(BasePlugin):
"""Mock plugin for unit testing."""
@@ -495,6 +520,41 @@ async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
assert result is False
@pytest.mark.asyncio
async def test_run_config_custom_metadata_propagates_to_events():
session_service = InMemorySessionService()
runner = Runner(
app_name=TEST_APP_ID,
agent=MockAgentWithMetadata("metadata_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
)
await session_service.create_session(
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
)
run_config = RunConfig(custom_metadata={"request_id": "req-1"})
events = [
event
async for event in runner.run_async(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
new_message=types.Content(role="user", parts=[types.Part(text="hi")]),
run_config=run_config,
)
]
assert events[0].custom_metadata is not None
assert events[0].custom_metadata["request_id"] == "req-1"
assert events[0].custom_metadata["event_key"] == "event_value"
session = await session_service.get_session(
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
)
user_event = next(event for event in session.events if event.author == "user")
assert user_event.custom_metadata == {"request_id": "req-1"}
class TestRunnerWithPlugins:
"""Tests for Runner with plugins."""