diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 17737297..8723ea2e 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -84,6 +84,19 @@ def _has_non_empty_transcription_text(transcription) -> bool: ) +def _apply_run_config_custom_metadata( + event: Event, run_config: RunConfig | None +) -> None: + """Merges run-level custom metadata into the event, if present.""" + if not run_config or not run_config.custom_metadata: + return + + event.custom_metadata = { + **run_config.custom_metadata, + **(event.custom_metadata or {}), + } + + class Runner: """The Runner class is used to run agents. @@ -695,6 +708,9 @@ class Runner: author='model', content=early_exit_result, ) + _apply_run_config_custom_metadata( + early_exit_event, invocation_context.run_config + ) if self._should_append_event(early_exit_event, is_live_call): await self.session_service.append_event( session=session, @@ -721,6 +737,9 @@ class Runner: async with Aclosing(execute_fn(invocation_context)) as agen: async for event in agen: + _apply_run_config_custom_metadata( + event, invocation_context.run_config + ) if is_live_call: if event.partial and _is_transcription(event): is_transcribing = True @@ -775,7 +794,13 @@ class Runner: modified_event = await plugin_manager.run_on_event_callback( invocation_context=invocation_context, event=event ) - yield (modified_event if modified_event else event) + if modified_event: + _apply_run_config_custom_metadata( + modified_event, invocation_context.run_config + ) + yield modified_event + else: + yield event # Step 4: Run the after_run callbacks to perform global cleanup tasks or # finalizing logs and metrics data. @@ -846,6 +871,7 @@ class Runner: author='user', content=new_message, ) + _apply_run_config_custom_metadata(event, invocation_context.run_config) # If new_message is a function response, find the matching function call # and use its branch as the new event's branch. if function_call := invocation_context._find_matching_function_call(event): diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index d692f7e3..c347a789 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -16,6 +16,7 @@ import importlib from pathlib import Path import sys import textwrap +from typing import AsyncGenerator from typing import Optional from unittest.mock import AsyncMock @@ -23,6 +24,7 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App from google.adk.apps.app import ResumabilityConfig from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService @@ -54,7 +56,9 @@ class MockAgent(BaseAgent): if parent_agent: self.parent_agent = parent_agent - async def _run_async_impl(self, invocation_context): + async def _run_async_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: yield Event( invocation_id=invocation_context.invocation_id, author=self.name, @@ -78,7 +82,9 @@ class MockLlmAgent(LlmAgent): self.disallow_transfer_to_parent = disallow_transfer_to_parent self.parent_agent = parent_agent - async def _run_async_impl(self, invocation_context): + async def _run_async_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: yield Event( invocation_id=invocation_context.invocation_id, author=self.name, @@ -88,6 +94,25 @@ class MockLlmAgent(LlmAgent): ) +class MockAgentWithMetadata(BaseAgent): + """Mock agent that returns event-level custom metadata.""" + + def __init__(self, name: str): + super().__init__(name=name, sub_agents=[]) + + async def _run_async_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ), + custom_metadata={"event_key": "event_value"}, + ) + + class MockPlugin(BasePlugin): """Mock plugin for unit testing.""" @@ -495,6 +520,41 @@ async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch): assert result is False +@pytest.mark.asyncio +async def test_run_config_custom_metadata_propagates_to_events(): + session_service = InMemorySessionService() + runner = Runner( + app_name=TEST_APP_ID, + agent=MockAgentWithMetadata("metadata_agent"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + await session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + run_config = RunConfig(custom_metadata={"request_id": "req-1"}) + events = [ + event + async for event in runner.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content(role="user", parts=[types.Part(text="hi")]), + run_config=run_config, + ) + ] + + assert events[0].custom_metadata is not None + assert events[0].custom_metadata["request_id"] == "req-1" + assert events[0].custom_metadata["event_key"] == "event_value" + + session = await session_service.get_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + user_event = next(event for event in session.events if event.author == "user") + assert user_event.custom_metadata == {"request_id": "req-1"} + + class TestRunnerWithPlugins: """Tests for Runner with plugins."""