You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: migrate invocation_context to callback_context
Update plugin manager and built-in plugins to prioritize CallbackContext. Keep InvocationContext access for legacy plugins with adapter. Change callback docs/tests to cover the new context. PiperOrigin-RevId: 818822267
This commit is contained in:
committed by
Copybara-Service
parent
2158b3c915
commit
df05ed6b3b
@@ -55,21 +55,6 @@ class ReadonlyContext:
|
|||||||
return MappingProxyType(self._invocation_context.session.state)
|
return MappingProxyType(self._invocation_context.session.state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user_id(self) -> str:
|
def session(self) -> Session:
|
||||||
"""The user ID for the current invocation."""
|
"""The current session for this invocation."""
|
||||||
return self._invocation_context.user_id
|
return self._invocation_context.session
|
||||||
|
|
||||||
@property
|
|
||||||
def app_name(self) -> str:
|
|
||||||
"""The application name for the current invocation."""
|
|
||||||
return self._invocation_context.app_name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def session_id(self) -> str:
|
|
||||||
"""The session ID for the current invocation."""
|
|
||||||
return self._invocation_context.session.id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def branch(self) -> Optional[str]:
|
|
||||||
"""The branch name for the current invocation, if any."""
|
|
||||||
return self._invocation_context.branch
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from .recordings_schema import Recordings
|
|||||||
from .recordings_schema import ToolRecording
|
from .recordings_schema import ToolRecording
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from ...agents.invocation_context import InvocationContext
|
||||||
from ...tools.base_tool import BaseTool
|
from ...tools.base_tool import BaseTool
|
||||||
from ...tools.tool_context import ToolContext
|
from ...tools.tool_context import ToolContext
|
||||||
|
|
||||||
@@ -74,10 +75,10 @@ class RecordingsPlugin(BasePlugin):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def before_run_callback(
|
async def before_run_callback(
|
||||||
self, *, callback_context: CallbackContext
|
self, *, invocation_context: InvocationContext
|
||||||
) -> Optional[types.Content]:
|
) -> Optional[types.Content]:
|
||||||
"""Always create fresh per-invocation recording state when enabled."""
|
"""Always create fresh per-invocation recording state when enabled."""
|
||||||
ctx = callback_context
|
ctx = CallbackContext(invocation_context)
|
||||||
if self._is_record_mode_on(ctx):
|
if self._is_record_mode_on(ctx):
|
||||||
# Always create/overwrite the state for this invocation
|
# Always create/overwrite the state for this invocation
|
||||||
self._create_invocation_state(ctx)
|
self._create_invocation_state(ctx)
|
||||||
@@ -279,10 +280,10 @@ class RecordingsPlugin(BasePlugin):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def after_run_callback(
|
async def after_run_callback(
|
||||||
self, *, callback_context: CallbackContext
|
self, *, invocation_context: InvocationContext
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Finalize and persist recordings, then clean per-invocation state."""
|
"""Finalize and persist recordings, then clean per-invocation state."""
|
||||||
ctx = callback_context
|
ctx = CallbackContext(invocation_context)
|
||||||
if not self._is_record_mode_on(ctx):
|
if not self._is_record_mode_on(ctx):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from .recordings_schema import Recordings
|
|||||||
from .recordings_schema import ToolRecording
|
from .recordings_schema import ToolRecording
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from ...agents.invocation_context import InvocationContext
|
||||||
from ...tools.base_tool import BaseTool
|
from ...tools.base_tool import BaseTool
|
||||||
from ...tools.tool_context import ToolContext
|
from ...tools.tool_context import ToolContext
|
||||||
|
|
||||||
@@ -80,10 +81,10 @@ class ReplayPlugin(BasePlugin):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def before_run_callback(
|
async def before_run_callback(
|
||||||
self, *, callback_context: CallbackContext
|
self, *, invocation_context: InvocationContext
|
||||||
) -> Optional[types.Content]:
|
) -> Optional[types.Content]:
|
||||||
"""Load replay recordings when enabled."""
|
"""Load replay recordings when enabled."""
|
||||||
ctx = callback_context
|
ctx = CallbackContext(invocation_context)
|
||||||
if self._is_replay_mode_on(ctx):
|
if self._is_replay_mode_on(ctx):
|
||||||
# Load the replay state for this invocation
|
# Load the replay state for this invocation
|
||||||
self._load_invocation_state(ctx)
|
self._load_invocation_state(ctx)
|
||||||
@@ -155,10 +156,10 @@ class ReplayPlugin(BasePlugin):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def after_run_callback(
|
async def after_run_callback(
|
||||||
self, *, callback_context: CallbackContext
|
self, *, invocation_context: InvocationContext
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Clean up replay state after invocation completes."""
|
"""Clean up replay state after invocation completes."""
|
||||||
ctx = callback_context
|
ctx = CallbackContext(invocation_context)
|
||||||
if not self._is_replay_mode_on(ctx):
|
if not self._is_replay_mode_on(ctx):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -111,61 +111,11 @@ class BasePlugin(ABC):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
|
|
||||||
async def on_user_message_callback(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
callback_context: Optional[CallbackContext] = None,
|
|
||||||
user_message: types.Content,
|
|
||||||
invocation_context: Optional[InvocationContext] = None,
|
|
||||||
) -> Optional[types.Content]:
|
|
||||||
"""Callback executed when a user message is received before an invocation starts.
|
|
||||||
|
|
||||||
Plugins can implement this with either callback_context (new) or
|
|
||||||
invocation_context (deprecated) or both.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def before_run_callback(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
callback_context: Optional[CallbackContext] = None,
|
|
||||||
invocation_context: Optional[InvocationContext] = None,
|
|
||||||
) -> Optional[types.Content]:
|
|
||||||
"""Callback executed before the ADK runner runs.
|
|
||||||
|
|
||||||
Plugins can implement this with either callback_context (new) or
|
|
||||||
invocation_context (deprecated) or both.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def on_event_callback(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
callback_context: Optional[CallbackContext] = None,
|
|
||||||
event: Event,
|
|
||||||
invocation_context: Optional[InvocationContext] = None,
|
|
||||||
) -> Optional[Event]:
|
|
||||||
"""Callback executed after an event is yielded from runner.
|
|
||||||
|
|
||||||
Plugins can implement this with either callback_context (new) or
|
|
||||||
invocation_context (deprecated) or both.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def after_run_callback(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
callback_context: Optional[CallbackContext] = None,
|
|
||||||
invocation_context: Optional[InvocationContext] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Callback executed after an ADK runner run has completed.
|
|
||||||
|
|
||||||
Plugins can implement this with either callback_context (new) or
|
|
||||||
invocation_context (deprecated) or both.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Runtime implementation accepts both via **kwargs
|
|
||||||
async def on_user_message_callback(
|
async def on_user_message_callback(
|
||||||
self, **kwargs: Any
|
self,
|
||||||
|
*,
|
||||||
|
invocation_context: InvocationContext,
|
||||||
|
user_message: types.Content,
|
||||||
) -> Optional[types.Content]:
|
) -> Optional[types.Content]:
|
||||||
"""Callback executed when a user message is received before an invocation starts.
|
"""Callback executed when a user message is received before an invocation starts.
|
||||||
|
|
||||||
@@ -173,63 +123,69 @@ class BasePlugin(ABC):
|
|||||||
runner starts the invocation.
|
runner starts the invocation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback_context: The context for the callback execution.
|
invocation_context: The context for the entire invocation.
|
||||||
user_message: The message content input by user.
|
user_message: The message content input by user.
|
||||||
invocation_context: DEPRECATED. Use callback_context instead. The context
|
|
||||||
for the entire invocation. This parameter is maintained for backward
|
|
||||||
compatibility and will be removed in a future version.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The modified user message or None if no modification is needed.
|
An optional `types.Content` to be returned to the ADK. Returning a
|
||||||
|
value to replace the user message. Returning `None` to proceed
|
||||||
|
normally.
|
||||||
"""
|
"""
|
||||||
return None
|
pass
|
||||||
|
|
||||||
async def before_run_callback(self, **kwargs: Any) -> Optional[types.Content]:
|
async def before_run_callback(
|
||||||
|
self, *, invocation_context: InvocationContext
|
||||||
|
) -> Optional[types.Content]:
|
||||||
"""Callback executed before the ADK runner runs.
|
"""Callback executed before the ADK runner runs.
|
||||||
|
|
||||||
This is the first lifecycle hook and is ideal for global setup, logging,
|
This is the first callback to be called in the lifecycle, ideal for global
|
||||||
or checks that may stop the invocation from running.
|
setup or initialization tasks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback_context: The context for the callback execution.
|
invocation_context: The context for the entire invocation, containing
|
||||||
invocation_context: DEPRECATED. Use callback_context instead. The context
|
session information, the root agent, etc.
|
||||||
for the entire invocation. This parameter is maintained for backward
|
|
||||||
compatibility and will be removed in a future version.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional `types.Content` to halt execution and return the value to the
|
An optional `Event` to be returned to the ADK. Returning a value to
|
||||||
caller. Return `None` to proceed normally.
|
halt execution of the runner and ends the runner with that event. Return
|
||||||
|
`None` to proceed normally.
|
||||||
"""
|
"""
|
||||||
return None
|
pass
|
||||||
|
|
||||||
async def on_event_callback(self, **kwargs: Any) -> Optional[Event]:
|
async def on_event_callback(
|
||||||
|
self, *, invocation_context: InvocationContext, event: Event
|
||||||
|
) -> Optional[Event]:
|
||||||
"""Callback executed after an event is yielded from runner.
|
"""Callback executed after an event is yielded from runner.
|
||||||
|
|
||||||
|
This is the ideal place to make modification to the event before the event
|
||||||
|
is handled by the underlying agent app.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback_context: The context for the callback execution.
|
invocation_context: The context for the entire invocation.
|
||||||
event: The event raised by the runner.
|
event: The event raised by the runner.
|
||||||
invocation_context: DEPRECATED. Use callback_context instead. The context
|
|
||||||
for the entire invocation. This parameter is maintained for backward
|
|
||||||
compatibility and will be removed in a future version.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The modified event or None if no modification is needed.
|
An optional value. A non-`None` return may be used by the framework to
|
||||||
|
modify or replace the response. Returning `None` allows the original
|
||||||
|
response to be used.
|
||||||
"""
|
"""
|
||||||
return None
|
pass
|
||||||
|
|
||||||
async def after_run_callback(self, **kwargs: Any) -> None:
|
async def after_run_callback(
|
||||||
|
self, *, invocation_context: InvocationContext
|
||||||
|
) -> None:
|
||||||
"""Callback executed after an ADK runner run has completed.
|
"""Callback executed after an ADK runner run has completed.
|
||||||
|
|
||||||
|
This is the final callback in the ADK lifecycle, suitable for cleanup, final
|
||||||
|
logging, or reporting tasks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback_context: The context for the callback execution.
|
invocation_context: The context for the entire invocation.
|
||||||
invocation_context: DEPRECATED. Use callback_context instead. The context
|
|
||||||
for the entire invocation. This parameter is maintained for backward
|
|
||||||
compatibility and will be removed in a future version.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
return None
|
pass
|
||||||
|
|
||||||
async def before_agent_callback(
|
async def before_agent_callback(
|
||||||
self, *, agent: BaseAgent, callback_context: CallbackContext
|
self, *, agent: BaseAgent, callback_context: CallbackContext
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class GlobalInstructionPlugin(BasePlugin):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Resolve the global instruction (handle both string and InstructionProvider)
|
# Resolve the global instruction (handle both string and InstructionProvider)
|
||||||
readonly_context = callback_context
|
readonly_context = ReadonlyContext(callback_context.invocation_context)
|
||||||
final_global_instruction = await self._resolve_global_instruction(
|
final_global_instruction = await self._resolve_global_instruction(
|
||||||
readonly_context
|
readonly_context
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -69,32 +69,38 @@ class LoggingPlugin(BasePlugin):
|
|||||||
async def on_user_message_callback(
|
async def on_user_message_callback(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
callback_context: CallbackContext,
|
invocation_context: InvocationContext,
|
||||||
user_message: types.Content,
|
user_message: types.Content,
|
||||||
) -> Optional[types.Content]:
|
) -> Optional[types.Content]:
|
||||||
"""Log user message and invocation start."""
|
"""Log user message and invocation start."""
|
||||||
self._log(f"🚀 USER MESSAGE RECEIVED")
|
self._log(f"🚀 USER MESSAGE RECEIVED")
|
||||||
self._log(f" Invocation ID: {callback_context.invocation_id}")
|
self._log(f" Invocation ID: {invocation_context.invocation_id}")
|
||||||
self._log(f" Session ID: {callback_context.session_id}")
|
self._log(f" Session ID: {invocation_context.session.id}")
|
||||||
self._log(f" User ID: {callback_context.user_id}")
|
self._log(f" User ID: {invocation_context.user_id}")
|
||||||
self._log(f" App Name: {callback_context.app_name}")
|
self._log(f" App Name: {invocation_context.app_name}")
|
||||||
self._log(f" Root Agent: {callback_context.agent_name}")
|
self._log(
|
||||||
|
" Root Agent:"
|
||||||
|
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
|
||||||
|
)
|
||||||
self._log(f" User Content: {self._format_content(user_message)}")
|
self._log(f" User Content: {self._format_content(user_message)}")
|
||||||
if callback_context.branch:
|
if invocation_context.branch:
|
||||||
self._log(f" Branch: {callback_context.branch}")
|
self._log(f" Branch: {invocation_context.branch}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def before_run_callback(
|
async def before_run_callback(
|
||||||
self, *, callback_context: CallbackContext
|
self, *, invocation_context: InvocationContext
|
||||||
) -> Optional[types.Content]:
|
) -> Optional[types.Content]:
|
||||||
"""Log invocation start."""
|
"""Log invocation start."""
|
||||||
self._log(f"🏃 INVOCATION STARTING")
|
self._log(f"🏃 INVOCATION STARTING")
|
||||||
self._log(f" Invocation ID: {callback_context.invocation_id}")
|
self._log(f" Invocation ID: {invocation_context.invocation_id}")
|
||||||
self._log(f" Starting Agent: {callback_context.agent_name}")
|
self._log(
|
||||||
|
" Starting Agent:"
|
||||||
|
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def on_event_callback(
|
async def on_event_callback(
|
||||||
self, *, callback_context: CallbackContext, event: Event
|
self, *, invocation_context: InvocationContext, event: Event
|
||||||
) -> Optional[Event]:
|
) -> Optional[Event]:
|
||||||
"""Log events yielded from the runner."""
|
"""Log events yielded from the runner."""
|
||||||
self._log(f"📢 EVENT YIELDED")
|
self._log(f"📢 EVENT YIELDED")
|
||||||
@@ -117,12 +123,15 @@ class LoggingPlugin(BasePlugin):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def after_run_callback(
|
async def after_run_callback(
|
||||||
self, *, callback_context: CallbackContext
|
self, *, invocation_context: InvocationContext
|
||||||
) -> Optional[None]:
|
) -> Optional[None]:
|
||||||
"""Log invocation completion."""
|
"""Log invocation completion."""
|
||||||
self._log(f"✅ INVOCATION COMPLETED")
|
self._log(f"✅ INVOCATION COMPLETED")
|
||||||
self._log(f" Invocation ID: {callback_context.invocation_id}")
|
self._log(f" Invocation ID: {invocation_context.invocation_id}")
|
||||||
self._log(f" Final Agent: {callback_context.agent_name}")
|
self._log(
|
||||||
|
" Final Agent:"
|
||||||
|
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def before_agent_callback(
|
async def before_agent_callback(
|
||||||
@@ -132,8 +141,8 @@ class LoggingPlugin(BasePlugin):
|
|||||||
self._log(f"🤖 AGENT STARTING")
|
self._log(f"🤖 AGENT STARTING")
|
||||||
self._log(f" Agent Name: {callback_context.agent_name}")
|
self._log(f" Agent Name: {callback_context.agent_name}")
|
||||||
self._log(f" Invocation ID: {callback_context.invocation_id}")
|
self._log(f" Invocation ID: {callback_context.invocation_id}")
|
||||||
if callback_context.branch:
|
if callback_context._invocation_context.branch:
|
||||||
self._log(f" Branch: {callback_context.branch}")
|
self._log(f" Branch: {callback_context._invocation_context.branch}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def after_agent_callback(
|
async def after_agent_callback(
|
||||||
|
|||||||
@@ -14,22 +14,20 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
import warnings
|
|
||||||
|
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from ..agents.callback_context import CallbackContext
|
|
||||||
from .base_plugin import BasePlugin
|
from .base_plugin import BasePlugin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..agents.base_agent import BaseAgent
|
from ..agents.base_agent import BaseAgent
|
||||||
|
from ..agents.callback_context import CallbackContext
|
||||||
from ..agents.invocation_context import InvocationContext
|
from ..agents.invocation_context import InvocationContext
|
||||||
from ..events.event import Event
|
from ..events.event import Event
|
||||||
from ..models.llm_request import LlmRequest
|
from ..models.llm_request import LlmRequest
|
||||||
@@ -115,39 +113,35 @@ class PluginManager:
|
|||||||
invocation_context: InvocationContext,
|
invocation_context: InvocationContext,
|
||||||
) -> Optional[types.Content]:
|
) -> Optional[types.Content]:
|
||||||
"""Runs the `on_user_message_callback` for all plugins."""
|
"""Runs the `on_user_message_callback` for all plugins."""
|
||||||
callback_context = CallbackContext(invocation_context)
|
|
||||||
return await self._run_callbacks(
|
return await self._run_callbacks(
|
||||||
"on_user_message_callback",
|
"on_user_message_callback",
|
||||||
user_message=user_message,
|
user_message=user_message,
|
||||||
callback_context=callback_context,
|
invocation_context=invocation_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_before_run_callback(
|
async def run_before_run_callback(
|
||||||
self, *, invocation_context: InvocationContext
|
self, *, invocation_context: InvocationContext
|
||||||
) -> Optional[types.Content]:
|
) -> Optional[types.Content]:
|
||||||
"""Runs the `before_run_callback` for all plugins."""
|
"""Runs the `before_run_callback` for all plugins."""
|
||||||
callback_context = CallbackContext(invocation_context)
|
|
||||||
return await self._run_callbacks(
|
return await self._run_callbacks(
|
||||||
"before_run_callback", callback_context=callback_context
|
"before_run_callback", invocation_context=invocation_context
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_after_run_callback(
|
async def run_after_run_callback(
|
||||||
self, *, invocation_context: InvocationContext
|
self, *, invocation_context: InvocationContext
|
||||||
) -> Optional[None]:
|
) -> Optional[None]:
|
||||||
"""Runs the `after_run_callback` for all plugins."""
|
"""Runs the `after_run_callback` for all plugins."""
|
||||||
callback_context = CallbackContext(invocation_context)
|
|
||||||
return await self._run_callbacks(
|
return await self._run_callbacks(
|
||||||
"after_run_callback", callback_context=callback_context
|
"after_run_callback", invocation_context=invocation_context
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_on_event_callback(
|
async def run_on_event_callback(
|
||||||
self, *, invocation_context: InvocationContext, event: Event
|
self, *, invocation_context: InvocationContext, event: Event
|
||||||
) -> Optional[Event]:
|
) -> Optional[Event]:
|
||||||
"""Runs the `on_event_callback` for all plugins."""
|
"""Runs the `on_event_callback` for all plugins."""
|
||||||
callback_context = CallbackContext(invocation_context)
|
|
||||||
return await self._run_callbacks(
|
return await self._run_callbacks(
|
||||||
"on_event_callback",
|
"on_event_callback",
|
||||||
callback_context=callback_context,
|
invocation_context=invocation_context,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -283,14 +277,8 @@ class PluginManager:
|
|||||||
# Each plugin might not implement all callbacks. The base class provides
|
# Each plugin might not implement all callbacks. The base class provides
|
||||||
# default `pass` implementations, so `getattr` will always succeed.
|
# default `pass` implementations, so `getattr` will always succeed.
|
||||||
callback_method = getattr(plugin, callback_name)
|
callback_method = getattr(plugin, callback_name)
|
||||||
|
|
||||||
# Backward compatibility: Support both callback_context and invocation_context
|
|
||||||
adapted_kwargs = self._adapt_kwargs_for_plugin(
|
|
||||||
plugin, callback_method, kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await callback_method(**adapted_kwargs)
|
result = await callback_method(**kwargs)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
# Early exit: A plugin has returned a value. We stop
|
# Early exit: A plugin has returned a value. We stop
|
||||||
# processing further plugins and return this value immediately.
|
# processing further plugins and return this value immediately.
|
||||||
@@ -309,80 +297,3 @@ class PluginManager:
|
|||||||
raise RuntimeError(error_message) from e
|
raise RuntimeError(error_message) from e
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _adapt_kwargs_for_plugin(
|
|
||||||
self, plugin: BasePlugin, callback_method: Any, kwargs: dict[str, Any]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Adapts keyword arguments for backward compatibility with legacy plugins.
|
|
||||||
|
|
||||||
This method handles the migration from invocation_context to
|
|
||||||
callback_context
|
|
||||||
by inspecting the plugin's callback method signature and providing the
|
|
||||||
appropriate parameter name. For maximum compatibility, it may pass both
|
|
||||||
parameters when the signature is ambiguous.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin: The plugin instance.
|
|
||||||
callback_method: The callback method to be invoked.
|
|
||||||
kwargs: The original keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Adapted keyword arguments that match the plugin's expected signature.
|
|
||||||
"""
|
|
||||||
# If no callback_context in kwargs, no adaptation needed
|
|
||||||
if "callback_context" not in kwargs:
|
|
||||||
return kwargs.copy()
|
|
||||||
|
|
||||||
callback_context = kwargs["callback_context"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Inspect the callback method signature
|
|
||||||
sig = inspect.signature(callback_method)
|
|
||||||
params = sig.parameters
|
|
||||||
|
|
||||||
# Case 1: Method explicitly wants only invocation_context
|
|
||||||
if "invocation_context" in params and "callback_context" not in params:
|
|
||||||
# Legacy plugin - pass only invocation_context
|
|
||||||
warnings.warn(
|
|
||||||
f"Plugin '{plugin.name}' uses deprecated 'invocation_context' "
|
|
||||||
"parameter in callback methods. Please update to use "
|
|
||||||
"'callback_context' instead. Support for 'invocation_context' "
|
|
||||||
"will be removed in a future version.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=3,
|
|
||||||
)
|
|
||||||
adapted_kwargs = kwargs.copy()
|
|
||||||
adapted_kwargs["invocation_context"] = (
|
|
||||||
callback_context._invocation_context
|
|
||||||
)
|
|
||||||
del adapted_kwargs["callback_context"]
|
|
||||||
return adapted_kwargs
|
|
||||||
|
|
||||||
# Case 2: Method explicitly wants only callback_context
|
|
||||||
elif "callback_context" in params and "invocation_context" not in params:
|
|
||||||
# Modern plugin - pass only callback_context
|
|
||||||
return kwargs.copy()
|
|
||||||
|
|
||||||
# Case 3: Method wants both, uses **kwargs, or signature is unclear
|
|
||||||
else:
|
|
||||||
# Pass both parameters for maximum compatibility
|
|
||||||
# This handles: **kwargs, both parameters explicitly, or unknown cases
|
|
||||||
adapted_kwargs = kwargs.copy()
|
|
||||||
adapted_kwargs["invocation_context"] = (
|
|
||||||
callback_context._invocation_context
|
|
||||||
)
|
|
||||||
return adapted_kwargs
|
|
||||||
|
|
||||||
except (ValueError, TypeError) as e:
|
|
||||||
# Fallback: Pass both parameters for safety
|
|
||||||
logger.debug(
|
|
||||||
"Failed to inspect plugin '%s' callback signature: %s. "
|
|
||||||
"Passing both callback_context and invocation_context for safety.",
|
|
||||||
plugin.name,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
adapted_kwargs = kwargs.copy()
|
|
||||||
adapted_kwargs["invocation_context"] = (
|
|
||||||
callback_context._invocation_context
|
|
||||||
)
|
|
||||||
return adapted_kwargs
|
|
||||||
|
|||||||
@@ -103,15 +103,19 @@ async def test_base_plugin_default_callbacks_return_none():
|
|||||||
assert (
|
assert (
|
||||||
await plugin.on_user_message_callback(
|
await plugin.on_user_message_callback(
|
||||||
user_message=mock_user_message,
|
user_message=mock_user_message,
|
||||||
callback_context=mock_context,
|
invocation_context=mock_context,
|
||||||
)
|
)
|
||||||
is None
|
is None
|
||||||
)
|
)
|
||||||
assert await plugin.before_run_callback(callback_context=mock_context) is None
|
assert (
|
||||||
assert await plugin.after_run_callback(callback_context=mock_context) is None
|
await plugin.before_run_callback(invocation_context=mock_context) is None
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
await plugin.after_run_callback(invocation_context=mock_context) is None
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
await plugin.on_event_callback(
|
await plugin.on_event_callback(
|
||||||
callback_context=mock_context, event=mock_context
|
invocation_context=mock_context, event=mock_context
|
||||||
)
|
)
|
||||||
is None
|
is None
|
||||||
)
|
)
|
||||||
@@ -196,21 +200,25 @@ async def test_base_plugin_all_callbacks_can_be_overridden():
|
|||||||
assert (
|
assert (
|
||||||
await plugin.on_user_message_callback(
|
await plugin.on_user_message_callback(
|
||||||
user_message=mock_user_message,
|
user_message=mock_user_message,
|
||||||
callback_context=mock_callback_context,
|
invocation_context=mock_invocation_context,
|
||||||
)
|
)
|
||||||
== "overridden_on_user_message"
|
== "overridden_on_user_message"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
await plugin.before_run_callback(callback_context=mock_callback_context)
|
await plugin.before_run_callback(
|
||||||
|
invocation_context=mock_invocation_context
|
||||||
|
)
|
||||||
== "overridden_before_run"
|
== "overridden_before_run"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
await plugin.after_run_callback(callback_context=mock_callback_context)
|
await plugin.after_run_callback(
|
||||||
|
invocation_context=mock_invocation_context
|
||||||
|
)
|
||||||
== "overridden_after_run"
|
== "overridden_after_run"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
await plugin.on_event_callback(
|
await plugin.on_event_callback(
|
||||||
callback_context=mock_callback_context, event=mock_event
|
invocation_context=mock_invocation_context, event=mock_event
|
||||||
)
|
)
|
||||||
== "overridden_on_event"
|
== "overridden_on_event"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ async def test_global_instruction_plugin_with_string():
|
|||||||
mock_invocation_context.session = mock_session
|
mock_invocation_context.session = mock_session
|
||||||
|
|
||||||
mock_callback_context = Mock(spec=CallbackContext)
|
mock_callback_context = Mock(spec=CallbackContext)
|
||||||
mock_callback_context._invocation_context = mock_invocation_context
|
mock_callback_context.invocation_context = mock_invocation_context
|
||||||
|
|
||||||
llm_request = LlmRequest(
|
llm_request = LlmRequest(
|
||||||
model="gemini-1.5-flash",
|
model="gemini-1.5-flash",
|
||||||
@@ -70,7 +70,7 @@ async def test_global_instruction_plugin_with_instruction_provider():
|
|||||||
"""Test GlobalInstructionPlugin with an InstructionProvider function."""
|
"""Test GlobalInstructionPlugin with an InstructionProvider function."""
|
||||||
|
|
||||||
async def build_global_instruction(readonly_context: ReadonlyContext) -> str:
|
async def build_global_instruction(readonly_context: ReadonlyContext) -> str:
|
||||||
return f"You are assistant for user {readonly_context.user_id}."
|
return f"You are assistant for user {readonly_context.session.user_id}."
|
||||||
|
|
||||||
plugin = GlobalInstructionPlugin(global_instruction=build_global_instruction)
|
plugin = GlobalInstructionPlugin(global_instruction=build_global_instruction)
|
||||||
|
|
||||||
@@ -83,8 +83,7 @@ async def test_global_instruction_plugin_with_instruction_provider():
|
|||||||
mock_invocation_context.session = mock_session
|
mock_invocation_context.session = mock_session
|
||||||
|
|
||||||
mock_callback_context = Mock(spec=CallbackContext)
|
mock_callback_context = Mock(spec=CallbackContext)
|
||||||
mock_callback_context._invocation_context = mock_invocation_context
|
mock_callback_context.invocation_context = mock_invocation_context
|
||||||
mock_callback_context.user_id = "alice"
|
|
||||||
|
|
||||||
llm_request = LlmRequest(
|
llm_request = LlmRequest(
|
||||||
model="gemini-1.5-flash",
|
model="gemini-1.5-flash",
|
||||||
@@ -120,7 +119,7 @@ async def test_global_instruction_plugin_empty_instruction():
|
|||||||
mock_invocation_context.session = mock_session
|
mock_invocation_context.session = mock_session
|
||||||
|
|
||||||
mock_callback_context = Mock(spec=CallbackContext)
|
mock_callback_context = Mock(spec=CallbackContext)
|
||||||
mock_callback_context._invocation_context = mock_invocation_context
|
mock_callback_context.invocation_context = mock_invocation_context
|
||||||
|
|
||||||
llm_request = LlmRequest(
|
llm_request = LlmRequest(
|
||||||
model="gemini-1.5-flash",
|
model="gemini-1.5-flash",
|
||||||
@@ -157,7 +156,7 @@ async def test_global_instruction_plugin_leads_existing():
|
|||||||
mock_invocation_context.session = mock_session
|
mock_invocation_context.session = mock_session
|
||||||
|
|
||||||
mock_callback_context = Mock(spec=CallbackContext)
|
mock_callback_context = Mock(spec=CallbackContext)
|
||||||
mock_callback_context._invocation_context = mock_invocation_context
|
mock_callback_context.invocation_context = mock_invocation_context
|
||||||
|
|
||||||
llm_request = LlmRequest(
|
llm_request = LlmRequest(
|
||||||
model="gemini-1.5-flash",
|
model="gemini-1.5-flash",
|
||||||
@@ -192,7 +191,7 @@ async def test_global_instruction_plugin_prepends_to_list():
|
|||||||
mock_invocation_context.session = mock_session
|
mock_invocation_context.session = mock_session
|
||||||
|
|
||||||
mock_callback_context = Mock(spec=CallbackContext)
|
mock_callback_context = Mock(spec=CallbackContext)
|
||||||
mock_callback_context._invocation_context = mock_invocation_context
|
mock_callback_context.invocation_context = mock_invocation_context
|
||||||
|
|
||||||
llm_request = LlmRequest(
|
llm_request = LlmRequest(
|
||||||
model="gemini-1.5-flash",
|
model="gemini-1.5-flash",
|
||||||
|
|||||||
@@ -15,8 +15,8 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from google.adk.agents.base_agent import BaseAgent
|
from google.adk.agents.base_agent import BaseAgent
|
||||||
from google.adk.agents.callback_context import CallbackContext
|
|
||||||
from google.adk.agents.context_cache_config import ContextCacheConfig
|
from google.adk.agents.context_cache_config import ContextCacheConfig
|
||||||
|
from google.adk.agents.invocation_context import InvocationContext
|
||||||
from google.adk.agents.llm_agent import LlmAgent
|
from google.adk.agents.llm_agent import LlmAgent
|
||||||
from google.adk.apps.app import App
|
from google.adk.apps.app import App
|
||||||
from google.adk.apps.app import ResumabilityConfig
|
from google.adk.apps.app import ResumabilityConfig
|
||||||
@@ -98,7 +98,7 @@ class MockPlugin(BasePlugin):
|
|||||||
async def on_user_message_callback(
|
async def on_user_message_callback(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
callback_context: CallbackContext,
|
invocation_context: InvocationContext,
|
||||||
user_message: types.Content,
|
user_message: types.Content,
|
||||||
) -> Optional[types.Content]:
|
) -> Optional[types.Content]:
|
||||||
if not self.enable_user_message_callback:
|
if not self.enable_user_message_callback:
|
||||||
@@ -109,7 +109,7 @@ class MockPlugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def on_event_callback(
|
async def on_event_callback(
|
||||||
self, *, callback_context: CallbackContext, event: Event
|
self, *, invocation_context: InvocationContext, event: Event
|
||||||
) -> Optional[Event]:
|
) -> Optional[Event]:
|
||||||
if not self.enable_event_callback:
|
if not self.enable_event_callback:
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user