You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: migrate invocation_context to callback_context
Update plugin manager and built-in plugins to prioritize CallbackContext. Keep InvocationContext access for legacy plugins with adapter. Change callback docs/tests to cover the new context. PiperOrigin-RevId: 818822267
This commit is contained in:
committed by
Copybara-Service
parent
2158b3c915
commit
df05ed6b3b
@@ -55,21 +55,6 @@ class ReadonlyContext:
|
||||
return MappingProxyType(self._invocation_context.session.state)
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
"""The user ID for the current invocation."""
|
||||
return self._invocation_context.user_id
|
||||
|
||||
@property
|
||||
def app_name(self) -> str:
|
||||
"""The application name for the current invocation."""
|
||||
return self._invocation_context.app_name
|
||||
|
||||
@property
|
||||
def session_id(self) -> str:
|
||||
"""The session ID for the current invocation."""
|
||||
return self._invocation_context.session.id
|
||||
|
||||
@property
|
||||
def branch(self) -> Optional[str]:
|
||||
"""The branch name for the current invocation, if any."""
|
||||
return self._invocation_context.branch
|
||||
def session(self) -> Session:
|
||||
"""The current session for this invocation."""
|
||||
return self._invocation_context.session
|
||||
|
||||
@@ -39,6 +39,7 @@ from .recordings_schema import Recordings
|
||||
from .recordings_schema import ToolRecording
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...tools.base_tool import BaseTool
|
||||
from ...tools.tool_context import ToolContext
|
||||
|
||||
@@ -74,10 +75,10 @@ class RecordingsPlugin(BasePlugin):
|
||||
|
||||
@override
|
||||
async def before_run_callback(
|
||||
self, *, callback_context: CallbackContext
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> Optional[types.Content]:
|
||||
"""Always create fresh per-invocation recording state when enabled."""
|
||||
ctx = callback_context
|
||||
ctx = CallbackContext(invocation_context)
|
||||
if self._is_record_mode_on(ctx):
|
||||
# Always create/overwrite the state for this invocation
|
||||
self._create_invocation_state(ctx)
|
||||
@@ -279,10 +280,10 @@ class RecordingsPlugin(BasePlugin):
|
||||
|
||||
@override
|
||||
async def after_run_callback(
|
||||
self, *, callback_context: CallbackContext
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> None:
|
||||
"""Finalize and persist recordings, then clean per-invocation state."""
|
||||
ctx = callback_context
|
||||
ctx = CallbackContext(invocation_context)
|
||||
if not self._is_record_mode_on(ctx):
|
||||
return None
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from .recordings_schema import Recordings
|
||||
from .recordings_schema import ToolRecording
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...tools.base_tool import BaseTool
|
||||
from ...tools.tool_context import ToolContext
|
||||
|
||||
@@ -80,10 +81,10 @@ class ReplayPlugin(BasePlugin):
|
||||
|
||||
@override
|
||||
async def before_run_callback(
|
||||
self, *, callback_context: CallbackContext
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> Optional[types.Content]:
|
||||
"""Load replay recordings when enabled."""
|
||||
ctx = callback_context
|
||||
ctx = CallbackContext(invocation_context)
|
||||
if self._is_replay_mode_on(ctx):
|
||||
# Load the replay state for this invocation
|
||||
self._load_invocation_state(ctx)
|
||||
@@ -155,10 +156,10 @@ class ReplayPlugin(BasePlugin):
|
||||
|
||||
@override
|
||||
async def after_run_callback(
|
||||
self, *, callback_context: CallbackContext
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> None:
|
||||
"""Clean up replay state after invocation completes."""
|
||||
ctx = callback_context
|
||||
ctx = CallbackContext(invocation_context)
|
||||
if not self._is_replay_mode_on(ctx):
|
||||
return None
|
||||
|
||||
|
||||
@@ -111,61 +111,11 @@ class BasePlugin(ABC):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
async def on_user_message_callback(
|
||||
self,
|
||||
*,
|
||||
callback_context: Optional[CallbackContext] = None,
|
||||
user_message: types.Content,
|
||||
invocation_context: Optional[InvocationContext] = None,
|
||||
) -> Optional[types.Content]:
|
||||
"""Callback executed when a user message is received before an invocation starts.
|
||||
|
||||
Plugins can implement this with either callback_context (new) or
|
||||
invocation_context (deprecated) or both.
|
||||
"""
|
||||
|
||||
async def before_run_callback(
|
||||
self,
|
||||
*,
|
||||
callback_context: Optional[CallbackContext] = None,
|
||||
invocation_context: Optional[InvocationContext] = None,
|
||||
) -> Optional[types.Content]:
|
||||
"""Callback executed before the ADK runner runs.
|
||||
|
||||
Plugins can implement this with either callback_context (new) or
|
||||
invocation_context (deprecated) or both.
|
||||
"""
|
||||
|
||||
async def on_event_callback(
|
||||
self,
|
||||
*,
|
||||
callback_context: Optional[CallbackContext] = None,
|
||||
event: Event,
|
||||
invocation_context: Optional[InvocationContext] = None,
|
||||
) -> Optional[Event]:
|
||||
"""Callback executed after an event is yielded from runner.
|
||||
|
||||
Plugins can implement this with either callback_context (new) or
|
||||
invocation_context (deprecated) or both.
|
||||
"""
|
||||
|
||||
async def after_run_callback(
|
||||
self,
|
||||
*,
|
||||
callback_context: Optional[CallbackContext] = None,
|
||||
invocation_context: Optional[InvocationContext] = None,
|
||||
) -> None:
|
||||
"""Callback executed after an ADK runner run has completed.
|
||||
|
||||
Plugins can implement this with either callback_context (new) or
|
||||
invocation_context (deprecated) or both.
|
||||
"""
|
||||
|
||||
# Runtime implementation accepts both via **kwargs
|
||||
async def on_user_message_callback(
|
||||
self, **kwargs: Any
|
||||
self,
|
||||
*,
|
||||
invocation_context: InvocationContext,
|
||||
user_message: types.Content,
|
||||
) -> Optional[types.Content]:
|
||||
"""Callback executed when a user message is received before an invocation starts.
|
||||
|
||||
@@ -173,63 +123,69 @@ class BasePlugin(ABC):
|
||||
runner starts the invocation.
|
||||
|
||||
Args:
|
||||
callback_context: The context for the callback execution.
|
||||
invocation_context: The context for the entire invocation.
|
||||
user_message: The message content input by user.
|
||||
invocation_context: DEPRECATED. Use callback_context instead. The context
|
||||
for the entire invocation. This parameter is maintained for backward
|
||||
compatibility and will be removed in a future version.
|
||||
|
||||
Returns:
|
||||
The modified user message or None if no modification is needed.
|
||||
An optional `types.Content` to be returned to the ADK. Returning a
|
||||
value to replace the user message. Returning `None` to proceed
|
||||
normally.
|
||||
"""
|
||||
return None
|
||||
pass
|
||||
|
||||
async def before_run_callback(self, **kwargs: Any) -> Optional[types.Content]:
|
||||
async def before_run_callback(
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> Optional[types.Content]:
|
||||
"""Callback executed before the ADK runner runs.
|
||||
|
||||
This is the first lifecycle hook and is ideal for global setup, logging,
|
||||
or checks that may stop the invocation from running.
|
||||
This is the first callback to be called in the lifecycle, ideal for global
|
||||
setup or initialization tasks.
|
||||
|
||||
Args:
|
||||
callback_context: The context for the callback execution.
|
||||
invocation_context: DEPRECATED. Use callback_context instead. The context
|
||||
for the entire invocation. This parameter is maintained for backward
|
||||
compatibility and will be removed in a future version.
|
||||
invocation_context: The context for the entire invocation, containing
|
||||
session information, the root agent, etc.
|
||||
|
||||
Returns:
|
||||
Optional `types.Content` to halt execution and return the value to the
|
||||
caller. Return `None` to proceed normally.
|
||||
An optional `Event` to be returned to the ADK. Returning a value to
|
||||
halt execution of the runner and ends the runner with that event. Return
|
||||
`None` to proceed normally.
|
||||
"""
|
||||
return None
|
||||
pass
|
||||
|
||||
async def on_event_callback(self, **kwargs: Any) -> Optional[Event]:
|
||||
async def on_event_callback(
|
||||
self, *, invocation_context: InvocationContext, event: Event
|
||||
) -> Optional[Event]:
|
||||
"""Callback executed after an event is yielded from runner.
|
||||
|
||||
This is the ideal place to make modification to the event before the event
|
||||
is handled by the underlying agent app.
|
||||
|
||||
Args:
|
||||
callback_context: The context for the callback execution.
|
||||
invocation_context: The context for the entire invocation.
|
||||
event: The event raised by the runner.
|
||||
invocation_context: DEPRECATED. Use callback_context instead. The context
|
||||
for the entire invocation. This parameter is maintained for backward
|
||||
compatibility and will be removed in a future version.
|
||||
|
||||
Returns:
|
||||
The modified event or None if no modification is needed.
|
||||
An optional value. A non-`None` return may be used by the framework to
|
||||
modify or replace the response. Returning `None` allows the original
|
||||
response to be used.
|
||||
"""
|
||||
return None
|
||||
pass
|
||||
|
||||
async def after_run_callback(self, **kwargs: Any) -> None:
|
||||
async def after_run_callback(
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> None:
|
||||
"""Callback executed after an ADK runner run has completed.
|
||||
|
||||
This is the final callback in the ADK lifecycle, suitable for cleanup, final
|
||||
logging, or reporting tasks.
|
||||
|
||||
Args:
|
||||
callback_context: The context for the callback execution.
|
||||
invocation_context: DEPRECATED. Use callback_context instead. The context
|
||||
for the entire invocation. This parameter is maintained for backward
|
||||
compatibility and will be removed in a future version.
|
||||
invocation_context: The context for the entire invocation.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
return None
|
||||
pass
|
||||
|
||||
async def before_agent_callback(
|
||||
self, *, agent: BaseAgent, callback_context: CallbackContext
|
||||
|
||||
@@ -79,7 +79,7 @@ class GlobalInstructionPlugin(BasePlugin):
|
||||
return None
|
||||
|
||||
# Resolve the global instruction (handle both string and InstructionProvider)
|
||||
readonly_context = callback_context
|
||||
readonly_context = ReadonlyContext(callback_context.invocation_context)
|
||||
final_global_instruction = await self._resolve_global_instruction(
|
||||
readonly_context
|
||||
)
|
||||
|
||||
@@ -69,32 +69,38 @@ class LoggingPlugin(BasePlugin):
|
||||
async def on_user_message_callback(
|
||||
self,
|
||||
*,
|
||||
callback_context: CallbackContext,
|
||||
invocation_context: InvocationContext,
|
||||
user_message: types.Content,
|
||||
) -> Optional[types.Content]:
|
||||
"""Log user message and invocation start."""
|
||||
self._log(f"🚀 USER MESSAGE RECEIVED")
|
||||
self._log(f" Invocation ID: {callback_context.invocation_id}")
|
||||
self._log(f" Session ID: {callback_context.session_id}")
|
||||
self._log(f" User ID: {callback_context.user_id}")
|
||||
self._log(f" App Name: {callback_context.app_name}")
|
||||
self._log(f" Root Agent: {callback_context.agent_name}")
|
||||
self._log(f" Invocation ID: {invocation_context.invocation_id}")
|
||||
self._log(f" Session ID: {invocation_context.session.id}")
|
||||
self._log(f" User ID: {invocation_context.user_id}")
|
||||
self._log(f" App Name: {invocation_context.app_name}")
|
||||
self._log(
|
||||
" Root Agent:"
|
||||
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
|
||||
)
|
||||
self._log(f" User Content: {self._format_content(user_message)}")
|
||||
if callback_context.branch:
|
||||
self._log(f" Branch: {callback_context.branch}")
|
||||
if invocation_context.branch:
|
||||
self._log(f" Branch: {invocation_context.branch}")
|
||||
return None
|
||||
|
||||
async def before_run_callback(
|
||||
self, *, callback_context: CallbackContext
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> Optional[types.Content]:
|
||||
"""Log invocation start."""
|
||||
self._log(f"🏃 INVOCATION STARTING")
|
||||
self._log(f" Invocation ID: {callback_context.invocation_id}")
|
||||
self._log(f" Starting Agent: {callback_context.agent_name}")
|
||||
self._log(f" Invocation ID: {invocation_context.invocation_id}")
|
||||
self._log(
|
||||
" Starting Agent:"
|
||||
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def on_event_callback(
|
||||
self, *, callback_context: CallbackContext, event: Event
|
||||
self, *, invocation_context: InvocationContext, event: Event
|
||||
) -> Optional[Event]:
|
||||
"""Log events yielded from the runner."""
|
||||
self._log(f"📢 EVENT YIELDED")
|
||||
@@ -117,12 +123,15 @@ class LoggingPlugin(BasePlugin):
|
||||
return None
|
||||
|
||||
async def after_run_callback(
|
||||
self, *, callback_context: CallbackContext
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> Optional[None]:
|
||||
"""Log invocation completion."""
|
||||
self._log(f"✅ INVOCATION COMPLETED")
|
||||
self._log(f" Invocation ID: {callback_context.invocation_id}")
|
||||
self._log(f" Final Agent: {callback_context.agent_name}")
|
||||
self._log(f" Invocation ID: {invocation_context.invocation_id}")
|
||||
self._log(
|
||||
" Final Agent:"
|
||||
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def before_agent_callback(
|
||||
@@ -132,8 +141,8 @@ class LoggingPlugin(BasePlugin):
|
||||
self._log(f"🤖 AGENT STARTING")
|
||||
self._log(f" Agent Name: {callback_context.agent_name}")
|
||||
self._log(f" Invocation ID: {callback_context.invocation_id}")
|
||||
if callback_context.branch:
|
||||
self._log(f" Branch: {callback_context.branch}")
|
||||
if callback_context._invocation_context.branch:
|
||||
self._log(f" Branch: {callback_context._invocation_context.branch}")
|
||||
return None
|
||||
|
||||
async def after_agent_callback(
|
||||
|
||||
@@ -14,22 +14,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
import warnings
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from .base_plugin import BasePlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from ..models.llm_request import LlmRequest
|
||||
@@ -115,39 +113,35 @@ class PluginManager:
|
||||
invocation_context: InvocationContext,
|
||||
) -> Optional[types.Content]:
|
||||
"""Runs the `on_user_message_callback` for all plugins."""
|
||||
callback_context = CallbackContext(invocation_context)
|
||||
return await self._run_callbacks(
|
||||
"on_user_message_callback",
|
||||
user_message=user_message,
|
||||
callback_context=callback_context,
|
||||
invocation_context=invocation_context,
|
||||
)
|
||||
|
||||
async def run_before_run_callback(
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> Optional[types.Content]:
|
||||
"""Runs the `before_run_callback` for all plugins."""
|
||||
callback_context = CallbackContext(invocation_context)
|
||||
return await self._run_callbacks(
|
||||
"before_run_callback", callback_context=callback_context
|
||||
"before_run_callback", invocation_context=invocation_context
|
||||
)
|
||||
|
||||
async def run_after_run_callback(
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> Optional[None]:
|
||||
"""Runs the `after_run_callback` for all plugins."""
|
||||
callback_context = CallbackContext(invocation_context)
|
||||
return await self._run_callbacks(
|
||||
"after_run_callback", callback_context=callback_context
|
||||
"after_run_callback", invocation_context=invocation_context
|
||||
)
|
||||
|
||||
async def run_on_event_callback(
|
||||
self, *, invocation_context: InvocationContext, event: Event
|
||||
) -> Optional[Event]:
|
||||
"""Runs the `on_event_callback` for all plugins."""
|
||||
callback_context = CallbackContext(invocation_context)
|
||||
return await self._run_callbacks(
|
||||
"on_event_callback",
|
||||
callback_context=callback_context,
|
||||
invocation_context=invocation_context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
@@ -283,14 +277,8 @@ class PluginManager:
|
||||
# Each plugin might not implement all callbacks. The base class provides
|
||||
# default `pass` implementations, so `getattr` will always succeed.
|
||||
callback_method = getattr(plugin, callback_name)
|
||||
|
||||
# Backward compatibility: Support both callback_context and invocation_context
|
||||
adapted_kwargs = self._adapt_kwargs_for_plugin(
|
||||
plugin, callback_method, kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
result = await callback_method(**adapted_kwargs)
|
||||
result = await callback_method(**kwargs)
|
||||
if result is not None:
|
||||
# Early exit: A plugin has returned a value. We stop
|
||||
# processing further plugins and return this value immediately.
|
||||
@@ -309,80 +297,3 @@ class PluginManager:
|
||||
raise RuntimeError(error_message) from e
|
||||
|
||||
return None
|
||||
|
||||
def _adapt_kwargs_for_plugin(
|
||||
self, plugin: BasePlugin, callback_method: Any, kwargs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Adapts keyword arguments for backward compatibility with legacy plugins.
|
||||
|
||||
This method handles the migration from invocation_context to
|
||||
callback_context
|
||||
by inspecting the plugin's callback method signature and providing the
|
||||
appropriate parameter name. For maximum compatibility, it may pass both
|
||||
parameters when the signature is ambiguous.
|
||||
|
||||
Args:
|
||||
plugin: The plugin instance.
|
||||
callback_method: The callback method to be invoked.
|
||||
kwargs: The original keyword arguments.
|
||||
|
||||
Returns:
|
||||
Adapted keyword arguments that match the plugin's expected signature.
|
||||
"""
|
||||
# If no callback_context in kwargs, no adaptation needed
|
||||
if "callback_context" not in kwargs:
|
||||
return kwargs.copy()
|
||||
|
||||
callback_context = kwargs["callback_context"]
|
||||
|
||||
try:
|
||||
# Inspect the callback method signature
|
||||
sig = inspect.signature(callback_method)
|
||||
params = sig.parameters
|
||||
|
||||
# Case 1: Method explicitly wants only invocation_context
|
||||
if "invocation_context" in params and "callback_context" not in params:
|
||||
# Legacy plugin - pass only invocation_context
|
||||
warnings.warn(
|
||||
f"Plugin '{plugin.name}' uses deprecated 'invocation_context' "
|
||||
"parameter in callback methods. Please update to use "
|
||||
"'callback_context' instead. Support for 'invocation_context' "
|
||||
"will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
adapted_kwargs = kwargs.copy()
|
||||
adapted_kwargs["invocation_context"] = (
|
||||
callback_context._invocation_context
|
||||
)
|
||||
del adapted_kwargs["callback_context"]
|
||||
return adapted_kwargs
|
||||
|
||||
# Case 2: Method explicitly wants only callback_context
|
||||
elif "callback_context" in params and "invocation_context" not in params:
|
||||
# Modern plugin - pass only callback_context
|
||||
return kwargs.copy()
|
||||
|
||||
# Case 3: Method wants both, uses **kwargs, or signature is unclear
|
||||
else:
|
||||
# Pass both parameters for maximum compatibility
|
||||
# This handles: **kwargs, both parameters explicitly, or unknown cases
|
||||
adapted_kwargs = kwargs.copy()
|
||||
adapted_kwargs["invocation_context"] = (
|
||||
callback_context._invocation_context
|
||||
)
|
||||
return adapted_kwargs
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
# Fallback: Pass both parameters for safety
|
||||
logger.debug(
|
||||
"Failed to inspect plugin '%s' callback signature: %s. "
|
||||
"Passing both callback_context and invocation_context for safety.",
|
||||
plugin.name,
|
||||
e,
|
||||
)
|
||||
adapted_kwargs = kwargs.copy()
|
||||
adapted_kwargs["invocation_context"] = (
|
||||
callback_context._invocation_context
|
||||
)
|
||||
return adapted_kwargs
|
||||
|
||||
@@ -103,15 +103,19 @@ async def test_base_plugin_default_callbacks_return_none():
|
||||
assert (
|
||||
await plugin.on_user_message_callback(
|
||||
user_message=mock_user_message,
|
||||
callback_context=mock_context,
|
||||
invocation_context=mock_context,
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert await plugin.before_run_callback(callback_context=mock_context) is None
|
||||
assert await plugin.after_run_callback(callback_context=mock_context) is None
|
||||
assert (
|
||||
await plugin.before_run_callback(invocation_context=mock_context) is None
|
||||
)
|
||||
assert (
|
||||
await plugin.after_run_callback(invocation_context=mock_context) is None
|
||||
)
|
||||
assert (
|
||||
await plugin.on_event_callback(
|
||||
callback_context=mock_context, event=mock_context
|
||||
invocation_context=mock_context, event=mock_context
|
||||
)
|
||||
is None
|
||||
)
|
||||
@@ -196,21 +200,25 @@ async def test_base_plugin_all_callbacks_can_be_overridden():
|
||||
assert (
|
||||
await plugin.on_user_message_callback(
|
||||
user_message=mock_user_message,
|
||||
callback_context=mock_callback_context,
|
||||
invocation_context=mock_invocation_context,
|
||||
)
|
||||
== "overridden_on_user_message"
|
||||
)
|
||||
assert (
|
||||
await plugin.before_run_callback(callback_context=mock_callback_context)
|
||||
await plugin.before_run_callback(
|
||||
invocation_context=mock_invocation_context
|
||||
)
|
||||
== "overridden_before_run"
|
||||
)
|
||||
assert (
|
||||
await plugin.after_run_callback(callback_context=mock_callback_context)
|
||||
await plugin.after_run_callback(
|
||||
invocation_context=mock_invocation_context
|
||||
)
|
||||
== "overridden_after_run"
|
||||
)
|
||||
assert (
|
||||
await plugin.on_event_callback(
|
||||
callback_context=mock_callback_context, event=mock_event
|
||||
invocation_context=mock_invocation_context, event=mock_event
|
||||
)
|
||||
== "overridden_on_event"
|
||||
)
|
||||
|
||||
@@ -43,7 +43,7 @@ async def test_global_instruction_plugin_with_string():
|
||||
mock_invocation_context.session = mock_session
|
||||
|
||||
mock_callback_context = Mock(spec=CallbackContext)
|
||||
mock_callback_context._invocation_context = mock_invocation_context
|
||||
mock_callback_context.invocation_context = mock_invocation_context
|
||||
|
||||
llm_request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
@@ -70,7 +70,7 @@ async def test_global_instruction_plugin_with_instruction_provider():
|
||||
"""Test GlobalInstructionPlugin with an InstructionProvider function."""
|
||||
|
||||
async def build_global_instruction(readonly_context: ReadonlyContext) -> str:
|
||||
return f"You are assistant for user {readonly_context.user_id}."
|
||||
return f"You are assistant for user {readonly_context.session.user_id}."
|
||||
|
||||
plugin = GlobalInstructionPlugin(global_instruction=build_global_instruction)
|
||||
|
||||
@@ -83,8 +83,7 @@ async def test_global_instruction_plugin_with_instruction_provider():
|
||||
mock_invocation_context.session = mock_session
|
||||
|
||||
mock_callback_context = Mock(spec=CallbackContext)
|
||||
mock_callback_context._invocation_context = mock_invocation_context
|
||||
mock_callback_context.user_id = "alice"
|
||||
mock_callback_context.invocation_context = mock_invocation_context
|
||||
|
||||
llm_request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
@@ -120,7 +119,7 @@ async def test_global_instruction_plugin_empty_instruction():
|
||||
mock_invocation_context.session = mock_session
|
||||
|
||||
mock_callback_context = Mock(spec=CallbackContext)
|
||||
mock_callback_context._invocation_context = mock_invocation_context
|
||||
mock_callback_context.invocation_context = mock_invocation_context
|
||||
|
||||
llm_request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
@@ -157,7 +156,7 @@ async def test_global_instruction_plugin_leads_existing():
|
||||
mock_invocation_context.session = mock_session
|
||||
|
||||
mock_callback_context = Mock(spec=CallbackContext)
|
||||
mock_callback_context._invocation_context = mock_invocation_context
|
||||
mock_callback_context.invocation_context = mock_invocation_context
|
||||
|
||||
llm_request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
@@ -192,7 +191,7 @@ async def test_global_instruction_plugin_prepends_to_list():
|
||||
mock_invocation_context.session = mock_session
|
||||
|
||||
mock_callback_context = Mock(spec=CallbackContext)
|
||||
mock_callback_context._invocation_context = mock_invocation_context
|
||||
mock_callback_context.invocation_context = mock_invocation_context
|
||||
|
||||
llm_request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.context_cache_config import ContextCacheConfig
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.apps.app import App
|
||||
from google.adk.apps.app import ResumabilityConfig
|
||||
@@ -98,7 +98,7 @@ class MockPlugin(BasePlugin):
|
||||
async def on_user_message_callback(
|
||||
self,
|
||||
*,
|
||||
callback_context: CallbackContext,
|
||||
invocation_context: InvocationContext,
|
||||
user_message: types.Content,
|
||||
) -> Optional[types.Content]:
|
||||
if not self.enable_user_message_callback:
|
||||
@@ -109,7 +109,7 @@ class MockPlugin(BasePlugin):
|
||||
)
|
||||
|
||||
async def on_event_callback(
|
||||
self, *, callback_context: CallbackContext, event: Event
|
||||
self, *, invocation_context: InvocationContext, event: Event
|
||||
) -> Optional[Event]:
|
||||
if not self.enable_event_callback:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user