feat: migrate invocation_context to callback_context

Update plugin manager and built-in plugins to prioritize CallbackContext. Keep InvocationContext access for legacy plugins with adapter. Change callback docs/tests to cover the new context.

PiperOrigin-RevId: 818822267
This commit is contained in:
Google Team Member
2025-10-13 14:05:09 -07:00
committed by Copybara-Service
parent 2158b3c915
commit df05ed6b3b
10 changed files with 110 additions and 240 deletions
+3 -18
View File
@@ -55,21 +55,6 @@ class ReadonlyContext:
return MappingProxyType(self._invocation_context.session.state) return MappingProxyType(self._invocation_context.session.state)
@property @property
def user_id(self) -> str: def session(self) -> Session:
"""The user ID for the current invocation.""" """The current session for this invocation."""
return self._invocation_context.user_id return self._invocation_context.session
@property
def app_name(self) -> str:
"""The application name for the current invocation."""
return self._invocation_context.app_name
@property
def session_id(self) -> str:
"""The session ID for the current invocation."""
return self._invocation_context.session.id
@property
def branch(self) -> Optional[str]:
"""The branch name for the current invocation, if any."""
return self._invocation_context.branch
@@ -39,6 +39,7 @@ from .recordings_schema import Recordings
from .recordings_schema import ToolRecording from .recordings_schema import ToolRecording
if TYPE_CHECKING: if TYPE_CHECKING:
from ...agents.invocation_context import InvocationContext
from ...tools.base_tool import BaseTool from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext from ...tools.tool_context import ToolContext
@@ -74,10 +75,10 @@ class RecordingsPlugin(BasePlugin):
@override @override
async def before_run_callback( async def before_run_callback(
self, *, callback_context: CallbackContext self, *, invocation_context: InvocationContext
) -> Optional[types.Content]: ) -> Optional[types.Content]:
"""Always create fresh per-invocation recording state when enabled.""" """Always create fresh per-invocation recording state when enabled."""
ctx = callback_context ctx = CallbackContext(invocation_context)
if self._is_record_mode_on(ctx): if self._is_record_mode_on(ctx):
# Always create/overwrite the state for this invocation # Always create/overwrite the state for this invocation
self._create_invocation_state(ctx) self._create_invocation_state(ctx)
@@ -279,10 +280,10 @@ class RecordingsPlugin(BasePlugin):
@override @override
async def after_run_callback( async def after_run_callback(
self, *, callback_context: CallbackContext self, *, invocation_context: InvocationContext
) -> None: ) -> None:
"""Finalize and persist recordings, then clean per-invocation state.""" """Finalize and persist recordings, then clean per-invocation state."""
ctx = callback_context ctx = CallbackContext(invocation_context)
if not self._is_record_mode_on(ctx): if not self._is_record_mode_on(ctx):
return None return None
+5 -4
View File
@@ -38,6 +38,7 @@ from .recordings_schema import Recordings
from .recordings_schema import ToolRecording from .recordings_schema import ToolRecording
if TYPE_CHECKING: if TYPE_CHECKING:
from ...agents.invocation_context import InvocationContext
from ...tools.base_tool import BaseTool from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext from ...tools.tool_context import ToolContext
@@ -80,10 +81,10 @@ class ReplayPlugin(BasePlugin):
@override @override
async def before_run_callback( async def before_run_callback(
self, *, callback_context: CallbackContext self, *, invocation_context: InvocationContext
) -> Optional[types.Content]: ) -> Optional[types.Content]:
"""Load replay recordings when enabled.""" """Load replay recordings when enabled."""
ctx = callback_context ctx = CallbackContext(invocation_context)
if self._is_replay_mode_on(ctx): if self._is_replay_mode_on(ctx):
# Load the replay state for this invocation # Load the replay state for this invocation
self._load_invocation_state(ctx) self._load_invocation_state(ctx)
@@ -155,10 +156,10 @@ class ReplayPlugin(BasePlugin):
@override @override
async def after_run_callback( async def after_run_callback(
self, *, callback_context: CallbackContext self, *, invocation_context: InvocationContext
) -> None: ) -> None:
"""Clean up replay state after invocation completes.""" """Clean up replay state after invocation completes."""
ctx = callback_context ctx = CallbackContext(invocation_context)
if not self._is_replay_mode_on(ctx): if not self._is_replay_mode_on(ctx):
return None return None
+39 -83
View File
@@ -111,61 +111,11 @@ class BasePlugin(ABC):
super().__init__() super().__init__()
self.name = name self.name = name
if TYPE_CHECKING:
async def on_user_message_callback(
self,
*,
callback_context: Optional[CallbackContext] = None,
user_message: types.Content,
invocation_context: Optional[InvocationContext] = None,
) -> Optional[types.Content]:
"""Callback executed when a user message is received before an invocation starts.
Plugins can implement this with either callback_context (new) or
invocation_context (deprecated) or both.
"""
async def before_run_callback(
self,
*,
callback_context: Optional[CallbackContext] = None,
invocation_context: Optional[InvocationContext] = None,
) -> Optional[types.Content]:
"""Callback executed before the ADK runner runs.
Plugins can implement this with either callback_context (new) or
invocation_context (deprecated) or both.
"""
async def on_event_callback(
self,
*,
callback_context: Optional[CallbackContext] = None,
event: Event,
invocation_context: Optional[InvocationContext] = None,
) -> Optional[Event]:
"""Callback executed after an event is yielded from runner.
Plugins can implement this with either callback_context (new) or
invocation_context (deprecated) or both.
"""
async def after_run_callback(
self,
*,
callback_context: Optional[CallbackContext] = None,
invocation_context: Optional[InvocationContext] = None,
) -> None:
"""Callback executed after an ADK runner run has completed.
Plugins can implement this with either callback_context (new) or
invocation_context (deprecated) or both.
"""
# Runtime implementation accepts both via **kwargs
async def on_user_message_callback( async def on_user_message_callback(
self, **kwargs: Any self,
*,
invocation_context: InvocationContext,
user_message: types.Content,
) -> Optional[types.Content]: ) -> Optional[types.Content]:
"""Callback executed when a user message is received before an invocation starts. """Callback executed when a user message is received before an invocation starts.
@@ -173,63 +123,69 @@ class BasePlugin(ABC):
runner starts the invocation. runner starts the invocation.
Args: Args:
callback_context: The context for the callback execution. invocation_context: The context for the entire invocation.
user_message: The message content input by user. user_message: The message content input by user.
invocation_context: DEPRECATED. Use callback_context instead. The context
for the entire invocation. This parameter is maintained for backward
compatibility and will be removed in a future version.
Returns: Returns:
The modified user message or None if no modification is needed. An optional `types.Content` to be returned to the ADK. Returning a
value to replace the user message. Returning `None` to proceed
normally.
""" """
return None pass
async def before_run_callback(self, **kwargs: Any) -> Optional[types.Content]: async def before_run_callback(
self, *, invocation_context: InvocationContext
) -> Optional[types.Content]:
"""Callback executed before the ADK runner runs. """Callback executed before the ADK runner runs.
This is the first lifecycle hook and is ideal for global setup, logging, This is the first callback to be called in the lifecycle, ideal for global
or checks that may stop the invocation from running. setup or initialization tasks.
Args: Args:
callback_context: The context for the callback execution. invocation_context: The context for the entire invocation, containing
invocation_context: DEPRECATED. Use callback_context instead. The context session information, the root agent, etc.
for the entire invocation. This parameter is maintained for backward
compatibility and will be removed in a future version.
Returns: Returns:
Optional `types.Content` to halt execution and return the value to the An optional `Event` to be returned to the ADK. Returning a value to
caller. Return `None` to proceed normally. halt execution of the runner and ends the runner with that event. Return
`None` to proceed normally.
""" """
return None pass
async def on_event_callback(self, **kwargs: Any) -> Optional[Event]: async def on_event_callback(
self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]:
"""Callback executed after an event is yielded from runner. """Callback executed after an event is yielded from runner.
This is the ideal place to make modification to the event before the event
is handled by the underlying agent app.
Args: Args:
callback_context: The context for the callback execution. invocation_context: The context for the entire invocation.
event: The event raised by the runner. event: The event raised by the runner.
invocation_context: DEPRECATED. Use callback_context instead. The context
for the entire invocation. This parameter is maintained for backward
compatibility and will be removed in a future version.
Returns: Returns:
The modified event or None if no modification is needed. An optional value. A non-`None` return may be used by the framework to
modify or replace the response. Returning `None` allows the original
response to be used.
""" """
return None pass
async def after_run_callback(self, **kwargs: Any) -> None: async def after_run_callback(
self, *, invocation_context: InvocationContext
) -> None:
"""Callback executed after an ADK runner run has completed. """Callback executed after an ADK runner run has completed.
This is the final callback in the ADK lifecycle, suitable for cleanup, final
logging, or reporting tasks.
Args: Args:
callback_context: The context for the callback execution. invocation_context: The context for the entire invocation.
invocation_context: DEPRECATED. Use callback_context instead. The context
for the entire invocation. This parameter is maintained for backward
compatibility and will be removed in a future version.
Returns: Returns:
None None
""" """
return None pass
async def before_agent_callback( async def before_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext self, *, agent: BaseAgent, callback_context: CallbackContext
@@ -79,7 +79,7 @@ class GlobalInstructionPlugin(BasePlugin):
return None return None
# Resolve the global instruction (handle both string and InstructionProvider) # Resolve the global instruction (handle both string and InstructionProvider)
readonly_context = callback_context readonly_context = ReadonlyContext(callback_context.invocation_context)
final_global_instruction = await self._resolve_global_instruction( final_global_instruction = await self._resolve_global_instruction(
readonly_context readonly_context
) )
+26 -17
View File
@@ -69,32 +69,38 @@ class LoggingPlugin(BasePlugin):
async def on_user_message_callback( async def on_user_message_callback(
self, self,
*, *,
callback_context: CallbackContext, invocation_context: InvocationContext,
user_message: types.Content, user_message: types.Content,
) -> Optional[types.Content]: ) -> Optional[types.Content]:
"""Log user message and invocation start.""" """Log user message and invocation start."""
self._log(f"🚀 USER MESSAGE RECEIVED") self._log(f"🚀 USER MESSAGE RECEIVED")
self._log(f" Invocation ID: {callback_context.invocation_id}") self._log(f" Invocation ID: {invocation_context.invocation_id}")
self._log(f" Session ID: {callback_context.session_id}") self._log(f" Session ID: {invocation_context.session.id}")
self._log(f" User ID: {callback_context.user_id}") self._log(f" User ID: {invocation_context.user_id}")
self._log(f" App Name: {callback_context.app_name}") self._log(f" App Name: {invocation_context.app_name}")
self._log(f" Root Agent: {callback_context.agent_name}") self._log(
" Root Agent:"
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
)
self._log(f" User Content: {self._format_content(user_message)}") self._log(f" User Content: {self._format_content(user_message)}")
if callback_context.branch: if invocation_context.branch:
self._log(f" Branch: {callback_context.branch}") self._log(f" Branch: {invocation_context.branch}")
return None return None
async def before_run_callback( async def before_run_callback(
self, *, callback_context: CallbackContext self, *, invocation_context: InvocationContext
) -> Optional[types.Content]: ) -> Optional[types.Content]:
"""Log invocation start.""" """Log invocation start."""
self._log(f"🏃 INVOCATION STARTING") self._log(f"🏃 INVOCATION STARTING")
self._log(f" Invocation ID: {callback_context.invocation_id}") self._log(f" Invocation ID: {invocation_context.invocation_id}")
self._log(f" Starting Agent: {callback_context.agent_name}") self._log(
" Starting Agent:"
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
)
return None return None
async def on_event_callback( async def on_event_callback(
self, *, callback_context: CallbackContext, event: Event self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]: ) -> Optional[Event]:
"""Log events yielded from the runner.""" """Log events yielded from the runner."""
self._log(f"📢 EVENT YIELDED") self._log(f"📢 EVENT YIELDED")
@@ -117,12 +123,15 @@ class LoggingPlugin(BasePlugin):
return None return None
async def after_run_callback( async def after_run_callback(
self, *, callback_context: CallbackContext self, *, invocation_context: InvocationContext
) -> Optional[None]: ) -> Optional[None]:
"""Log invocation completion.""" """Log invocation completion."""
self._log(f"✅ INVOCATION COMPLETED") self._log(f"✅ INVOCATION COMPLETED")
self._log(f" Invocation ID: {callback_context.invocation_id}") self._log(f" Invocation ID: {invocation_context.invocation_id}")
self._log(f" Final Agent: {callback_context.agent_name}") self._log(
" Final Agent:"
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
)
return None return None
async def before_agent_callback( async def before_agent_callback(
@@ -132,8 +141,8 @@ class LoggingPlugin(BasePlugin):
self._log(f"🤖 AGENT STARTING") self._log(f"🤖 AGENT STARTING")
self._log(f" Agent Name: {callback_context.agent_name}") self._log(f" Agent Name: {callback_context.agent_name}")
self._log(f" Invocation ID: {callback_context.invocation_id}") self._log(f" Invocation ID: {callback_context.invocation_id}")
if callback_context.branch: if callback_context._invocation_context.branch:
self._log(f" Branch: {callback_context.branch}") self._log(f" Branch: {callback_context._invocation_context.branch}")
return None return None
async def after_agent_callback( async def after_agent_callback(
+6 -95
View File
@@ -14,22 +14,20 @@
from __future__ import annotations from __future__ import annotations
import inspect
import logging import logging
from typing import Any from typing import Any
from typing import List from typing import List
from typing import Literal from typing import Literal
from typing import Optional from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import warnings
from google.genai import types from google.genai import types
from ..agents.callback_context import CallbackContext
from .base_plugin import BasePlugin from .base_plugin import BasePlugin
if TYPE_CHECKING: if TYPE_CHECKING:
from ..agents.base_agent import BaseAgent from ..agents.base_agent import BaseAgent
from ..agents.callback_context import CallbackContext
from ..agents.invocation_context import InvocationContext from ..agents.invocation_context import InvocationContext
from ..events.event import Event from ..events.event import Event
from ..models.llm_request import LlmRequest from ..models.llm_request import LlmRequest
@@ -115,39 +113,35 @@ class PluginManager:
invocation_context: InvocationContext, invocation_context: InvocationContext,
) -> Optional[types.Content]: ) -> Optional[types.Content]:
"""Runs the `on_user_message_callback` for all plugins.""" """Runs the `on_user_message_callback` for all plugins."""
callback_context = CallbackContext(invocation_context)
return await self._run_callbacks( return await self._run_callbacks(
"on_user_message_callback", "on_user_message_callback",
user_message=user_message, user_message=user_message,
callback_context=callback_context, invocation_context=invocation_context,
) )
async def run_before_run_callback( async def run_before_run_callback(
self, *, invocation_context: InvocationContext self, *, invocation_context: InvocationContext
) -> Optional[types.Content]: ) -> Optional[types.Content]:
"""Runs the `before_run_callback` for all plugins.""" """Runs the `before_run_callback` for all plugins."""
callback_context = CallbackContext(invocation_context)
return await self._run_callbacks( return await self._run_callbacks(
"before_run_callback", callback_context=callback_context "before_run_callback", invocation_context=invocation_context
) )
async def run_after_run_callback( async def run_after_run_callback(
self, *, invocation_context: InvocationContext self, *, invocation_context: InvocationContext
) -> Optional[None]: ) -> Optional[None]:
"""Runs the `after_run_callback` for all plugins.""" """Runs the `after_run_callback` for all plugins."""
callback_context = CallbackContext(invocation_context)
return await self._run_callbacks( return await self._run_callbacks(
"after_run_callback", callback_context=callback_context "after_run_callback", invocation_context=invocation_context
) )
async def run_on_event_callback( async def run_on_event_callback(
self, *, invocation_context: InvocationContext, event: Event self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]: ) -> Optional[Event]:
"""Runs the `on_event_callback` for all plugins.""" """Runs the `on_event_callback` for all plugins."""
callback_context = CallbackContext(invocation_context)
return await self._run_callbacks( return await self._run_callbacks(
"on_event_callback", "on_event_callback",
callback_context=callback_context, invocation_context=invocation_context,
event=event, event=event,
) )
@@ -283,14 +277,8 @@ class PluginManager:
# Each plugin might not implement all callbacks. The base class provides # Each plugin might not implement all callbacks. The base class provides
# default `pass` implementations, so `getattr` will always succeed. # default `pass` implementations, so `getattr` will always succeed.
callback_method = getattr(plugin, callback_name) callback_method = getattr(plugin, callback_name)
# Backward compatibility: Support both callback_context and invocation_context
adapted_kwargs = self._adapt_kwargs_for_plugin(
plugin, callback_method, kwargs
)
try: try:
result = await callback_method(**adapted_kwargs) result = await callback_method(**kwargs)
if result is not None: if result is not None:
# Early exit: A plugin has returned a value. We stop # Early exit: A plugin has returned a value. We stop
# processing further plugins and return this value immediately. # processing further plugins and return this value immediately.
@@ -309,80 +297,3 @@ class PluginManager:
raise RuntimeError(error_message) from e raise RuntimeError(error_message) from e
return None return None
def _adapt_kwargs_for_plugin(
self, plugin: BasePlugin, callback_method: Any, kwargs: dict[str, Any]
) -> dict[str, Any]:
"""Adapts keyword arguments for backward compatibility with legacy plugins.
This method handles the migration from invocation_context to
callback_context
by inspecting the plugin's callback method signature and providing the
appropriate parameter name. For maximum compatibility, it may pass both
parameters when the signature is ambiguous.
Args:
plugin: The plugin instance.
callback_method: The callback method to be invoked.
kwargs: The original keyword arguments.
Returns:
Adapted keyword arguments that match the plugin's expected signature.
"""
# If no callback_context in kwargs, no adaptation needed
if "callback_context" not in kwargs:
return kwargs.copy()
callback_context = kwargs["callback_context"]
try:
# Inspect the callback method signature
sig = inspect.signature(callback_method)
params = sig.parameters
# Case 1: Method explicitly wants only invocation_context
if "invocation_context" in params and "callback_context" not in params:
# Legacy plugin - pass only invocation_context
warnings.warn(
f"Plugin '{plugin.name}' uses deprecated 'invocation_context' "
"parameter in callback methods. Please update to use "
"'callback_context' instead. Support for 'invocation_context' "
"will be removed in a future version.",
DeprecationWarning,
stacklevel=3,
)
adapted_kwargs = kwargs.copy()
adapted_kwargs["invocation_context"] = (
callback_context._invocation_context
)
del adapted_kwargs["callback_context"]
return adapted_kwargs
# Case 2: Method explicitly wants only callback_context
elif "callback_context" in params and "invocation_context" not in params:
# Modern plugin - pass only callback_context
return kwargs.copy()
# Case 3: Method wants both, uses **kwargs, or signature is unclear
else:
# Pass both parameters for maximum compatibility
# This handles: **kwargs, both parameters explicitly, or unknown cases
adapted_kwargs = kwargs.copy()
adapted_kwargs["invocation_context"] = (
callback_context._invocation_context
)
return adapted_kwargs
except (ValueError, TypeError) as e:
# Fallback: Pass both parameters for safety
logger.debug(
"Failed to inspect plugin '%s' callback signature: %s. "
"Passing both callback_context and invocation_context for safety.",
plugin.name,
e,
)
adapted_kwargs = kwargs.copy()
adapted_kwargs["invocation_context"] = (
callback_context._invocation_context
)
return adapted_kwargs
+16 -8
View File
@@ -103,15 +103,19 @@ async def test_base_plugin_default_callbacks_return_none():
assert ( assert (
await plugin.on_user_message_callback( await plugin.on_user_message_callback(
user_message=mock_user_message, user_message=mock_user_message,
callback_context=mock_context, invocation_context=mock_context,
) )
is None is None
) )
assert await plugin.before_run_callback(callback_context=mock_context) is None assert (
assert await plugin.after_run_callback(callback_context=mock_context) is None await plugin.before_run_callback(invocation_context=mock_context) is None
)
assert (
await plugin.after_run_callback(invocation_context=mock_context) is None
)
assert ( assert (
await plugin.on_event_callback( await plugin.on_event_callback(
callback_context=mock_context, event=mock_context invocation_context=mock_context, event=mock_context
) )
is None is None
) )
@@ -196,21 +200,25 @@ async def test_base_plugin_all_callbacks_can_be_overridden():
assert ( assert (
await plugin.on_user_message_callback( await plugin.on_user_message_callback(
user_message=mock_user_message, user_message=mock_user_message,
callback_context=mock_callback_context, invocation_context=mock_invocation_context,
) )
== "overridden_on_user_message" == "overridden_on_user_message"
) )
assert ( assert (
await plugin.before_run_callback(callback_context=mock_callback_context) await plugin.before_run_callback(
invocation_context=mock_invocation_context
)
== "overridden_before_run" == "overridden_before_run"
) )
assert ( assert (
await plugin.after_run_callback(callback_context=mock_callback_context) await plugin.after_run_callback(
invocation_context=mock_invocation_context
)
== "overridden_after_run" == "overridden_after_run"
) )
assert ( assert (
await plugin.on_event_callback( await plugin.on_event_callback(
callback_context=mock_callback_context, event=mock_event invocation_context=mock_invocation_context, event=mock_event
) )
== "overridden_on_event" == "overridden_on_event"
) )
@@ -43,7 +43,7 @@ async def test_global_instruction_plugin_with_string():
mock_invocation_context.session = mock_session mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext) mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context mock_callback_context.invocation_context = mock_invocation_context
llm_request = LlmRequest( llm_request = LlmRequest(
model="gemini-1.5-flash", model="gemini-1.5-flash",
@@ -70,7 +70,7 @@ async def test_global_instruction_plugin_with_instruction_provider():
"""Test GlobalInstructionPlugin with an InstructionProvider function.""" """Test GlobalInstructionPlugin with an InstructionProvider function."""
async def build_global_instruction(readonly_context: ReadonlyContext) -> str: async def build_global_instruction(readonly_context: ReadonlyContext) -> str:
return f"You are assistant for user {readonly_context.user_id}." return f"You are assistant for user {readonly_context.session.user_id}."
plugin = GlobalInstructionPlugin(global_instruction=build_global_instruction) plugin = GlobalInstructionPlugin(global_instruction=build_global_instruction)
@@ -83,8 +83,7 @@ async def test_global_instruction_plugin_with_instruction_provider():
mock_invocation_context.session = mock_session mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext) mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context mock_callback_context.invocation_context = mock_invocation_context
mock_callback_context.user_id = "alice"
llm_request = LlmRequest( llm_request = LlmRequest(
model="gemini-1.5-flash", model="gemini-1.5-flash",
@@ -120,7 +119,7 @@ async def test_global_instruction_plugin_empty_instruction():
mock_invocation_context.session = mock_session mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext) mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context mock_callback_context.invocation_context = mock_invocation_context
llm_request = LlmRequest( llm_request = LlmRequest(
model="gemini-1.5-flash", model="gemini-1.5-flash",
@@ -157,7 +156,7 @@ async def test_global_instruction_plugin_leads_existing():
mock_invocation_context.session = mock_session mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext) mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context mock_callback_context.invocation_context = mock_invocation_context
llm_request = LlmRequest( llm_request = LlmRequest(
model="gemini-1.5-flash", model="gemini-1.5-flash",
@@ -192,7 +191,7 @@ async def test_global_instruction_plugin_prepends_to_list():
mock_invocation_context.session = mock_session mock_invocation_context.session = mock_session
mock_callback_context = Mock(spec=CallbackContext) mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context._invocation_context = mock_invocation_context mock_callback_context.invocation_context = mock_invocation_context
llm_request = LlmRequest( llm_request = LlmRequest(
model="gemini-1.5-flash", model="gemini-1.5-flash",
+3 -3
View File
@@ -15,8 +15,8 @@
from typing import Optional from typing import Optional
from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.llm_agent import LlmAgent
from google.adk.apps.app import App from google.adk.apps.app import App
from google.adk.apps.app import ResumabilityConfig from google.adk.apps.app import ResumabilityConfig
@@ -98,7 +98,7 @@ class MockPlugin(BasePlugin):
async def on_user_message_callback( async def on_user_message_callback(
self, self,
*, *,
callback_context: CallbackContext, invocation_context: InvocationContext,
user_message: types.Content, user_message: types.Content,
) -> Optional[types.Content]: ) -> Optional[types.Content]:
if not self.enable_user_message_callback: if not self.enable_user_message_callback:
@@ -109,7 +109,7 @@ class MockPlugin(BasePlugin):
) )
async def on_event_callback( async def on_event_callback(
self, *, callback_context: CallbackContext, event: Event self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]: ) -> Optional[Event]:
if not self.enable_event_callback: if not self.enable_event_callback:
return None return None