You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Implement PluginService for registering and executing plugins
PluginService takes the registration of plugins, and provide the wrapper utilities to execute all plugins. PiperOrigin-RevId: 781745769
This commit is contained in:
committed by
Copybara-Service
parent
4dce9ef519
commit
16ba91cd01
@@ -0,0 +1,265 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from .base_plugin import BasePlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from ..models.llm_request import LlmRequest
|
||||
from ..models.llm_response import LlmResponse
|
||||
from ..tools.base_tool import BaseTool
|
||||
from ..tools.tool_context import ToolContext
|
||||
|
||||
# A type alias for the names of the available plugin callbacks.
|
||||
# This helps with static analysis and prevents typos when calling run_callbacks.
|
||||
PluginCallbackName = Literal[
|
||||
"on_user_message_callback",
|
||||
"before_run_callback",
|
||||
"after_run_callback",
|
||||
"on_event_callback",
|
||||
"before_agent_callback",
|
||||
"after_agent_callback",
|
||||
"before_tool_callback",
|
||||
"after_tool_callback",
|
||||
"before_model_callback",
|
||||
"after_model_callback",
|
||||
]
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""Manages the registration and execution of plugins.
|
||||
|
||||
The PluginManager is an internal class that orchestrates the invocation of
|
||||
plugin callbacks at key points in the SDK's execution lifecycle. It maintains
|
||||
a list of registered plugins and ensures they are called in the order they
|
||||
were registered.
|
||||
|
||||
The core execution logic implements an "early exit" strategy: if any plugin
|
||||
callback returns a non-`None` value, the execution of subsequent plugins for
|
||||
that specific event is halted, and the returned value is propagated up the
|
||||
call stack. This allows plugins to short-circuit operations like agent runs,
|
||||
tool calls, or model requests.
|
||||
"""
|
||||
|
||||
def __init__(self, plugins: Optional[List[BasePlugin]] = None):
|
||||
"""Initializes the plugin service.
|
||||
|
||||
Args:
|
||||
plugins: An optional list of plugins to register upon initialization.
|
||||
"""
|
||||
self.plugins: List[BasePlugin] = []
|
||||
if plugins:
|
||||
for plugin in plugins:
|
||||
self.register_plugin(plugin)
|
||||
|
||||
def register_plugin(self, plugin: BasePlugin) -> None:
|
||||
"""Registers a new plugin.
|
||||
|
||||
Args:
|
||||
plugin: The plugin instance to register.
|
||||
|
||||
Raises:
|
||||
ValueError: If a plugin with the same name is already registered.
|
||||
"""
|
||||
if any(p.name == plugin.name for p in self.plugins):
|
||||
raise ValueError(f"Plugin with name '{plugin.name}' already registered.")
|
||||
self.plugins.append(plugin)
|
||||
logger.info("Plugin '%s' registered.", plugin.name)
|
||||
|
||||
def get_plugin(self, plugin_name: str) -> Optional[BasePlugin]:
|
||||
"""Retrieves a registered plugin by its name.
|
||||
|
||||
Args:
|
||||
plugin_name: The name of the plugin to retrieve.
|
||||
|
||||
Returns:
|
||||
The plugin instance if found, otherwise `None`.
|
||||
"""
|
||||
return next((p for p in self.plugins if p.name == plugin_name), None)
|
||||
|
||||
async def run_on_user_message_callback(
|
||||
self,
|
||||
*,
|
||||
user_message: types.Content,
|
||||
invocation_context: InvocationContext,
|
||||
) -> Optional[types.Content]:
|
||||
"""Runs the `on_user_message_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"on_user_message_callback",
|
||||
user_message=user_message,
|
||||
invocation_context=invocation_context,
|
||||
)
|
||||
|
||||
async def run_before_run_callback(
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> Optional[types.Content]:
|
||||
"""Runs the `before_run_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"before_run_callback", invocation_context=invocation_context
|
||||
)
|
||||
|
||||
async def run_after_run_callback(
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> Optional[None]:
|
||||
"""Runs the `after_run_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"after_run_callback", invocation_context=invocation_context
|
||||
)
|
||||
|
||||
async def run_on_event_callback(
|
||||
self, *, invocation_context: InvocationContext, event: Event
|
||||
) -> Optional[Event]:
|
||||
"""Runs the `on_event_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"on_event_callback",
|
||||
invocation_context=invocation_context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
async def run_before_agent_callback(
|
||||
self, *, agent: BaseAgent, callback_context: CallbackContext
|
||||
) -> Optional[types.Content]:
|
||||
"""Runs the `before_agent_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"before_agent_callback",
|
||||
agent=agent,
|
||||
callback_context=callback_context,
|
||||
)
|
||||
|
||||
async def run_after_agent_callback(
|
||||
self, *, agent: BaseAgent, callback_context: CallbackContext
|
||||
) -> Optional[types.Content]:
|
||||
"""Runs the `after_agent_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"after_agent_callback",
|
||||
agent=agent,
|
||||
callback_context=callback_context,
|
||||
)
|
||||
|
||||
async def run_before_tool_callback(
|
||||
self,
|
||||
*,
|
||||
tool: BaseTool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Optional[dict]:
|
||||
"""Runs the `before_tool_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"before_tool_callback",
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
async def run_after_tool_callback(
|
||||
self,
|
||||
*,
|
||||
tool: BaseTool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
result: dict,
|
||||
) -> Optional[dict]:
|
||||
"""Runs the `after_tool_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"after_tool_callback",
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
tool_context=tool_context,
|
||||
result=result,
|
||||
)
|
||||
|
||||
async def run_before_model_callback(
|
||||
self, *, callback_context: CallbackContext, llm_request: LlmRequest
|
||||
) -> Optional[LlmResponse]:
|
||||
"""Runs the `before_model_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"before_model_callback",
|
||||
callback_context=callback_context,
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
async def run_after_model_callback(
|
||||
self, *, callback_context: CallbackContext, llm_response: LlmResponse
|
||||
) -> Optional[LlmResponse]:
|
||||
"""Runs the `after_model_callback` for all plugins."""
|
||||
return await self._run_callbacks(
|
||||
"after_model_callback",
|
||||
callback_context=callback_context,
|
||||
llm_response=llm_response,
|
||||
)
|
||||
|
||||
async def _run_callbacks(
|
||||
self, callback_name: PluginCallbackName, **kwargs: Any
|
||||
) -> Optional[Any]:
|
||||
"""Executes a specific callback for all registered plugins.
|
||||
|
||||
This private method iterates through the plugins and calls the specified
|
||||
callback method on each one, passing the provided keyword arguments.
|
||||
|
||||
The execution stops as soon as a plugin's callback returns a non-`None`
|
||||
value. This "early exit" value is then returned by this method. If all
|
||||
plugins are executed and all return `None`, this method also returns `None`.
|
||||
|
||||
Args:
|
||||
callback_name: The name of the callback method to execute.
|
||||
**kwargs: Keyword arguments to be passed to the callback method.
|
||||
|
||||
Returns:
|
||||
The first non-`None` value returned by a plugin callback, or `None` if
|
||||
all callbacks return `None`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If a plugin encounters an unhandled exception during
|
||||
execution. The original exception is chained.
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
# Each plugin might not implement all callbacks. The base class provides
|
||||
# default `pass` implementations, so `getattr` will always succeed.
|
||||
callback_method = getattr(plugin, callback_name)
|
||||
try:
|
||||
result = await callback_method(**kwargs)
|
||||
if result is not None:
|
||||
# Early exit: A plugin has returned a value. We stop
|
||||
# processing further plugins and return this value immediately.
|
||||
logger.debug(
|
||||
"Plugin '%s' returned a value for callback '%s', exiting early.",
|
||||
plugin.name,
|
||||
callback_name,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
error_message = (
|
||||
f"Error in plugin '{plugin.name}' during '{callback_name}'"
|
||||
f" callback: {e}"
|
||||
)
|
||||
logger.error(error_message, exc_info=True)
|
||||
raise RuntimeError(error_message) from e
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,239 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.plugins.base_plugin import BasePlugin
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
class TestablePlugin(BasePlugin):
|
||||
__test__ = False
|
||||
"""A concrete implementation of BasePlugin for testing purposes."""
|
||||
pass
|
||||
|
||||
|
||||
class FullOverridePlugin(BasePlugin):
|
||||
__test__ = False
|
||||
|
||||
"""A plugin that overrides every single callback method for testing."""
|
||||
|
||||
def __init__(self, name: str = "full_override"):
|
||||
super().__init__(name)
|
||||
|
||||
async def on_user_message_callback(self, **kwargs) -> str:
|
||||
return "overridden_on_user_message"
|
||||
|
||||
async def before_run_callback(self, **kwargs) -> str:
|
||||
return "overridden_before_run"
|
||||
|
||||
async def after_run_callback(self, **kwargs) -> str:
|
||||
return "overridden_after_run"
|
||||
|
||||
async def on_event_callback(self, **kwargs) -> str:
|
||||
return "overridden_on_event"
|
||||
|
||||
async def before_agent_callback(self, **kwargs) -> str:
|
||||
return "overridden_before_agent"
|
||||
|
||||
async def after_agent_callback(self, **kwargs) -> str:
|
||||
return "overridden_after_agent"
|
||||
|
||||
async def before_tool_callback(self, **kwargs) -> str:
|
||||
return "overridden_before_tool"
|
||||
|
||||
async def after_tool_callback(self, **kwargs) -> str:
|
||||
return "overridden_after_tool"
|
||||
|
||||
async def before_model_callback(self, **kwargs) -> str:
|
||||
return "overridden_before_model"
|
||||
|
||||
async def after_model_callback(self, **kwargs) -> str:
|
||||
return "overridden_after_model"
|
||||
|
||||
|
||||
def test_base_plugin_initialization():
|
||||
"""Tests that a plugin is initialized with the correct name."""
|
||||
plugin_name = "my_test_plugin"
|
||||
plugin = TestablePlugin(name=plugin_name)
|
||||
assert plugin.name == plugin_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_plugin_default_callbacks_return_none():
|
||||
"""Tests that the default (non-overridden) callbacks in BasePlugin exist
|
||||
|
||||
and return None as expected.
|
||||
"""
|
||||
plugin = TestablePlugin(name="default_plugin")
|
||||
|
||||
# Mocking all necessary context objects
|
||||
mock_context = Mock()
|
||||
mock_user_message = Mock()
|
||||
|
||||
# The default implementations should do nothing and return None.
|
||||
assert (
|
||||
await plugin.on_user_message_callback(
|
||||
user_message=mock_user_message,
|
||||
invocation_context=mock_context,
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
await plugin.before_run_callback(invocation_context=mock_context) is None
|
||||
)
|
||||
assert (
|
||||
await plugin.after_run_callback(invocation_context=mock_context) is None
|
||||
)
|
||||
assert (
|
||||
await plugin.on_event_callback(
|
||||
invocation_context=mock_context, event=mock_context
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
await plugin.before_agent_callback(
|
||||
agent=mock_context, callback_context=mock_context
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
await plugin.after_agent_callback(
|
||||
agent=mock_context, callback_context=mock_context
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
await plugin.before_tool_callback(
|
||||
tool=mock_context, tool_args={}, tool_context=mock_context
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
await plugin.after_tool_callback(
|
||||
tool=mock_context, tool_args={}, tool_context=mock_context, result={}
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
await plugin.before_model_callback(
|
||||
callback_context=mock_context, llm_request=mock_context
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
await plugin.after_model_callback(
|
||||
callback_context=mock_context, llm_response=mock_context
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_plugin_all_callbacks_can_be_overridden():
|
||||
"""Verifies that a user can create a subclass of BasePlugin and that all
|
||||
|
||||
overridden methods are correctly called.
|
||||
"""
|
||||
plugin = FullOverridePlugin()
|
||||
|
||||
# Create mock objects for all required arguments. We don't need real
|
||||
# objects, just placeholders to satisfy the method signatures.
|
||||
mock_user_message = Mock(spec=types.Content)
|
||||
mock_invocation_context = Mock(spec=InvocationContext)
|
||||
mock_callback_context = Mock(spec=CallbackContext)
|
||||
mock_agent = Mock(spec=BaseAgent)
|
||||
mock_tool = Mock(spec=BaseTool)
|
||||
mock_tool_context = Mock(spec=ToolContext)
|
||||
mock_llm_request = Mock(spec=LlmRequest)
|
||||
mock_llm_response = Mock(spec=LlmResponse)
|
||||
mock_event = Mock(spec=Event)
|
||||
|
||||
# Call each method and assert it returns the unique string from the override.
|
||||
# This proves that the subclass's method was executed.
|
||||
assert (
|
||||
await plugin.on_user_message_callback(
|
||||
user_message=mock_user_message,
|
||||
invocation_context=mock_invocation_context,
|
||||
)
|
||||
== "overridden_on_user_message"
|
||||
)
|
||||
assert (
|
||||
await plugin.before_run_callback(
|
||||
invocation_context=mock_invocation_context
|
||||
)
|
||||
== "overridden_before_run"
|
||||
)
|
||||
assert (
|
||||
await plugin.after_run_callback(
|
||||
invocation_context=mock_invocation_context
|
||||
)
|
||||
== "overridden_after_run"
|
||||
)
|
||||
assert (
|
||||
await plugin.on_event_callback(
|
||||
invocation_context=mock_invocation_context, event=mock_event
|
||||
)
|
||||
== "overridden_on_event"
|
||||
)
|
||||
assert (
|
||||
await plugin.before_agent_callback(
|
||||
agent=mock_agent, callback_context=mock_callback_context
|
||||
)
|
||||
== "overridden_before_agent"
|
||||
)
|
||||
assert (
|
||||
await plugin.after_agent_callback(
|
||||
agent=mock_agent, callback_context=mock_callback_context
|
||||
)
|
||||
== "overridden_after_agent"
|
||||
)
|
||||
assert (
|
||||
await plugin.before_model_callback(
|
||||
callback_context=mock_callback_context, llm_request=mock_llm_request
|
||||
)
|
||||
== "overridden_before_model"
|
||||
)
|
||||
assert (
|
||||
await plugin.after_model_callback(
|
||||
callback_context=mock_callback_context, llm_response=mock_llm_response
|
||||
)
|
||||
== "overridden_after_model"
|
||||
)
|
||||
assert (
|
||||
await plugin.before_tool_callback(
|
||||
tool=mock_tool, tool_args={}, tool_context=mock_tool_context
|
||||
)
|
||||
== "overridden_before_tool"
|
||||
)
|
||||
assert (
|
||||
await plugin.after_tool_callback(
|
||||
tool=mock_tool,
|
||||
tool_args={},
|
||||
tool_context=mock_tool_context,
|
||||
result={},
|
||||
)
|
||||
== "overridden_after_tool"
|
||||
)
|
||||
@@ -0,0 +1,250 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the PluginManager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.plugins.base_plugin import BasePlugin
|
||||
# Assume the following path to your modules
|
||||
# You might need to adjust this based on your project structure.
|
||||
from google.adk.plugins.plugin_manager import PluginCallbackName
|
||||
from google.adk.plugins.plugin_manager import PluginManager
|
||||
import pytest
|
||||
|
||||
|
||||
# A helper class to use in tests instead of mocks.
|
||||
# This makes tests more explicit and easier to debug.
|
||||
class TestPlugin(BasePlugin):
|
||||
__test__ = False
|
||||
"""
|
||||
A test plugin that can be configured to return specific values or raise
|
||||
exceptions for any callback, and it logs which callbacks were invoked.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
# A log to track the names of callbacks that have been called.
|
||||
self.call_log: list[PluginCallbackName] = []
|
||||
# A map to configure return values for specific callbacks.
|
||||
self.return_values: dict[PluginCallbackName, any] = {}
|
||||
# A map to configure exceptions to be raised by specific callbacks.
|
||||
self.exceptions_to_raise: dict[PluginCallbackName, Exception] = {}
|
||||
|
||||
async def _handle_callback(self, name: PluginCallbackName):
|
||||
"""Generic handler for all callback methods."""
|
||||
self.call_log.append(name)
|
||||
if name in self.exceptions_to_raise:
|
||||
raise self.exceptions_to_raise[name]
|
||||
return self.return_values.get(name)
|
||||
|
||||
# Implement all callback methods from the BasePlugin interface.
|
||||
async def on_user_message_callback(self, **kwargs):
|
||||
return await self._handle_callback("on_user_message_callback")
|
||||
|
||||
async def before_run_callback(self, **kwargs):
|
||||
return await self._handle_callback("before_run_callback")
|
||||
|
||||
async def after_run_callback(self, **kwargs):
|
||||
return await self._handle_callback("after_run_callback")
|
||||
|
||||
async def on_event_callback(self, **kwargs):
|
||||
return await self._handle_callback("on_event_callback")
|
||||
|
||||
async def before_agent_callback(self, **kwargs):
|
||||
return await self._handle_callback("before_agent_callback")
|
||||
|
||||
async def after_agent_callback(self, **kwargs):
|
||||
return await self._handle_callback("after_agent_callback")
|
||||
|
||||
async def before_tool_callback(self, **kwargs):
|
||||
return await self._handle_callback("before_tool_callback")
|
||||
|
||||
async def after_tool_callback(self, **kwargs):
|
||||
return await self._handle_callback("after_tool_callback")
|
||||
|
||||
async def before_model_callback(self, **kwargs):
|
||||
return await self._handle_callback("before_model_callback")
|
||||
|
||||
async def after_model_callback(self, **kwargs):
|
||||
return await self._handle_callback("after_model_callback")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> PluginManager:
|
||||
"""Provides a clean PluginManager instance for each test."""
|
||||
return PluginManager()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plugin1() -> TestPlugin:
|
||||
"""Provides a clean instance of our test plugin named 'plugin1'."""
|
||||
return TestPlugin(name="plugin1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plugin2() -> TestPlugin:
|
||||
"""Provides a clean instance of our test plugin named 'plugin2'."""
|
||||
return TestPlugin(name="plugin2")
|
||||
|
||||
|
||||
def test_register_and_get_plugin(service: PluginManager, plugin1: TestPlugin):
|
||||
"""Tests successful registration and retrieval of a plugin."""
|
||||
service.register_plugin(plugin1)
|
||||
|
||||
assert len(service.plugins) == 1
|
||||
assert service.plugins[0] is plugin1
|
||||
assert service.get_plugin("plugin1") is plugin1
|
||||
|
||||
|
||||
def test_register_duplicate_plugin_name_raises_value_error(
|
||||
service: PluginManager, plugin1: TestPlugin
|
||||
):
|
||||
"""Tests that registering a plugin with a duplicate name raises an error."""
|
||||
plugin1_duplicate = TestPlugin(name="plugin1")
|
||||
service.register_plugin(plugin1)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Plugin with name 'plugin1' already registered."
|
||||
):
|
||||
service.register_plugin(plugin1_duplicate)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_early_exit_stops_subsequent_plugins(
|
||||
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
|
||||
):
|
||||
"""Tests the core "early exit" logic: if a plugin returns a value,
|
||||
|
||||
subsequent plugins for that callback should not be executed.
|
||||
"""
|
||||
# Configure plugin1 to return a value, simulating a cache hit.
|
||||
mock_response = Mock(spec=LlmResponse)
|
||||
plugin1.return_values["before_run_callback"] = mock_response
|
||||
|
||||
service.register_plugin(plugin1)
|
||||
service.register_plugin(plugin2)
|
||||
|
||||
# Execute the callback chain.
|
||||
result = await service.run_before_run_callback(invocation_context=Mock())
|
||||
|
||||
# Assert that the final result is the one returned by the first plugin.
|
||||
assert result is mock_response
|
||||
# Assert that the first plugin was called.
|
||||
assert "before_run_callback" in plugin1.call_log
|
||||
# CRITICAL: Assert that the second plugin was never called.
|
||||
assert "before_run_callback" not in plugin2.call_log
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_flow_all_plugins_are_called(
|
||||
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
|
||||
):
|
||||
"""Tests that if no plugin returns a value, all plugins in the chain
|
||||
|
||||
are executed in order.
|
||||
"""
|
||||
# By default, plugins are configured to return None.
|
||||
service.register_plugin(plugin1)
|
||||
service.register_plugin(plugin2)
|
||||
|
||||
result = await service.run_before_run_callback(invocation_context=Mock())
|
||||
|
||||
# The final result should be None as no plugin interrupted the flow.
|
||||
assert result is None
|
||||
# Both plugins must have been called.
|
||||
assert "before_run_callback" in plugin1.call_log
|
||||
assert "before_run_callback" in plugin2.call_log
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_exception_is_wrapped_in_runtime_error(
|
||||
service: PluginManager, plugin1: TestPlugin
|
||||
):
|
||||
"""Tests that if a plugin callback raises an exception, the PluginManager
|
||||
|
||||
catches it and raises a descriptive RuntimeError.
|
||||
"""
|
||||
# Configure the plugin to raise an error during a specific callback.
|
||||
original_exception = ValueError("Something went wrong inside the plugin!")
|
||||
plugin1.exceptions_to_raise["before_run_callback"] = original_exception
|
||||
service.register_plugin(plugin1)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
await service.run_before_run_callback(invocation_context=Mock())
|
||||
|
||||
# Check that the error message is informative.
|
||||
assert "Error in plugin 'plugin1'" in str(excinfo.value)
|
||||
assert "before_run_callback" in str(excinfo.value)
|
||||
# Check that the original exception is chained for better tracebacks.
|
||||
assert excinfo.value.__cause__ is original_exception
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_callbacks_are_supported(
|
||||
service: PluginManager, plugin1: TestPlugin
|
||||
):
|
||||
"""Tests that all callbacks defined in the BasePlugin interface are supported
|
||||
|
||||
by the PluginManager.
|
||||
"""
|
||||
service.register_plugin(plugin1)
|
||||
mock_context = Mock()
|
||||
mock_user_message = Mock()
|
||||
|
||||
# Test all callbacks
|
||||
await service.run_on_user_message_callback(
|
||||
user_message=mock_user_message, invocation_context=mock_context
|
||||
)
|
||||
await service.run_before_run_callback(invocation_context=mock_context)
|
||||
await service.run_after_run_callback(invocation_context=mock_context)
|
||||
await service.run_on_event_callback(
|
||||
invocation_context=mock_context, event=mock_context
|
||||
)
|
||||
await service.run_before_agent_callback(
|
||||
agent=mock_context, callback_context=mock_context
|
||||
)
|
||||
await service.run_after_agent_callback(
|
||||
agent=mock_context, callback_context=mock_context
|
||||
)
|
||||
await service.run_before_tool_callback(
|
||||
tool=mock_context, tool_args={}, tool_context=mock_context
|
||||
)
|
||||
await service.run_after_tool_callback(
|
||||
tool=mock_context, tool_args={}, tool_context=mock_context, result={}
|
||||
)
|
||||
await service.run_before_model_callback(
|
||||
callback_context=mock_context, llm_request=mock_context
|
||||
)
|
||||
await service.run_after_model_callback(
|
||||
callback_context=mock_context, llm_response=mock_context
|
||||
)
|
||||
|
||||
# Verify all callbacks were logged
|
||||
expected_callbacks = [
|
||||
"on_user_message_callback",
|
||||
"before_run_callback",
|
||||
"after_run_callback",
|
||||
"on_event_callback",
|
||||
"before_agent_callback",
|
||||
"after_agent_callback",
|
||||
"before_tool_callback",
|
||||
"after_tool_callback",
|
||||
"before_model_callback",
|
||||
"after_model_callback",
|
||||
]
|
||||
assert set(plugin1.call_log) == set(expected_callbacks)
|
||||
Reference in New Issue
Block a user