From df05ed6b3b7b218d85fddc1acd6617802cdf6f2a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 13 Oct 2025 14:05:09 -0700 Subject: [PATCH] feat: migrate invocation_context to callback_context Update plugin manager and built-in plugins to prioritize CallbackContext. Keep InvocationContext access for legacy plugins with adapter. Change callback docs/tests to cover the new context. PiperOrigin-RevId: 818822267 --- src/google/adk/agents/readonly_context.py | 21 +-- .../adk/cli/plugins/recordings_plugin.py | 9 +- src/google/adk/cli/plugins/replay_plugin.py | 9 +- src/google/adk/plugins/base_plugin.py | 122 ++++++------------ .../adk/plugins/global_instruction_plugin.py | 2 +- src/google/adk/plugins/logging_plugin.py | 43 +++--- src/google/adk/plugins/plugin_manager.py | 101 +-------------- tests/unittests/plugins/test_base_plugin.py | 24 ++-- .../plugins/test_global_instruction_plugin.py | 13 +- tests/unittests/test_runners.py | 6 +- 10 files changed, 110 insertions(+), 240 deletions(-) diff --git a/src/google/adk/agents/readonly_context.py b/src/google/adk/agents/readonly_context.py index 12470be0..c067fd11 100644 --- a/src/google/adk/agents/readonly_context.py +++ b/src/google/adk/agents/readonly_context.py @@ -55,21 +55,6 @@ class ReadonlyContext: return MappingProxyType(self._invocation_context.session.state) @property - def user_id(self) -> str: - """The user ID for the current invocation.""" - return self._invocation_context.user_id - - @property - def app_name(self) -> str: - """The application name for the current invocation.""" - return self._invocation_context.app_name - - @property - def session_id(self) -> str: - """The session ID for the current invocation.""" - return self._invocation_context.session.id - - @property - def branch(self) -> Optional[str]: - """The branch name for the current invocation, if any.""" - return self._invocation_context.branch + def session(self) -> Session: + """The current session for this invocation.""" + return self._invocation_context.session diff --git a/src/google/adk/cli/plugins/recordings_plugin.py b/src/google/adk/cli/plugins/recordings_plugin.py index 0383e0b9..8ee36892 100644 --- a/src/google/adk/cli/plugins/recordings_plugin.py +++ b/src/google/adk/cli/plugins/recordings_plugin.py @@ -39,6 +39,7 @@ from .recordings_schema import Recordings from .recordings_schema import ToolRecording if TYPE_CHECKING: + from ...agents.invocation_context import InvocationContext from ...tools.base_tool import BaseTool from ...tools.tool_context import ToolContext @@ -74,10 +75,10 @@ class RecordingsPlugin(BasePlugin): @override async def before_run_callback( - self, *, callback_context: CallbackContext + self, *, invocation_context: InvocationContext ) -> Optional[types.Content]: """Always create fresh per-invocation recording state when enabled.""" - ctx = callback_context + ctx = CallbackContext(invocation_context) if self._is_record_mode_on(ctx): # Always create/overwrite the state for this invocation self._create_invocation_state(ctx) @@ -279,10 +280,10 @@ class RecordingsPlugin(BasePlugin): @override async def after_run_callback( - self, *, callback_context: CallbackContext + self, *, invocation_context: InvocationContext ) -> None: """Finalize and persist recordings, then clean per-invocation state.""" - ctx = callback_context + ctx = CallbackContext(invocation_context) if not self._is_record_mode_on(ctx): return None diff --git a/src/google/adk/cli/plugins/replay_plugin.py b/src/google/adk/cli/plugins/replay_plugin.py index f0d97242..1ca63f6d 100644 --- a/src/google/adk/cli/plugins/replay_plugin.py +++ b/src/google/adk/cli/plugins/replay_plugin.py @@ -38,6 +38,7 @@ from .recordings_schema import Recordings from .recordings_schema import ToolRecording if TYPE_CHECKING: + from ...agents.invocation_context import InvocationContext from ...tools.base_tool import BaseTool from ...tools.tool_context import ToolContext @@ -80,10 +81,10 @@ class ReplayPlugin(BasePlugin): @override async def before_run_callback( - self, *, callback_context: CallbackContext + self, *, invocation_context: InvocationContext ) -> Optional[types.Content]: """Load replay recordings when enabled.""" - ctx = callback_context + ctx = CallbackContext(invocation_context) if self._is_replay_mode_on(ctx): # Load the replay state for this invocation self._load_invocation_state(ctx) @@ -155,10 +156,10 @@ class ReplayPlugin(BasePlugin): @override async def after_run_callback( - self, *, callback_context: CallbackContext + self, *, invocation_context: InvocationContext ) -> None: """Clean up replay state after invocation completes.""" - ctx = callback_context + ctx = CallbackContext(invocation_context) if not self._is_replay_mode_on(ctx): return None diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 99115433..fb3e3c00 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -111,61 +111,11 @@ class BasePlugin(ABC): super().__init__() self.name = name - if TYPE_CHECKING: - - async def on_user_message_callback( - self, - *, - callback_context: Optional[CallbackContext] = None, - user_message: types.Content, - invocation_context: Optional[InvocationContext] = None, - ) -> Optional[types.Content]: - """Callback executed when a user message is received before an invocation starts. - - Plugins can implement this with either callback_context (new) or - invocation_context (deprecated) or both. - """ - - async def before_run_callback( - self, - *, - callback_context: Optional[CallbackContext] = None, - invocation_context: Optional[InvocationContext] = None, - ) -> Optional[types.Content]: - """Callback executed before the ADK runner runs. - - Plugins can implement this with either callback_context (new) or - invocation_context (deprecated) or both. - """ - - async def on_event_callback( - self, - *, - callback_context: Optional[CallbackContext] = None, - event: Event, - invocation_context: Optional[InvocationContext] = None, - ) -> Optional[Event]: - """Callback executed after an event is yielded from runner. - - Plugins can implement this with either callback_context (new) or - invocation_context (deprecated) or both. - """ - - async def after_run_callback( - self, - *, - callback_context: Optional[CallbackContext] = None, - invocation_context: Optional[InvocationContext] = None, - ) -> None: - """Callback executed after an ADK runner run has completed. - - Plugins can implement this with either callback_context (new) or - invocation_context (deprecated) or both. - """ - - # Runtime implementation accepts both via **kwargs async def on_user_message_callback( - self, **kwargs: Any + self, + *, + invocation_context: InvocationContext, + user_message: types.Content, ) -> Optional[types.Content]: """Callback executed when a user message is received before an invocation starts. @@ -173,63 +123,69 @@ class BasePlugin(ABC): runner starts the invocation. Args: - callback_context: The context for the callback execution. + invocation_context: The context for the entire invocation. user_message: The message content input by user. - invocation_context: DEPRECATED. Use callback_context instead. The context - for the entire invocation. This parameter is maintained for backward - compatibility and will be removed in a future version. Returns: - The modified user message or None if no modification is needed. + An optional `types.Content` to be returned to the ADK. Returning a + value to replace the user message. Returning `None` to proceed + normally. """ - return None + pass - async def before_run_callback(self, **kwargs: Any) -> Optional[types.Content]: + async def before_run_callback( + self, *, invocation_context: InvocationContext + ) -> Optional[types.Content]: """Callback executed before the ADK runner runs. - This is the first lifecycle hook and is ideal for global setup, logging, - or checks that may stop the invocation from running. + This is the first callback to be called in the lifecycle, ideal for global + setup or initialization tasks. Args: - callback_context: The context for the callback execution. - invocation_context: DEPRECATED. Use callback_context instead. The context - for the entire invocation. This parameter is maintained for backward - compatibility and will be removed in a future version. + invocation_context: The context for the entire invocation, containing + session information, the root agent, etc. Returns: - Optional `types.Content` to halt execution and return the value to the - caller. Return `None` to proceed normally. + An optional `Event` to be returned to the ADK. Returning a value to + halt execution of the runner and ends the runner with that event. Return + `None` to proceed normally. """ - return None + pass - async def on_event_callback(self, **kwargs: Any) -> Optional[Event]: + async def on_event_callback( + self, *, invocation_context: InvocationContext, event: Event + ) -> Optional[Event]: """Callback executed after an event is yielded from runner. + This is the ideal place to make modification to the event before the event + is handled by the underlying agent app. + Args: - callback_context: The context for the callback execution. + invocation_context: The context for the entire invocation. event: The event raised by the runner. - invocation_context: DEPRECATED. Use callback_context instead. The context - for the entire invocation. This parameter is maintained for backward - compatibility and will be removed in a future version. Returns: - The modified event or None if no modification is needed. + An optional value. A non-`None` return may be used by the framework to + modify or replace the response. Returning `None` allows the original + response to be used. """ - return None + pass - async def after_run_callback(self, **kwargs: Any) -> None: + async def after_run_callback( + self, *, invocation_context: InvocationContext + ) -> None: """Callback executed after an ADK runner run has completed. + This is the final callback in the ADK lifecycle, suitable for cleanup, final + logging, or reporting tasks. + Args: - callback_context: The context for the callback execution. - invocation_context: DEPRECATED. Use callback_context instead. The context - for the entire invocation. This parameter is maintained for backward - compatibility and will be removed in a future version. + invocation_context: The context for the entire invocation. Returns: None """ - return None + pass async def before_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext diff --git a/src/google/adk/plugins/global_instruction_plugin.py b/src/google/adk/plugins/global_instruction_plugin.py index 5cdf489f..4251f691 100644 --- a/src/google/adk/plugins/global_instruction_plugin.py +++ b/src/google/adk/plugins/global_instruction_plugin.py @@ -79,7 +79,7 @@ class GlobalInstructionPlugin(BasePlugin): return None # Resolve the global instruction (handle both string and InstructionProvider) - readonly_context = callback_context + readonly_context = ReadonlyContext(callback_context.invocation_context) final_global_instruction = await self._resolve_global_instruction( readonly_context ) diff --git a/src/google/adk/plugins/logging_plugin.py b/src/google/adk/plugins/logging_plugin.py index 11773461..72d1ca83 100644 --- a/src/google/adk/plugins/logging_plugin.py +++ b/src/google/adk/plugins/logging_plugin.py @@ -69,32 +69,38 @@ class LoggingPlugin(BasePlugin): async def on_user_message_callback( self, *, - callback_context: CallbackContext, + invocation_context: InvocationContext, user_message: types.Content, ) -> Optional[types.Content]: """Log user message and invocation start.""" self._log(f"🚀 USER MESSAGE RECEIVED") - self._log(f" Invocation ID: {callback_context.invocation_id}") - self._log(f" Session ID: {callback_context.session_id}") - self._log(f" User ID: {callback_context.user_id}") - self._log(f" App Name: {callback_context.app_name}") - self._log(f" Root Agent: {callback_context.agent_name}") + self._log(f" Invocation ID: {invocation_context.invocation_id}") + self._log(f" Session ID: {invocation_context.session.id}") + self._log(f" User ID: {invocation_context.user_id}") + self._log(f" App Name: {invocation_context.app_name}") + self._log( + " Root Agent:" + f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}" + ) self._log(f" User Content: {self._format_content(user_message)}") - if callback_context.branch: - self._log(f" Branch: {callback_context.branch}") + if invocation_context.branch: + self._log(f" Branch: {invocation_context.branch}") return None async def before_run_callback( - self, *, callback_context: CallbackContext + self, *, invocation_context: InvocationContext ) -> Optional[types.Content]: """Log invocation start.""" self._log(f"🏃 INVOCATION STARTING") - self._log(f" Invocation ID: {callback_context.invocation_id}") - self._log(f" Starting Agent: {callback_context.agent_name}") + self._log(f" Invocation ID: {invocation_context.invocation_id}") + self._log( + " Starting Agent:" + f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}" + ) return None async def on_event_callback( - self, *, callback_context: CallbackContext, event: Event + self, *, invocation_context: InvocationContext, event: Event ) -> Optional[Event]: """Log events yielded from the runner.""" self._log(f"📢 EVENT YIELDED") @@ -117,12 +123,15 @@ class LoggingPlugin(BasePlugin): return None async def after_run_callback( - self, *, callback_context: CallbackContext + self, *, invocation_context: InvocationContext ) -> Optional[None]: """Log invocation completion.""" self._log(f"✅ INVOCATION COMPLETED") - self._log(f" Invocation ID: {callback_context.invocation_id}") - self._log(f" Final Agent: {callback_context.agent_name}") + self._log(f" Invocation ID: {invocation_context.invocation_id}") + self._log( + " Final Agent:" + f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}" + ) return None async def before_agent_callback( @@ -132,8 +141,8 @@ class LoggingPlugin(BasePlugin): self._log(f"🤖 AGENT STARTING") self._log(f" Agent Name: {callback_context.agent_name}") self._log(f" Invocation ID: {callback_context.invocation_id}") - if callback_context.branch: - self._log(f" Branch: {callback_context.branch}") + if callback_context._invocation_context.branch: + self._log(f" Branch: {callback_context._invocation_context.branch}") return None async def after_agent_callback( diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index 23708f15..217dbb8b 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -14,22 +14,20 @@ from __future__ import annotations -import inspect import logging from typing import Any from typing import List from typing import Literal from typing import Optional from typing import TYPE_CHECKING -import warnings from google.genai import types -from ..agents.callback_context import CallbackContext from .base_plugin import BasePlugin if TYPE_CHECKING: from ..agents.base_agent import BaseAgent + from ..agents.callback_context import CallbackContext from ..agents.invocation_context import InvocationContext from ..events.event import Event from ..models.llm_request import LlmRequest @@ -115,39 +113,35 @@ class PluginManager: invocation_context: InvocationContext, ) -> Optional[types.Content]: """Runs the `on_user_message_callback` for all plugins.""" - callback_context = CallbackContext(invocation_context) return await self._run_callbacks( "on_user_message_callback", user_message=user_message, - callback_context=callback_context, + invocation_context=invocation_context, ) async def run_before_run_callback( self, *, invocation_context: InvocationContext ) -> Optional[types.Content]: """Runs the `before_run_callback` for all plugins.""" - callback_context = CallbackContext(invocation_context) return await self._run_callbacks( - "before_run_callback", callback_context=callback_context + "before_run_callback", invocation_context=invocation_context ) async def run_after_run_callback( self, *, invocation_context: InvocationContext ) -> Optional[None]: """Runs the `after_run_callback` for all plugins.""" - callback_context = CallbackContext(invocation_context) return await self._run_callbacks( - "after_run_callback", callback_context=callback_context + "after_run_callback", invocation_context=invocation_context ) async def run_on_event_callback( self, *, invocation_context: InvocationContext, event: Event ) -> Optional[Event]: """Runs the `on_event_callback` for all plugins.""" - callback_context = CallbackContext(invocation_context) return await self._run_callbacks( "on_event_callback", - callback_context=callback_context, + invocation_context=invocation_context, event=event, ) @@ -283,14 +277,8 @@ class PluginManager: # Each plugin might not implement all callbacks. The base class provides # default `pass` implementations, so `getattr` will always succeed. callback_method = getattr(plugin, callback_name) - - # Backward compatibility: Support both callback_context and invocation_context - adapted_kwargs = self._adapt_kwargs_for_plugin( - plugin, callback_method, kwargs - ) - try: - result = await callback_method(**adapted_kwargs) + result = await callback_method(**kwargs) if result is not None: # Early exit: A plugin has returned a value. We stop # processing further plugins and return this value immediately. @@ -309,80 +297,3 @@ class PluginManager: raise RuntimeError(error_message) from e return None - - def _adapt_kwargs_for_plugin( - self, plugin: BasePlugin, callback_method: Any, kwargs: dict[str, Any] - ) -> dict[str, Any]: - """Adapts keyword arguments for backward compatibility with legacy plugins. - - This method handles the migration from invocation_context to - callback_context - by inspecting the plugin's callback method signature and providing the - appropriate parameter name. For maximum compatibility, it may pass both - parameters when the signature is ambiguous. - - Args: - plugin: The plugin instance. - callback_method: The callback method to be invoked. - kwargs: The original keyword arguments. - - Returns: - Adapted keyword arguments that match the plugin's expected signature. - """ - # If no callback_context in kwargs, no adaptation needed - if "callback_context" not in kwargs: - return kwargs.copy() - - callback_context = kwargs["callback_context"] - - try: - # Inspect the callback method signature - sig = inspect.signature(callback_method) - params = sig.parameters - - # Case 1: Method explicitly wants only invocation_context - if "invocation_context" in params and "callback_context" not in params: - # Legacy plugin - pass only invocation_context - warnings.warn( - f"Plugin '{plugin.name}' uses deprecated 'invocation_context' " - "parameter in callback methods. Please update to use " - "'callback_context' instead. Support for 'invocation_context' " - "will be removed in a future version.", - DeprecationWarning, - stacklevel=3, - ) - adapted_kwargs = kwargs.copy() - adapted_kwargs["invocation_context"] = ( - callback_context._invocation_context - ) - del adapted_kwargs["callback_context"] - return adapted_kwargs - - # Case 2: Method explicitly wants only callback_context - elif "callback_context" in params and "invocation_context" not in params: - # Modern plugin - pass only callback_context - return kwargs.copy() - - # Case 3: Method wants both, uses **kwargs, or signature is unclear - else: - # Pass both parameters for maximum compatibility - # This handles: **kwargs, both parameters explicitly, or unknown cases - adapted_kwargs = kwargs.copy() - adapted_kwargs["invocation_context"] = ( - callback_context._invocation_context - ) - return adapted_kwargs - - except (ValueError, TypeError) as e: - # Fallback: Pass both parameters for safety - logger.debug( - "Failed to inspect plugin '%s' callback signature: %s. " - "Passing both callback_context and invocation_context for safety.", - plugin.name, - e, - ) - adapted_kwargs = kwargs.copy() - adapted_kwargs["invocation_context"] = ( - callback_context._invocation_context - ) - return adapted_kwargs diff --git a/tests/unittests/plugins/test_base_plugin.py b/tests/unittests/plugins/test_base_plugin.py index 2af514fc..3a2de943 100644 --- a/tests/unittests/plugins/test_base_plugin.py +++ b/tests/unittests/plugins/test_base_plugin.py @@ -103,15 +103,19 @@ async def test_base_plugin_default_callbacks_return_none(): assert ( await plugin.on_user_message_callback( user_message=mock_user_message, - callback_context=mock_context, + invocation_context=mock_context, ) is None ) - assert await plugin.before_run_callback(callback_context=mock_context) is None - assert await plugin.after_run_callback(callback_context=mock_context) is None + assert ( + await plugin.before_run_callback(invocation_context=mock_context) is None + ) + assert ( + await plugin.after_run_callback(invocation_context=mock_context) is None + ) assert ( await plugin.on_event_callback( - callback_context=mock_context, event=mock_context + invocation_context=mock_context, event=mock_context ) is None ) @@ -196,21 +200,25 @@ async def test_base_plugin_all_callbacks_can_be_overridden(): assert ( await plugin.on_user_message_callback( user_message=mock_user_message, - callback_context=mock_callback_context, + invocation_context=mock_invocation_context, ) == "overridden_on_user_message" ) assert ( - await plugin.before_run_callback(callback_context=mock_callback_context) + await plugin.before_run_callback( + invocation_context=mock_invocation_context + ) == "overridden_before_run" ) assert ( - await plugin.after_run_callback(callback_context=mock_callback_context) + await plugin.after_run_callback( + invocation_context=mock_invocation_context + ) == "overridden_after_run" ) assert ( await plugin.on_event_callback( - callback_context=mock_callback_context, event=mock_event + invocation_context=mock_invocation_context, event=mock_event ) == "overridden_on_event" ) diff --git a/tests/unittests/plugins/test_global_instruction_plugin.py b/tests/unittests/plugins/test_global_instruction_plugin.py index 4571dd39..2253b1fb 100644 --- a/tests/unittests/plugins/test_global_instruction_plugin.py +++ b/tests/unittests/plugins/test_global_instruction_plugin.py @@ -43,7 +43,7 @@ async def test_global_instruction_plugin_with_string(): mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context._invocation_context = mock_invocation_context + mock_callback_context.invocation_context = mock_invocation_context llm_request = LlmRequest( model="gemini-1.5-flash", @@ -70,7 +70,7 @@ async def test_global_instruction_plugin_with_instruction_provider(): """Test GlobalInstructionPlugin with an InstructionProvider function.""" async def build_global_instruction(readonly_context: ReadonlyContext) -> str: - return f"You are assistant for user {readonly_context.user_id}." + return f"You are assistant for user {readonly_context.session.user_id}." plugin = GlobalInstructionPlugin(global_instruction=build_global_instruction) @@ -83,8 +83,7 @@ async def test_global_instruction_plugin_with_instruction_provider(): mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context._invocation_context = mock_invocation_context - mock_callback_context.user_id = "alice" + mock_callback_context.invocation_context = mock_invocation_context llm_request = LlmRequest( model="gemini-1.5-flash", @@ -120,7 +119,7 @@ async def test_global_instruction_plugin_empty_instruction(): mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context._invocation_context = mock_invocation_context + mock_callback_context.invocation_context = mock_invocation_context llm_request = LlmRequest( model="gemini-1.5-flash", @@ -157,7 +156,7 @@ async def test_global_instruction_plugin_leads_existing(): mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context._invocation_context = mock_invocation_context + mock_callback_context.invocation_context = mock_invocation_context llm_request = LlmRequest( model="gemini-1.5-flash", @@ -192,7 +191,7 @@ async def test_global_instruction_plugin_prepends_to_list(): mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context._invocation_context = mock_invocation_context + mock_callback_context.invocation_context = mock_invocation_context llm_request = LlmRequest( model="gemini-1.5-flash", diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 8c003b36..e6578aab 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -15,8 +15,8 @@ from typing import Optional from google.adk.agents.base_agent import BaseAgent -from google.adk.agents.callback_context import CallbackContext from google.adk.agents.context_cache_config import ContextCacheConfig +from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.apps.app import App from google.adk.apps.app import ResumabilityConfig @@ -98,7 +98,7 @@ class MockPlugin(BasePlugin): async def on_user_message_callback( self, *, - callback_context: CallbackContext, + invocation_context: InvocationContext, user_message: types.Content, ) -> Optional[types.Content]: if not self.enable_user_message_callback: @@ -109,7 +109,7 @@ class MockPlugin(BasePlugin): ) async def on_event_callback( - self, *, callback_context: CallbackContext, event: Event + self, *, invocation_context: InvocationContext, event: Event ) -> Optional[Event]: if not self.enable_event_callback: return None