You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Propagate RunConfig custom metadata to all events
Adds a method to merge custom metadata from the RunConfig into each Event. This metadata is applied to events generated by the agent, early exit events, and the initial user message event. Close #3953 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 852433171
This commit is contained in:
committed by
Copybara-Service
parent
4ddb2cb2a8
commit
e3db2d0d83
@@ -84,6 +84,19 @@ def _has_non_empty_transcription_text(transcription) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _apply_run_config_custom_metadata(
|
||||
event: Event, run_config: RunConfig | None
|
||||
) -> None:
|
||||
"""Merges run-level custom metadata into the event, if present."""
|
||||
if not run_config or not run_config.custom_metadata:
|
||||
return
|
||||
|
||||
event.custom_metadata = {
|
||||
**run_config.custom_metadata,
|
||||
**(event.custom_metadata or {}),
|
||||
}
|
||||
|
||||
|
||||
class Runner:
|
||||
"""The Runner class is used to run agents.
|
||||
|
||||
@@ -695,6 +708,9 @@ class Runner:
|
||||
author='model',
|
||||
content=early_exit_result,
|
||||
)
|
||||
_apply_run_config_custom_metadata(
|
||||
early_exit_event, invocation_context.run_config
|
||||
)
|
||||
if self._should_append_event(early_exit_event, is_live_call):
|
||||
await self.session_service.append_event(
|
||||
session=session,
|
||||
@@ -721,6 +737,9 @@ class Runner:
|
||||
|
||||
async with Aclosing(execute_fn(invocation_context)) as agen:
|
||||
async for event in agen:
|
||||
_apply_run_config_custom_metadata(
|
||||
event, invocation_context.run_config
|
||||
)
|
||||
if is_live_call:
|
||||
if event.partial and _is_transcription(event):
|
||||
is_transcribing = True
|
||||
@@ -775,7 +794,13 @@ class Runner:
|
||||
modified_event = await plugin_manager.run_on_event_callback(
|
||||
invocation_context=invocation_context, event=event
|
||||
)
|
||||
yield (modified_event if modified_event else event)
|
||||
if modified_event:
|
||||
_apply_run_config_custom_metadata(
|
||||
modified_event, invocation_context.run_config
|
||||
)
|
||||
yield modified_event
|
||||
else:
|
||||
yield event
|
||||
|
||||
# Step 4: Run the after_run callbacks to perform global cleanup tasks or
|
||||
# finalizing logs and metrics data.
|
||||
@@ -846,6 +871,7 @@ class Runner:
|
||||
author='user',
|
||||
content=new_message,
|
||||
)
|
||||
_apply_run_config_custom_metadata(event, invocation_context.run_config)
|
||||
# If new_message is a function response, find the matching function call
|
||||
# and use its branch as the new event's branch.
|
||||
if function_call := invocation_context._find_matching_function_call(event):
|
||||
|
||||
@@ -16,6 +16,7 @@ import importlib
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
@@ -23,6 +24,7 @@ from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.context_cache_config import ContextCacheConfig
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.apps.app import App
|
||||
from google.adk.apps.app import ResumabilityConfig
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
@@ -54,7 +56,9 @@ class MockAgent(BaseAgent):
|
||||
if parent_agent:
|
||||
self.parent_agent = parent_agent
|
||||
|
||||
async def _run_async_impl(self, invocation_context):
|
||||
async def _run_async_impl(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
@@ -78,7 +82,9 @@ class MockLlmAgent(LlmAgent):
|
||||
self.disallow_transfer_to_parent = disallow_transfer_to_parent
|
||||
self.parent_agent = parent_agent
|
||||
|
||||
async def _run_async_impl(self, invocation_context):
|
||||
async def _run_async_impl(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
@@ -88,6 +94,25 @@ class MockLlmAgent(LlmAgent):
|
||||
)
|
||||
|
||||
|
||||
class MockAgentWithMetadata(BaseAgent):
|
||||
"""Mock agent that returns event-level custom metadata."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name=name, sub_agents=[])
|
||||
|
||||
async def _run_async_impl(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(text="Test response")]
|
||||
),
|
||||
custom_metadata={"event_key": "event_value"},
|
||||
)
|
||||
|
||||
|
||||
class MockPlugin(BasePlugin):
|
||||
"""Mock plugin for unit testing."""
|
||||
|
||||
@@ -495,6 +520,41 @@ async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_config_custom_metadata_propagates_to_events():
|
||||
session_service = InMemorySessionService()
|
||||
runner = Runner(
|
||||
app_name=TEST_APP_ID,
|
||||
agent=MockAgentWithMetadata("metadata_agent"),
|
||||
session_service=session_service,
|
||||
artifact_service=InMemoryArtifactService(),
|
||||
)
|
||||
await session_service.create_session(
|
||||
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
|
||||
)
|
||||
|
||||
run_config = RunConfig(custom_metadata={"request_id": "req-1"})
|
||||
events = [
|
||||
event
|
||||
async for event in runner.run_async(
|
||||
user_id=TEST_USER_ID,
|
||||
session_id=TEST_SESSION_ID,
|
||||
new_message=types.Content(role="user", parts=[types.Part(text="hi")]),
|
||||
run_config=run_config,
|
||||
)
|
||||
]
|
||||
|
||||
assert events[0].custom_metadata is not None
|
||||
assert events[0].custom_metadata["request_id"] == "req-1"
|
||||
assert events[0].custom_metadata["event_key"] == "event_value"
|
||||
|
||||
session = await session_service.get_session(
|
||||
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
|
||||
)
|
||||
user_event = next(event for event in session.events if event.author == "user")
|
||||
assert user_event.custom_metadata == {"request_id": "req-1"}
|
||||
|
||||
|
||||
class TestRunnerWithPlugins:
|
||||
"""Tests for Runner with plugins."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user