feat: Implement PluginService for registering and executing plugins

PluginService takes the registration of plugins, and provide the wrapper utilities to execute all plugins.

PiperOrigin-RevId: 781745769
This commit is contained in:
Che Liu
2025-07-10 17:24:59 -07:00
committed by Copybara-Service
parent 4dce9ef519
commit 16ba91cd01
3 changed files with 754 additions and 0 deletions
+265
View File
@@ -0,0 +1,265 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Any
from typing import List
from typing import Literal
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
from .base_plugin import BasePlugin
if TYPE_CHECKING:
from ..agents.base_agent import BaseAgent
from ..agents.callback_context import CallbackContext
from ..agents.invocation_context import InvocationContext
from ..events.event import Event
from ..models.llm_request import LlmRequest
from ..models.llm_response import LlmResponse
from ..tools.base_tool import BaseTool
from ..tools.tool_context import ToolContext
# A type alias for the names of the available plugin callbacks.
# This helps with static analysis and prevents typos when calling run_callbacks.
PluginCallbackName = Literal[
"on_user_message_callback",
"before_run_callback",
"after_run_callback",
"on_event_callback",
"before_agent_callback",
"after_agent_callback",
"before_tool_callback",
"after_tool_callback",
"before_model_callback",
"after_model_callback",
]
logger = logging.getLogger("google_adk." + __name__)
class PluginManager:
"""Manages the registration and execution of plugins.
The PluginManager is an internal class that orchestrates the invocation of
plugin callbacks at key points in the SDK's execution lifecycle. It maintains
a list of registered plugins and ensures they are called in the order they
were registered.
The core execution logic implements an "early exit" strategy: if any plugin
callback returns a non-`None` value, the execution of subsequent plugins for
that specific event is halted, and the returned value is propagated up the
call stack. This allows plugins to short-circuit operations like agent runs,
tool calls, or model requests.
"""
def __init__(self, plugins: Optional[List[BasePlugin]] = None):
"""Initializes the plugin service.
Args:
plugins: An optional list of plugins to register upon initialization.
"""
self.plugins: List[BasePlugin] = []
if plugins:
for plugin in plugins:
self.register_plugin(plugin)
def register_plugin(self, plugin: BasePlugin) -> None:
"""Registers a new plugin.
Args:
plugin: The plugin instance to register.
Raises:
ValueError: If a plugin with the same name is already registered.
"""
if any(p.name == plugin.name for p in self.plugins):
raise ValueError(f"Plugin with name '{plugin.name}' already registered.")
self.plugins.append(plugin)
logger.info("Plugin '%s' registered.", plugin.name)
def get_plugin(self, plugin_name: str) -> Optional[BasePlugin]:
"""Retrieves a registered plugin by its name.
Args:
plugin_name: The name of the plugin to retrieve.
Returns:
The plugin instance if found, otherwise `None`.
"""
return next((p for p in self.plugins if p.name == plugin_name), None)
async def run_on_user_message_callback(
self,
*,
user_message: types.Content,
invocation_context: InvocationContext,
) -> Optional[types.Content]:
"""Runs the `on_user_message_callback` for all plugins."""
return await self._run_callbacks(
"on_user_message_callback",
user_message=user_message,
invocation_context=invocation_context,
)
async def run_before_run_callback(
self, *, invocation_context: InvocationContext
) -> Optional[types.Content]:
"""Runs the `before_run_callback` for all plugins."""
return await self._run_callbacks(
"before_run_callback", invocation_context=invocation_context
)
async def run_after_run_callback(
self, *, invocation_context: InvocationContext
) -> Optional[None]:
"""Runs the `after_run_callback` for all plugins."""
return await self._run_callbacks(
"after_run_callback", invocation_context=invocation_context
)
async def run_on_event_callback(
self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]:
"""Runs the `on_event_callback` for all plugins."""
return await self._run_callbacks(
"on_event_callback",
invocation_context=invocation_context,
event=event,
)
async def run_before_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
) -> Optional[types.Content]:
"""Runs the `before_agent_callback` for all plugins."""
return await self._run_callbacks(
"before_agent_callback",
agent=agent,
callback_context=callback_context,
)
async def run_after_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
) -> Optional[types.Content]:
"""Runs the `after_agent_callback` for all plugins."""
return await self._run_callbacks(
"after_agent_callback",
agent=agent,
callback_context=callback_context,
)
async def run_before_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> Optional[dict]:
"""Runs the `before_tool_callback` for all plugins."""
return await self._run_callbacks(
"before_tool_callback",
tool=tool,
tool_args=tool_args,
tool_context=tool_context,
)
async def run_after_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
result: dict,
) -> Optional[dict]:
"""Runs the `after_tool_callback` for all plugins."""
return await self._run_callbacks(
"after_tool_callback",
tool=tool,
tool_args=tool_args,
tool_context=tool_context,
result=result,
)
async def run_before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
"""Runs the `before_model_callback` for all plugins."""
return await self._run_callbacks(
"before_model_callback",
callback_context=callback_context,
llm_request=llm_request,
)
async def run_after_model_callback(
self, *, callback_context: CallbackContext, llm_response: LlmResponse
) -> Optional[LlmResponse]:
"""Runs the `after_model_callback` for all plugins."""
return await self._run_callbacks(
"after_model_callback",
callback_context=callback_context,
llm_response=llm_response,
)
async def _run_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> Optional[Any]:
"""Executes a specific callback for all registered plugins.
This private method iterates through the plugins and calls the specified
callback method on each one, passing the provided keyword arguments.
The execution stops as soon as a plugin's callback returns a non-`None`
value. This "early exit" value is then returned by this method. If all
plugins are executed and all return `None`, this method also returns `None`.
Args:
callback_name: The name of the callback method to execute.
**kwargs: Keyword arguments to be passed to the callback method.
Returns:
The first non-`None` value returned by a plugin callback, or `None` if
all callbacks return `None`.
Raises:
RuntimeError: If a plugin encounters an unhandled exception during
execution. The original exception is chained.
"""
for plugin in self.plugins:
# Each plugin might not implement all callbacks. The base class provides
# default `pass` implementations, so `getattr` will always succeed.
callback_method = getattr(plugin, callback_name)
try:
result = await callback_method(**kwargs)
if result is not None:
# Early exit: A plugin has returned a value. We stop
# processing further plugins and return this value immediately.
logger.debug(
"Plugin '%s' returned a value for callback '%s', exiting early.",
plugin.name,
callback_name,
)
return result
except Exception as e:
error_message = (
f"Error in plugin '{plugin.name}' during '{callback_name}'"
f" callback: {e}"
)
logger.error(error_message, exc_info=True)
raise RuntimeError(error_message) from e
return None
+239
View File
@@ -0,0 +1,239 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from unittest.mock import Mock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
from google.adk.events.event import Event
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
from google.genai import types
import pytest
class TestablePlugin(BasePlugin):
__test__ = False
"""A concrete implementation of BasePlugin for testing purposes."""
pass
class FullOverridePlugin(BasePlugin):
__test__ = False
"""A plugin that overrides every single callback method for testing."""
def __init__(self, name: str = "full_override"):
super().__init__(name)
async def on_user_message_callback(self, **kwargs) -> str:
return "overridden_on_user_message"
async def before_run_callback(self, **kwargs) -> str:
return "overridden_before_run"
async def after_run_callback(self, **kwargs) -> str:
return "overridden_after_run"
async def on_event_callback(self, **kwargs) -> str:
return "overridden_on_event"
async def before_agent_callback(self, **kwargs) -> str:
return "overridden_before_agent"
async def after_agent_callback(self, **kwargs) -> str:
return "overridden_after_agent"
async def before_tool_callback(self, **kwargs) -> str:
return "overridden_before_tool"
async def after_tool_callback(self, **kwargs) -> str:
return "overridden_after_tool"
async def before_model_callback(self, **kwargs) -> str:
return "overridden_before_model"
async def after_model_callback(self, **kwargs) -> str:
return "overridden_after_model"
def test_base_plugin_initialization():
"""Tests that a plugin is initialized with the correct name."""
plugin_name = "my_test_plugin"
plugin = TestablePlugin(name=plugin_name)
assert plugin.name == plugin_name
@pytest.mark.asyncio
async def test_base_plugin_default_callbacks_return_none():
"""Tests that the default (non-overridden) callbacks in BasePlugin exist
and return None as expected.
"""
plugin = TestablePlugin(name="default_plugin")
# Mocking all necessary context objects
mock_context = Mock()
mock_user_message = Mock()
# The default implementations should do nothing and return None.
assert (
await plugin.on_user_message_callback(
user_message=mock_user_message,
invocation_context=mock_context,
)
is None
)
assert (
await plugin.before_run_callback(invocation_context=mock_context) is None
)
assert (
await plugin.after_run_callback(invocation_context=mock_context) is None
)
assert (
await plugin.on_event_callback(
invocation_context=mock_context, event=mock_context
)
is None
)
assert (
await plugin.before_agent_callback(
agent=mock_context, callback_context=mock_context
)
is None
)
assert (
await plugin.after_agent_callback(
agent=mock_context, callback_context=mock_context
)
is None
)
assert (
await plugin.before_tool_callback(
tool=mock_context, tool_args={}, tool_context=mock_context
)
is None
)
assert (
await plugin.after_tool_callback(
tool=mock_context, tool_args={}, tool_context=mock_context, result={}
)
is None
)
assert (
await plugin.before_model_callback(
callback_context=mock_context, llm_request=mock_context
)
is None
)
assert (
await plugin.after_model_callback(
callback_context=mock_context, llm_response=mock_context
)
is None
)
@pytest.mark.asyncio
async def test_base_plugin_all_callbacks_can_be_overridden():
"""Verifies that a user can create a subclass of BasePlugin and that all
overridden methods are correctly called.
"""
plugin = FullOverridePlugin()
# Create mock objects for all required arguments. We don't need real
# objects, just placeholders to satisfy the method signatures.
mock_user_message = Mock(spec=types.Content)
mock_invocation_context = Mock(spec=InvocationContext)
mock_callback_context = Mock(spec=CallbackContext)
mock_agent = Mock(spec=BaseAgent)
mock_tool = Mock(spec=BaseTool)
mock_tool_context = Mock(spec=ToolContext)
mock_llm_request = Mock(spec=LlmRequest)
mock_llm_response = Mock(spec=LlmResponse)
mock_event = Mock(spec=Event)
# Call each method and assert it returns the unique string from the override.
# This proves that the subclass's method was executed.
assert (
await plugin.on_user_message_callback(
user_message=mock_user_message,
invocation_context=mock_invocation_context,
)
== "overridden_on_user_message"
)
assert (
await plugin.before_run_callback(
invocation_context=mock_invocation_context
)
== "overridden_before_run"
)
assert (
await plugin.after_run_callback(
invocation_context=mock_invocation_context
)
== "overridden_after_run"
)
assert (
await plugin.on_event_callback(
invocation_context=mock_invocation_context, event=mock_event
)
== "overridden_on_event"
)
assert (
await plugin.before_agent_callback(
agent=mock_agent, callback_context=mock_callback_context
)
== "overridden_before_agent"
)
assert (
await plugin.after_agent_callback(
agent=mock_agent, callback_context=mock_callback_context
)
== "overridden_after_agent"
)
assert (
await plugin.before_model_callback(
callback_context=mock_callback_context, llm_request=mock_llm_request
)
== "overridden_before_model"
)
assert (
await plugin.after_model_callback(
callback_context=mock_callback_context, llm_response=mock_llm_response
)
== "overridden_after_model"
)
assert (
await plugin.before_tool_callback(
tool=mock_tool, tool_args={}, tool_context=mock_tool_context
)
== "overridden_before_tool"
)
assert (
await plugin.after_tool_callback(
tool=mock_tool,
tool_args={},
tool_context=mock_tool_context,
result={},
)
== "overridden_after_tool"
)
@@ -0,0 +1,250 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the PluginManager."""
from __future__ import annotations
from unittest.mock import Mock
from google.adk.models.llm_response import LlmResponse
from google.adk.plugins.base_plugin import BasePlugin
# Assume the following path to your modules
# You might need to adjust this based on your project structure.
from google.adk.plugins.plugin_manager import PluginCallbackName
from google.adk.plugins.plugin_manager import PluginManager
import pytest
# A helper class to use in tests instead of mocks.
# This makes tests more explicit and easier to debug.
class TestPlugin(BasePlugin):
__test__ = False
"""
A test plugin that can be configured to return specific values or raise
exceptions for any callback, and it logs which callbacks were invoked.
"""
def __init__(self, name: str):
super().__init__(name)
# A log to track the names of callbacks that have been called.
self.call_log: list[PluginCallbackName] = []
# A map to configure return values for specific callbacks.
self.return_values: dict[PluginCallbackName, any] = {}
# A map to configure exceptions to be raised by specific callbacks.
self.exceptions_to_raise: dict[PluginCallbackName, Exception] = {}
async def _handle_callback(self, name: PluginCallbackName):
"""Generic handler for all callback methods."""
self.call_log.append(name)
if name in self.exceptions_to_raise:
raise self.exceptions_to_raise[name]
return self.return_values.get(name)
# Implement all callback methods from the BasePlugin interface.
async def on_user_message_callback(self, **kwargs):
return await self._handle_callback("on_user_message_callback")
async def before_run_callback(self, **kwargs):
return await self._handle_callback("before_run_callback")
async def after_run_callback(self, **kwargs):
return await self._handle_callback("after_run_callback")
async def on_event_callback(self, **kwargs):
return await self._handle_callback("on_event_callback")
async def before_agent_callback(self, **kwargs):
return await self._handle_callback("before_agent_callback")
async def after_agent_callback(self, **kwargs):
return await self._handle_callback("after_agent_callback")
async def before_tool_callback(self, **kwargs):
return await self._handle_callback("before_tool_callback")
async def after_tool_callback(self, **kwargs):
return await self._handle_callback("after_tool_callback")
async def before_model_callback(self, **kwargs):
return await self._handle_callback("before_model_callback")
async def after_model_callback(self, **kwargs):
return await self._handle_callback("after_model_callback")
@pytest.fixture
def service() -> PluginManager:
"""Provides a clean PluginManager instance for each test."""
return PluginManager()
@pytest.fixture
def plugin1() -> TestPlugin:
"""Provides a clean instance of our test plugin named 'plugin1'."""
return TestPlugin(name="plugin1")
@pytest.fixture
def plugin2() -> TestPlugin:
"""Provides a clean instance of our test plugin named 'plugin2'."""
return TestPlugin(name="plugin2")
def test_register_and_get_plugin(service: PluginManager, plugin1: TestPlugin):
"""Tests successful registration and retrieval of a plugin."""
service.register_plugin(plugin1)
assert len(service.plugins) == 1
assert service.plugins[0] is plugin1
assert service.get_plugin("plugin1") is plugin1
def test_register_duplicate_plugin_name_raises_value_error(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that registering a plugin with a duplicate name raises an error."""
plugin1_duplicate = TestPlugin(name="plugin1")
service.register_plugin(plugin1)
with pytest.raises(
ValueError, match="Plugin with name 'plugin1' already registered."
):
service.register_plugin(plugin1_duplicate)
@pytest.mark.asyncio
async def test_early_exit_stops_subsequent_plugins(
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
):
"""Tests the core "early exit" logic: if a plugin returns a value,
subsequent plugins for that callback should not be executed.
"""
# Configure plugin1 to return a value, simulating a cache hit.
mock_response = Mock(spec=LlmResponse)
plugin1.return_values["before_run_callback"] = mock_response
service.register_plugin(plugin1)
service.register_plugin(plugin2)
# Execute the callback chain.
result = await service.run_before_run_callback(invocation_context=Mock())
# Assert that the final result is the one returned by the first plugin.
assert result is mock_response
# Assert that the first plugin was called.
assert "before_run_callback" in plugin1.call_log
# CRITICAL: Assert that the second plugin was never called.
assert "before_run_callback" not in plugin2.call_log
@pytest.mark.asyncio
async def test_normal_flow_all_plugins_are_called(
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
):
"""Tests that if no plugin returns a value, all plugins in the chain
are executed in order.
"""
# By default, plugins are configured to return None.
service.register_plugin(plugin1)
service.register_plugin(plugin2)
result = await service.run_before_run_callback(invocation_context=Mock())
# The final result should be None as no plugin interrupted the flow.
assert result is None
# Both plugins must have been called.
assert "before_run_callback" in plugin1.call_log
assert "before_run_callback" in plugin2.call_log
@pytest.mark.asyncio
async def test_plugin_exception_is_wrapped_in_runtime_error(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that if a plugin callback raises an exception, the PluginManager
catches it and raises a descriptive RuntimeError.
"""
# Configure the plugin to raise an error during a specific callback.
original_exception = ValueError("Something went wrong inside the plugin!")
plugin1.exceptions_to_raise["before_run_callback"] = original_exception
service.register_plugin(plugin1)
with pytest.raises(RuntimeError) as excinfo:
await service.run_before_run_callback(invocation_context=Mock())
# Check that the error message is informative.
assert "Error in plugin 'plugin1'" in str(excinfo.value)
assert "before_run_callback" in str(excinfo.value)
# Check that the original exception is chained for better tracebacks.
assert excinfo.value.__cause__ is original_exception
@pytest.mark.asyncio
async def test_all_callbacks_are_supported(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that all callbacks defined in the BasePlugin interface are supported
by the PluginManager.
"""
service.register_plugin(plugin1)
mock_context = Mock()
mock_user_message = Mock()
# Test all callbacks
await service.run_on_user_message_callback(
user_message=mock_user_message, invocation_context=mock_context
)
await service.run_before_run_callback(invocation_context=mock_context)
await service.run_after_run_callback(invocation_context=mock_context)
await service.run_on_event_callback(
invocation_context=mock_context, event=mock_context
)
await service.run_before_agent_callback(
agent=mock_context, callback_context=mock_context
)
await service.run_after_agent_callback(
agent=mock_context, callback_context=mock_context
)
await service.run_before_tool_callback(
tool=mock_context, tool_args={}, tool_context=mock_context
)
await service.run_after_tool_callback(
tool=mock_context, tool_args={}, tool_context=mock_context, result={}
)
await service.run_before_model_callback(
callback_context=mock_context, llm_request=mock_context
)
await service.run_after_model_callback(
callback_context=mock_context, llm_response=mock_context
)
# Verify all callbacks were logged
expected_callbacks = [
"on_user_message_callback",
"before_run_callback",
"after_run_callback",
"on_event_callback",
"before_agent_callback",
"after_agent_callback",
"before_tool_callback",
"after_tool_callback",
"before_model_callback",
"after_model_callback",
]
assert set(plugin1.call_log) == set(expected_callbacks)