feat: migrate invocation_context to callback_context

Update plugin manager and built-in plugins to prioritize CallbackContext. Keep InvocationContext access for legacy plugins with adapter. Change callback docs/tests to cover the new context.

PiperOrigin-RevId: 818822267
This commit is contained in:
Google Team Member
2025-10-13 14:05:09 -07:00
committed by Copybara-Service
parent 2158b3c915
commit df05ed6b3b
10 changed files with 110 additions and 240 deletions
+3 -18
View File
@@ -55,21 +55,6 @@ class ReadonlyContext:
return MappingProxyType(self._invocation_context.session.state)
@property
def user_id(self) -> str:
"""The user ID for the current invocation."""
return self._invocation_context.user_id
@property
def app_name(self) -> str:
"""The application name for the current invocation."""
return self._invocation_context.app_name
@property
def session_id(self) -> str:
"""The session ID for the current invocation."""
return self._invocation_context.session.id
@property
def branch(self) -> Optional[str]:
"""The branch name for the current invocation, if any."""
return self._invocation_context.branch
def session(self) -> Session:
"""The current session for this invocation."""
return self._invocation_context.session
@@ -39,6 +39,7 @@ from .recordings_schema import Recordings
from .recordings_schema import ToolRecording
if TYPE_CHECKING:
from ...agents.invocation_context import InvocationContext
from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext
@@ -74,10 +75,10 @@ class RecordingsPlugin(BasePlugin):
@override
async def before_run_callback(
self, *, callback_context: CallbackContext
self, *, invocation_context: InvocationContext
) -> Optional[types.Content]:
"""Always create fresh per-invocation recording state when enabled."""
ctx = callback_context
ctx = CallbackContext(invocation_context)
if self._is_record_mode_on(ctx):
# Always create/overwrite the state for this invocation
self._create_invocation_state(ctx)
@@ -279,10 +280,10 @@ class RecordingsPlugin(BasePlugin):
@override
async def after_run_callback(
self, *, callback_context: CallbackContext
self, *, invocation_context: InvocationContext
) -> None:
"""Finalize and persist recordings, then clean per-invocation state."""
ctx = callback_context
ctx = CallbackContext(invocation_context)
if not self._is_record_mode_on(ctx):
return None
+5 -4
View File
@@ -38,6 +38,7 @@ from .recordings_schema import Recordings
from .recordings_schema import ToolRecording
if TYPE_CHECKING:
from ...agents.invocation_context import InvocationContext
from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext
@@ -80,10 +81,10 @@ class ReplayPlugin(BasePlugin):
@override
async def before_run_callback(
self, *, callback_context: CallbackContext
self, *, invocation_context: InvocationContext
) -> Optional[types.Content]:
"""Load replay recordings when enabled."""
ctx = callback_context
ctx = CallbackContext(invocation_context)
if self._is_replay_mode_on(ctx):
# Load the replay state for this invocation
self._load_invocation_state(ctx)
@@ -155,10 +156,10 @@ class ReplayPlugin(BasePlugin):
@override
async def after_run_callback(
self, *, callback_context: CallbackContext
self, *, invocation_context: InvocationContext
) -> None:
"""Clean up replay state after invocation completes."""
ctx = callback_context
ctx = CallbackContext(invocation_context)
if not self._is_replay_mode_on(ctx):
return None
+39 -83
View File
@@ -111,61 +111,11 @@ class BasePlugin(ABC):
super().__init__()
self.name = name
if TYPE_CHECKING:
async def on_user_message_callback(
self,
*,
callback_context: Optional[CallbackContext] = None,
user_message: types.Content,
invocation_context: Optional[InvocationContext] = None,
) -> Optional[types.Content]:
"""Callback executed when a user message is received before an invocation starts.
Plugins can implement this with either callback_context (new) or
invocation_context (deprecated) or both.
"""
async def before_run_callback(
self,
*,
callback_context: Optional[CallbackContext] = None,
invocation_context: Optional[InvocationContext] = None,
) -> Optional[types.Content]:
"""Callback executed before the ADK runner runs.
Plugins can implement this with either callback_context (new) or
invocation_context (deprecated) or both.
"""
async def on_event_callback(
self,
*,
callback_context: Optional[CallbackContext] = None,
event: Event,
invocation_context: Optional[InvocationContext] = None,
) -> Optional[Event]:
"""Callback executed after an event is yielded from runner.
Plugins can implement this with either callback_context (new) or
invocation_context (deprecated) or both.
"""
async def after_run_callback(
self,
*,
callback_context: Optional[CallbackContext] = None,
invocation_context: Optional[InvocationContext] = None,
) -> None:
"""Callback executed after an ADK runner run has completed.
Plugins can implement this with either callback_context (new) or
invocation_context (deprecated) or both.
"""
# Runtime implementation accepts both via **kwargs
async def on_user_message_callback(
self, **kwargs: Any
self,
*,
invocation_context: InvocationContext,
user_message: types.Content,
) -> Optional[types.Content]:
"""Callback executed when a user message is received before an invocation starts.
@@ -173,63 +123,69 @@ class BasePlugin(ABC):
runner starts the invocation.
Args:
callback_context: The context for the callback execution.
invocation_context: The context for the entire invocation.
user_message: The message content input by user.
invocation_context: DEPRECATED. Use callback_context instead. The context
for the entire invocation. This parameter is maintained for backward
compatibility and will be removed in a future version.
Returns:
The modified user message or None if no modification is needed.
An optional `types.Content` to be returned to the ADK. Returning a
value to replace the user message. Returning `None` to proceed
normally.
"""
return None
pass
async def before_run_callback(self, **kwargs: Any) -> Optional[types.Content]:
async def before_run_callback(
self, *, invocation_context: InvocationContext
) -> Optional[types.Content]:
"""Callback executed before the ADK runner runs.
This is the first lifecycle hook and is ideal for global setup, logging,
or checks that may stop the invocation from running.
This is the first callback to be called in the lifecycle, ideal for global
setup or initialization tasks.
Args:
callback_context: The context for the callback execution.
invocation_context: DEPRECATED. Use callback_context instead. The context
for the entire invocation. This parameter is maintained for backward
compatibility and will be removed in a future version.
invocation_context: The context for the entire invocation, containing
session information, the root agent, etc.
Returns:
Optional `types.Content` to halt execution and return the value to the
caller. Return `None` to proceed normally.
An optional `Event` to be returned to the ADK. Returning a value to
halt execution of the runner and ends the runner with that event. Return
`None` to proceed normally.
"""
return None
pass
async def on_event_callback(self, **kwargs: Any) -> Optional[Event]:
async def on_event_callback(
self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]:
"""Callback executed after an event is yielded from runner.
This is the ideal place to make modification to the event before the event
is handled by the underlying agent app.
Args:
callback_context: The context for the callback execution.
invocation_context: The context for the entire invocation.
event: The event raised by the runner.
invocation_context: DEPRECATED. Use callback_context instead. The context
for the entire invocation. This parameter is maintained for backward
compatibility and will be removed in a future version.
Returns:
The modified event or None if no modification is needed.
An optional value. A non-`None` return may be used by the framework to
modify or replace the response. Returning `None` allows the original
response to be used.
"""
return None
pass
async def after_run_callback(self, **kwargs: Any) -> None:
async def after_run_callback(
self, *, invocation_context: InvocationContext
) -> None:
"""Callback executed after an ADK runner run has completed.
This is the final callback in the ADK lifecycle, suitable for cleanup, final
logging, or reporting tasks.
Args:
callback_context: The context for the callback execution.
invocation_context: DEPRECATED. Use callback_context instead. The context
for the entire invocation. This parameter is maintained for backward
compatibility and will be removed in a future version.
invocation_context: The context for the entire invocation.
Returns:
None
"""
return None
pass
async def before_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
@@ -79,7 +79,7 @@ class GlobalInstructionPlugin(BasePlugin):
return None
# Resolve the global instruction (handle both string and InstructionProvider)
readonly_context = callback_context
readonly_context = ReadonlyContext(callback_context.invocation_context)
final_global_instruction = await self._resolve_global_instruction(
readonly_context
)
+26 -17
View File
@@ -69,32 +69,38 @@ class LoggingPlugin(BasePlugin):
async def on_user_message_callback(
self,
*,
callback_context: CallbackContext,
invocation_context: InvocationContext,
user_message: types.Content,
) -> Optional[types.Content]:
"""Log user message and invocation start."""
self._log(f"🚀 USER MESSAGE RECEIVED")
self._log(f" Invocation ID: {callback_context.invocation_id}")
self._log(f" Session ID: {callback_context.session_id}")
self._log(f" User ID: {callback_context.user_id}")
self._log(f" App Name: {callback_context.app_name}")
self._log(f" Root Agent: {callback_context.agent_name}")
self._log(f" Invocation ID: {invocation_context.invocation_id}")
self._log(f" Session ID: {invocation_context.session.id}")
self._log(f" User ID: {invocation_context.user_id}")
self._log(f" App Name: {invocation_context.app_name}")
self._log(
" Root Agent:"
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
)
self._log(f" User Content: {self._format_content(user_message)}")
if callback_context.branch:
self._log(f" Branch: {callback_context.branch}")
if invocation_context.branch:
self._log(f" Branch: {invocation_context.branch}")
return None
async def before_run_callback(
self, *, callback_context: CallbackContext
self, *, invocation_context: InvocationContext
) -> Optional[types.Content]:
"""Log invocation start."""
self._log(f"🏃 INVOCATION STARTING")
self._log(f" Invocation ID: {callback_context.invocation_id}")
self._log(f" Starting Agent: {callback_context.agent_name}")
self._log(f" Invocation ID: {invocation_context.invocation_id}")
self._log(
" Starting Agent:"
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
)
return None
async def on_event_callback(
self, *, callback_context: CallbackContext, event: Event
self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]:
"""Log events yielded from the runner."""
self._log(f"📢 EVENT YIELDED")
@@ -117,12 +123,15 @@ class LoggingPlugin(BasePlugin):
return None
async def after_run_callback(
self, *, callback_context: CallbackContext
self, *, invocation_context: InvocationContext
) -> Optional[None]:
"""Log invocation completion."""
self._log(f"✅ INVOCATION COMPLETED")
self._log(f" Invocation ID: {callback_context.invocation_id}")
self._log(f" Final Agent: {callback_context.agent_name}")
self._log(f" Invocation ID: {invocation_context.invocation_id}")
self._log(
" Final Agent:"
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
)
return None
async def before_agent_callback(
@@ -132,8 +141,8 @@ class LoggingPlugin(BasePlugin):
self._log(f"🤖 AGENT STARTING")
self._log(f" Agent Name: {callback_context.agent_name}")
self._log(f" Invocation ID: {callback_context.invocation_id}")
if callback_context.branch:
self._log(f" Branch: {callback_context.branch}")
if callback_context._invocation_context.branch:
self._log(f" Branch: {callback_context._invocation_context.branch}")
return None
async def after_agent_callback(
+6 -95
View File
@@ -14,22 +14,20 @@
from __future__ import annotations
import inspect
import logging
from typing import Any
from typing import List
from typing import Literal
from typing import Optional
from typing import TYPE_CHECKING
import warnings
from google.genai import types
from ..agents.callback_context import CallbackContext
from .base_plugin import BasePlugin
if TYPE_CHECKING:
from ..agents.base_agent import BaseAgent
from ..agents.callback_context import CallbackContext
from ..agents.invocation_context import InvocationContext
from ..events.event import Event
from ..models.llm_request import LlmRequest
@@ -115,39 +113,35 @@ class PluginManager:
invocation_context: InvocationContext,
) -> Optional[types.Content]:
"""Runs the `on_user_message_callback` for all plugins."""
callback_context = CallbackContext(invocation_context)
return await self._run_callbacks(
"on_user_message_callback",
user_message=user_message,
callback_context=callback_context,
invocation_context=invocation_context,
)
async def run_before_run_callback(
self, *, invocation_context: InvocationContext
) -> Optional[types.Content]:
"""Runs the `before_run_callback` for all plugins."""
callback_context = CallbackContext(invocation_context)
return await self._run_callbacks(
"before_run_callback", callback_context=callback_context
"before_run_callback", invocation_context=invocation_context
)
async def run_after_run_callback(
self, *, invocation_context: InvocationContext
) -> Optional[None]:
"""Runs the `after_run_callback` for all plugins."""
callback_context = CallbackContext(invocation_context)
return await self._run_callbacks(
"after_run_callback", callback_context=callback_context
"after_run_callback", invocation_context=invocation_context
)
async def run_on_event_callback(
self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]:
"""Runs the `on_event_callback` for all plugins."""
callback_context = CallbackContext(invocation_context)
return await self._run_callbacks(
"on_event_callback",
callback_context=callback_context,
invocation_context=invocation_context,
event=event,
)
@@ -283,14 +277,8 @@ class PluginManager:
# Each plugin might not implement all callbacks. The base class provides
# default `pass` implementations, so `getattr` will always succeed.
callback_method = getattr(plugin, callback_name)
# Backward compatibility: Support both callback_context and invocation_context
adapted_kwargs = self._adapt_kwargs_for_plugin(
plugin, callback_method, kwargs
)
try:
result = await callback_method(**adapted_kwargs)
result = await callback_method(**kwargs)
if result is not None:
# Early exit: A plugin has returned a value. We stop
# processing further plugins and return this value immediately.
@@ -309,80 +297,3 @@ class PluginManager:
raise RuntimeError(error_message) from e
return None
def _adapt_kwargs_for_plugin(
self, plugin: BasePlugin, callback_method: Any, kwargs: dict[str, Any]
) -> dict[str, Any]:
"""Adapts keyword arguments for backward compatibility with legacy plugins.
This method handles the migration from invocation_context to
callback_context
by inspecting the plugin's callback method signature and providing the
appropriate parameter name. For maximum compatibility, it may pass both
parameters when the signature is ambiguous.
Args:
plugin: The plugin instance.
callback_method: The callback method to be invoked.
kwargs: The original keyword arguments.
Returns:
Adapted keyword arguments that match the plugin's expected signature.
"""
# If no callback_context in kwargs, no adaptation needed
if "callback_context" not in kwargs:
return kwargs.copy()
callback_context = kwargs["callback_context"]
try:
# Inspect the callback method signature
sig = inspect.signature(callback_method)
params = sig.parameters
# Case 1: Method explicitly wants only invocation_context
if "invocation_context" in params and "callback_context" not in params:
# Legacy plugin - pass only invocation_context
warnings.warn(
f"Plugin '{plugin.name}' uses deprecated 'invocation_context' "
"parameter in callback methods. Please update to use "
"'callback_context' instead. Support for 'invocation_context' "
"will be removed in a future version.",
DeprecationWarning,
stacklevel=3,
)
adapted_kwargs = kwargs.copy()
adapted_kwargs["invocation_context"] = (
callback_context._invocation_context
)
del adapted_kwargs["callback_context"]
return adapted_kwargs
# Case 2: Method explicitly wants only callback_context
elif "callback_context" in params and "invocation_context" not in params:
# Modern plugin - pass only callback_context
return kwargs.copy()
# Case 3: Method wants both, uses **kwargs, or signature is unclear
else:
# Pass both parameters for maximum compatibility
# This handles: **kwargs, both parameters explicitly, or unknown cases
adapted_kwargs = kwargs.copy()
adapted_kwargs["invocation_context"] = (
callback_context._invocation_context
)
return adapted_kwargs
except (ValueError, TypeError) as e:
# Fallback: Pass both parameters for safety
logger.debug(
"Failed to inspect plugin '%s' callback signature: %s. "
"Passing both callback_context and invocation_context for safety.",
plugin.name,
e,
)
adapted_kwargs = kwargs.copy()
adapted_kwargs["invocation_context"] = (
callback_context._invocation_context
)
return adapted_kwargs
+16 -8
View File
@@ -103,15 +103,19 @@ async def test_base_plugin_default_callbacks_return_none():
assert (
await plugin.on_user_message_callback(
user_message=mock_user_message,
callback_context=mock_context,
invocation_context=mock_context,
)
is None
)
assert await plugin.before_run_callback(callback_context=mock_context) is None
assert await plugin.after_run_callback(callback_context=mock_context) is None
assert (
await plugin.before_run_callback(invocation_context=mock_context) is None
)
assert (
await plugin.after_run_callback(invocation_context=mock_context) is None
)
assert (
await plugin.on_event_callback(
callback_context=mock_context, event=mock_context
invocation_context=mock_context, event=mock_context
)
is None
)
@@ -196,21 +200,25 @@ async def test_base_plugin_all_callbacks_can_be_overridden():
assert (
await plugin.on_user_message_callback(
user_message=mock_user_message,
callback_context=mock_callback_context,
invocation_context=mock_invocation_context,
)
== "overridden_on_user_message"
)
assert (
await plugin.before_run_callback(callback_context=mock_callback_context)
await plugin.before_run_callback(
invocation_context=mock_invocation_context
)
== "overridden_before_run"
)
assert (
await plugin.after_run_callback(callback_context=mock_callback_context)
await plugin.after_run_callback(
invocation_context=mock_invocation_context
)
== "overridden_after_run"
)
assert (
await plugin.on_event_callback(
callback_context=mock_callback_context, event=mock_event
invocation_context=mock_invocation_context, event=mock_event
)
== "overridden_on_event"
)
@@ -43,7 +43,7 @@ async def test_global_instruction_plugin_with_string():
mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context
mock_callback_context.invocation_context = mock_invocation_context
llm_request = LlmRequest(
model="gemini-1.5-flash",
@@ -70,7 +70,7 @@ async def test_global_instruction_plugin_with_instruction_provider():
"""Test GlobalInstructionPlugin with an InstructionProvider function."""
async def build_global_instruction(readonly_context: ReadonlyContext) -> str:
return f"You are assistant for user {readonly_context.user_id}."
return f"You are assistant for user {readonly_context.session.user_id}."
plugin = GlobalInstructionPlugin(global_instruction=build_global_instruction)
@@ -83,8 +83,7 @@ async def test_global_instruction_plugin_with_instruction_provider():
mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context
mock_callback_context.user_id = "alice"
mock_callback_context.invocation_context = mock_invocation_context
llm_request = LlmRequest(
model="gemini-1.5-flash",
@@ -120,7 +119,7 @@ async def test_global_instruction_plugin_empty_instruction():
mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context
mock_callback_context.invocation_context = mock_invocation_context
llm_request = LlmRequest(
model="gemini-1.5-flash",
@@ -157,7 +156,7 @@ async def test_global_instruction_plugin_leads_existing():
mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context
mock_callback_context.invocation_context = mock_invocation_context
llm_request = LlmRequest(
model="gemini-1.5-flash",
@@ -192,7 +191,7 @@ async def test_global_instruction_plugin_prepends_to_list():
mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context
mock_callback_context.invocation_context = mock_invocation_context
llm_request = LlmRequest(
model="gemini-1.5-flash",
+3 -3
View File
@@ -15,8 +15,8 @@
from typing import Optional
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.apps.app import App
from google.adk.apps.app import ResumabilityConfig
@@ -98,7 +98,7 @@ class MockPlugin(BasePlugin):
async def on_user_message_callback(
self,
*,
callback_context: CallbackContext,
invocation_context: InvocationContext,
user_message: types.Content,
) -> Optional[types.Content]:
if not self.enable_user_message_callback:
@@ -109,7 +109,7 @@ class MockPlugin(BasePlugin):
)
async def on_event_callback(
self, *, callback_context: CallbackContext, event: Event
self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]:
if not self.enable_event_callback:
return None