From 16ba91cd01eede0d7411ae85be12134e449d4c51 Mon Sep 17 00:00:00 2001 From: Che Liu Date: Thu, 10 Jul 2025 17:24:59 -0700 Subject: [PATCH] feat: Implement PluginService for registering and executing plugins PluginService takes the registration of plugins, and provide the wrapper utilities to execute all plugins. PiperOrigin-RevId: 781745769 --- src/google/adk/plugins/plugin_manager.py | 265 ++++++++++++++++++ tests/unittests/plugins/test_base_plugin.py | 239 ++++++++++++++++ .../unittests/plugins/test_plugin_manager.py | 250 +++++++++++++++++ 3 files changed, 754 insertions(+) create mode 100644 src/google/adk/plugins/plugin_manager.py create mode 100644 tests/unittests/plugins/test_base_plugin.py create mode 100644 tests/unittests/plugins/test_plugin_manager.py diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py new file mode 100644 index 00000000..3680c351 --- /dev/null +++ b/src/google/adk/plugins/plugin_manager.py @@ -0,0 +1,265 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any +from typing import List +from typing import Literal +from typing import Optional +from typing import TYPE_CHECKING + +from google.genai import types + +from .base_plugin import BasePlugin + +if TYPE_CHECKING: + from ..agents.base_agent import BaseAgent + from ..agents.callback_context import CallbackContext + from ..agents.invocation_context import InvocationContext + from ..events.event import Event + from ..models.llm_request import LlmRequest + from ..models.llm_response import LlmResponse + from ..tools.base_tool import BaseTool + from ..tools.tool_context import ToolContext + +# A type alias for the names of the available plugin callbacks. +# This helps with static analysis and prevents typos when calling run_callbacks. +PluginCallbackName = Literal[ + "on_user_message_callback", + "before_run_callback", + "after_run_callback", + "on_event_callback", + "before_agent_callback", + "after_agent_callback", + "before_tool_callback", + "after_tool_callback", + "before_model_callback", + "after_model_callback", +] + +logger = logging.getLogger("google_adk." + __name__) + + +class PluginManager: + """Manages the registration and execution of plugins. + + The PluginManager is an internal class that orchestrates the invocation of + plugin callbacks at key points in the SDK's execution lifecycle. It maintains + a list of registered plugins and ensures they are called in the order they + were registered. + + The core execution logic implements an "early exit" strategy: if any plugin + callback returns a non-`None` value, the execution of subsequent plugins for + that specific event is halted, and the returned value is propagated up the + call stack. This allows plugins to short-circuit operations like agent runs, + tool calls, or model requests. + """ + + def __init__(self, plugins: Optional[List[BasePlugin]] = None): + """Initializes the plugin service. + + Args: + plugins: An optional list of plugins to register upon initialization. + """ + self.plugins: List[BasePlugin] = [] + if plugins: + for plugin in plugins: + self.register_plugin(plugin) + + def register_plugin(self, plugin: BasePlugin) -> None: + """Registers a new plugin. + + Args: + plugin: The plugin instance to register. + + Raises: + ValueError: If a plugin with the same name is already registered. + """ + if any(p.name == plugin.name for p in self.plugins): + raise ValueError(f"Plugin with name '{plugin.name}' already registered.") + self.plugins.append(plugin) + logger.info("Plugin '%s' registered.", plugin.name) + + def get_plugin(self, plugin_name: str) -> Optional[BasePlugin]: + """Retrieves a registered plugin by its name. + + Args: + plugin_name: The name of the plugin to retrieve. + + Returns: + The plugin instance if found, otherwise `None`. + """ + return next((p for p in self.plugins if p.name == plugin_name), None) + + async def run_on_user_message_callback( + self, + *, + user_message: types.Content, + invocation_context: InvocationContext, + ) -> Optional[types.Content]: + """Runs the `on_user_message_callback` for all plugins.""" + return await self._run_callbacks( + "on_user_message_callback", + user_message=user_message, + invocation_context=invocation_context, + ) + + async def run_before_run_callback( + self, *, invocation_context: InvocationContext + ) -> Optional[types.Content]: + """Runs the `before_run_callback` for all plugins.""" + return await self._run_callbacks( + "before_run_callback", invocation_context=invocation_context + ) + + async def run_after_run_callback( + self, *, invocation_context: InvocationContext + ) -> Optional[None]: + """Runs the `after_run_callback` for all plugins.""" + return await self._run_callbacks( + "after_run_callback", invocation_context=invocation_context + ) + + async def run_on_event_callback( + self, *, invocation_context: InvocationContext, event: Event + ) -> Optional[Event]: + """Runs the `on_event_callback` for all plugins.""" + return await self._run_callbacks( + "on_event_callback", + invocation_context=invocation_context, + event=event, + ) + + async def run_before_agent_callback( + self, *, agent: BaseAgent, callback_context: CallbackContext + ) -> Optional[types.Content]: + """Runs the `before_agent_callback` for all plugins.""" + return await self._run_callbacks( + "before_agent_callback", + agent=agent, + callback_context=callback_context, + ) + + async def run_after_agent_callback( + self, *, agent: BaseAgent, callback_context: CallbackContext + ) -> Optional[types.Content]: + """Runs the `after_agent_callback` for all plugins.""" + return await self._run_callbacks( + "after_agent_callback", + agent=agent, + callback_context=callback_context, + ) + + async def run_before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + ) -> Optional[dict]: + """Runs the `before_tool_callback` for all plugins.""" + return await self._run_callbacks( + "before_tool_callback", + tool=tool, + tool_args=tool_args, + tool_context=tool_context, + ) + + async def run_after_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + result: dict, + ) -> Optional[dict]: + """Runs the `after_tool_callback` for all plugins.""" + return await self._run_callbacks( + "after_tool_callback", + tool=tool, + tool_args=tool_args, + tool_context=tool_context, + result=result, + ) + + async def run_before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Runs the `before_model_callback` for all plugins.""" + return await self._run_callbacks( + "before_model_callback", + callback_context=callback_context, + llm_request=llm_request, + ) + + async def run_after_model_callback( + self, *, callback_context: CallbackContext, llm_response: LlmResponse + ) -> Optional[LlmResponse]: + """Runs the `after_model_callback` for all plugins.""" + return await self._run_callbacks( + "after_model_callback", + callback_context=callback_context, + llm_response=llm_response, + ) + + async def _run_callbacks( + self, callback_name: PluginCallbackName, **kwargs: Any + ) -> Optional[Any]: + """Executes a specific callback for all registered plugins. + + This private method iterates through the plugins and calls the specified + callback method on each one, passing the provided keyword arguments. + + The execution stops as soon as a plugin's callback returns a non-`None` + value. This "early exit" value is then returned by this method. If all + plugins are executed and all return `None`, this method also returns `None`. + + Args: + callback_name: The name of the callback method to execute. + **kwargs: Keyword arguments to be passed to the callback method. + + Returns: + The first non-`None` value returned by a plugin callback, or `None` if + all callbacks return `None`. + + Raises: + RuntimeError: If a plugin encounters an unhandled exception during + execution. The original exception is chained. + """ + for plugin in self.plugins: + # Each plugin might not implement all callbacks. The base class provides + # default `pass` implementations, so `getattr` will always succeed. + callback_method = getattr(plugin, callback_name) + try: + result = await callback_method(**kwargs) + if result is not None: + # Early exit: A plugin has returned a value. We stop + # processing further plugins and return this value immediately. + logger.debug( + "Plugin '%s' returned a value for callback '%s', exiting early.", + plugin.name, + callback_name, + ) + return result + except Exception as e: + error_message = ( + f"Error in plugin '{plugin.name}' during '{callback_name}'" + f" callback: {e}" + ) + logger.error(error_message, exc_info=True) + raise RuntimeError(error_message) from e + + return None diff --git a/tests/unittests/plugins/test_base_plugin.py b/tests/unittests/plugins/test_base_plugin.py new file mode 100644 index 00000000..04b1c3e9 --- /dev/null +++ b/tests/unittests/plugins/test_base_plugin.py @@ -0,0 +1,239 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import Mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + + +class TestablePlugin(BasePlugin): + __test__ = False + """A concrete implementation of BasePlugin for testing purposes.""" + pass + + +class FullOverridePlugin(BasePlugin): + __test__ = False + + """A plugin that overrides every single callback method for testing.""" + + def __init__(self, name: str = "full_override"): + super().__init__(name) + + async def on_user_message_callback(self, **kwargs) -> str: + return "overridden_on_user_message" + + async def before_run_callback(self, **kwargs) -> str: + return "overridden_before_run" + + async def after_run_callback(self, **kwargs) -> str: + return "overridden_after_run" + + async def on_event_callback(self, **kwargs) -> str: + return "overridden_on_event" + + async def before_agent_callback(self, **kwargs) -> str: + return "overridden_before_agent" + + async def after_agent_callback(self, **kwargs) -> str: + return "overridden_after_agent" + + async def before_tool_callback(self, **kwargs) -> str: + return "overridden_before_tool" + + async def after_tool_callback(self, **kwargs) -> str: + return "overridden_after_tool" + + async def before_model_callback(self, **kwargs) -> str: + return "overridden_before_model" + + async def after_model_callback(self, **kwargs) -> str: + return "overridden_after_model" + + +def test_base_plugin_initialization(): + """Tests that a plugin is initialized with the correct name.""" + plugin_name = "my_test_plugin" + plugin = TestablePlugin(name=plugin_name) + assert plugin.name == plugin_name + + +@pytest.mark.asyncio +async def test_base_plugin_default_callbacks_return_none(): + """Tests that the default (non-overridden) callbacks in BasePlugin exist + + and return None as expected. + """ + plugin = TestablePlugin(name="default_plugin") + + # Mocking all necessary context objects + mock_context = Mock() + mock_user_message = Mock() + + # The default implementations should do nothing and return None. + assert ( + await plugin.on_user_message_callback( + user_message=mock_user_message, + invocation_context=mock_context, + ) + is None + ) + assert ( + await plugin.before_run_callback(invocation_context=mock_context) is None + ) + assert ( + await plugin.after_run_callback(invocation_context=mock_context) is None + ) + assert ( + await plugin.on_event_callback( + invocation_context=mock_context, event=mock_context + ) + is None + ) + assert ( + await plugin.before_agent_callback( + agent=mock_context, callback_context=mock_context + ) + is None + ) + assert ( + await plugin.after_agent_callback( + agent=mock_context, callback_context=mock_context + ) + is None + ) + assert ( + await plugin.before_tool_callback( + tool=mock_context, tool_args={}, tool_context=mock_context + ) + is None + ) + assert ( + await plugin.after_tool_callback( + tool=mock_context, tool_args={}, tool_context=mock_context, result={} + ) + is None + ) + assert ( + await plugin.before_model_callback( + callback_context=mock_context, llm_request=mock_context + ) + is None + ) + assert ( + await plugin.after_model_callback( + callback_context=mock_context, llm_response=mock_context + ) + is None + ) + + +@pytest.mark.asyncio +async def test_base_plugin_all_callbacks_can_be_overridden(): + """Verifies that a user can create a subclass of BasePlugin and that all + + overridden methods are correctly called. + """ + plugin = FullOverridePlugin() + + # Create mock objects for all required arguments. We don't need real + # objects, just placeholders to satisfy the method signatures. + mock_user_message = Mock(spec=types.Content) + mock_invocation_context = Mock(spec=InvocationContext) + mock_callback_context = Mock(spec=CallbackContext) + mock_agent = Mock(spec=BaseAgent) + mock_tool = Mock(spec=BaseTool) + mock_tool_context = Mock(spec=ToolContext) + mock_llm_request = Mock(spec=LlmRequest) + mock_llm_response = Mock(spec=LlmResponse) + mock_event = Mock(spec=Event) + + # Call each method and assert it returns the unique string from the override. + # This proves that the subclass's method was executed. + assert ( + await plugin.on_user_message_callback( + user_message=mock_user_message, + invocation_context=mock_invocation_context, + ) + == "overridden_on_user_message" + ) + assert ( + await plugin.before_run_callback( + invocation_context=mock_invocation_context + ) + == "overridden_before_run" + ) + assert ( + await plugin.after_run_callback( + invocation_context=mock_invocation_context + ) + == "overridden_after_run" + ) + assert ( + await plugin.on_event_callback( + invocation_context=mock_invocation_context, event=mock_event + ) + == "overridden_on_event" + ) + assert ( + await plugin.before_agent_callback( + agent=mock_agent, callback_context=mock_callback_context + ) + == "overridden_before_agent" + ) + assert ( + await plugin.after_agent_callback( + agent=mock_agent, callback_context=mock_callback_context + ) + == "overridden_after_agent" + ) + assert ( + await plugin.before_model_callback( + callback_context=mock_callback_context, llm_request=mock_llm_request + ) + == "overridden_before_model" + ) + assert ( + await plugin.after_model_callback( + callback_context=mock_callback_context, llm_response=mock_llm_response + ) + == "overridden_after_model" + ) + assert ( + await plugin.before_tool_callback( + tool=mock_tool, tool_args={}, tool_context=mock_tool_context + ) + == "overridden_before_tool" + ) + assert ( + await plugin.after_tool_callback( + tool=mock_tool, + tool_args={}, + tool_context=mock_tool_context, + result={}, + ) + == "overridden_after_tool" + ) diff --git a/tests/unittests/plugins/test_plugin_manager.py b/tests/unittests/plugins/test_plugin_manager.py new file mode 100644 index 00000000..76d32a61 --- /dev/null +++ b/tests/unittests/plugins/test_plugin_manager.py @@ -0,0 +1,250 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the PluginManager.""" + +from __future__ import annotations + +from unittest.mock import Mock + +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +# Assume the following path to your modules +# You might need to adjust this based on your project structure. +from google.adk.plugins.plugin_manager import PluginCallbackName +from google.adk.plugins.plugin_manager import PluginManager +import pytest + + +# A helper class to use in tests instead of mocks. +# This makes tests more explicit and easier to debug. +class TestPlugin(BasePlugin): + __test__ = False + """ + A test plugin that can be configured to return specific values or raise + exceptions for any callback, and it logs which callbacks were invoked. + """ + + def __init__(self, name: str): + super().__init__(name) + # A log to track the names of callbacks that have been called. + self.call_log: list[PluginCallbackName] = [] + # A map to configure return values for specific callbacks. + self.return_values: dict[PluginCallbackName, any] = {} + # A map to configure exceptions to be raised by specific callbacks. + self.exceptions_to_raise: dict[PluginCallbackName, Exception] = {} + + async def _handle_callback(self, name: PluginCallbackName): + """Generic handler for all callback methods.""" + self.call_log.append(name) + if name in self.exceptions_to_raise: + raise self.exceptions_to_raise[name] + return self.return_values.get(name) + + # Implement all callback methods from the BasePlugin interface. + async def on_user_message_callback(self, **kwargs): + return await self._handle_callback("on_user_message_callback") + + async def before_run_callback(self, **kwargs): + return await self._handle_callback("before_run_callback") + + async def after_run_callback(self, **kwargs): + return await self._handle_callback("after_run_callback") + + async def on_event_callback(self, **kwargs): + return await self._handle_callback("on_event_callback") + + async def before_agent_callback(self, **kwargs): + return await self._handle_callback("before_agent_callback") + + async def after_agent_callback(self, **kwargs): + return await self._handle_callback("after_agent_callback") + + async def before_tool_callback(self, **kwargs): + return await self._handle_callback("before_tool_callback") + + async def after_tool_callback(self, **kwargs): + return await self._handle_callback("after_tool_callback") + + async def before_model_callback(self, **kwargs): + return await self._handle_callback("before_model_callback") + + async def after_model_callback(self, **kwargs): + return await self._handle_callback("after_model_callback") + + +@pytest.fixture +def service() -> PluginManager: + """Provides a clean PluginManager instance for each test.""" + return PluginManager() + + +@pytest.fixture +def plugin1() -> TestPlugin: + """Provides a clean instance of our test plugin named 'plugin1'.""" + return TestPlugin(name="plugin1") + + +@pytest.fixture +def plugin2() -> TestPlugin: + """Provides a clean instance of our test plugin named 'plugin2'.""" + return TestPlugin(name="plugin2") + + +def test_register_and_get_plugin(service: PluginManager, plugin1: TestPlugin): + """Tests successful registration and retrieval of a plugin.""" + service.register_plugin(plugin1) + + assert len(service.plugins) == 1 + assert service.plugins[0] is plugin1 + assert service.get_plugin("plugin1") is plugin1 + + +def test_register_duplicate_plugin_name_raises_value_error( + service: PluginManager, plugin1: TestPlugin +): + """Tests that registering a plugin with a duplicate name raises an error.""" + plugin1_duplicate = TestPlugin(name="plugin1") + service.register_plugin(plugin1) + + with pytest.raises( + ValueError, match="Plugin with name 'plugin1' already registered." + ): + service.register_plugin(plugin1_duplicate) + + +@pytest.mark.asyncio +async def test_early_exit_stops_subsequent_plugins( + service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin +): + """Tests the core "early exit" logic: if a plugin returns a value, + + subsequent plugins for that callback should not be executed. + """ + # Configure plugin1 to return a value, simulating a cache hit. + mock_response = Mock(spec=LlmResponse) + plugin1.return_values["before_run_callback"] = mock_response + + service.register_plugin(plugin1) + service.register_plugin(plugin2) + + # Execute the callback chain. + result = await service.run_before_run_callback(invocation_context=Mock()) + + # Assert that the final result is the one returned by the first plugin. + assert result is mock_response + # Assert that the first plugin was called. + assert "before_run_callback" in plugin1.call_log + # CRITICAL: Assert that the second plugin was never called. + assert "before_run_callback" not in plugin2.call_log + + +@pytest.mark.asyncio +async def test_normal_flow_all_plugins_are_called( + service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin +): + """Tests that if no plugin returns a value, all plugins in the chain + + are executed in order. + """ + # By default, plugins are configured to return None. + service.register_plugin(plugin1) + service.register_plugin(plugin2) + + result = await service.run_before_run_callback(invocation_context=Mock()) + + # The final result should be None as no plugin interrupted the flow. + assert result is None + # Both plugins must have been called. + assert "before_run_callback" in plugin1.call_log + assert "before_run_callback" in plugin2.call_log + + +@pytest.mark.asyncio +async def test_plugin_exception_is_wrapped_in_runtime_error( + service: PluginManager, plugin1: TestPlugin +): + """Tests that if a plugin callback raises an exception, the PluginManager + + catches it and raises a descriptive RuntimeError. + """ + # Configure the plugin to raise an error during a specific callback. + original_exception = ValueError("Something went wrong inside the plugin!") + plugin1.exceptions_to_raise["before_run_callback"] = original_exception + service.register_plugin(plugin1) + + with pytest.raises(RuntimeError) as excinfo: + await service.run_before_run_callback(invocation_context=Mock()) + + # Check that the error message is informative. + assert "Error in plugin 'plugin1'" in str(excinfo.value) + assert "before_run_callback" in str(excinfo.value) + # Check that the original exception is chained for better tracebacks. + assert excinfo.value.__cause__ is original_exception + + +@pytest.mark.asyncio +async def test_all_callbacks_are_supported( + service: PluginManager, plugin1: TestPlugin +): + """Tests that all callbacks defined in the BasePlugin interface are supported + + by the PluginManager. + """ + service.register_plugin(plugin1) + mock_context = Mock() + mock_user_message = Mock() + + # Test all callbacks + await service.run_on_user_message_callback( + user_message=mock_user_message, invocation_context=mock_context + ) + await service.run_before_run_callback(invocation_context=mock_context) + await service.run_after_run_callback(invocation_context=mock_context) + await service.run_on_event_callback( + invocation_context=mock_context, event=mock_context + ) + await service.run_before_agent_callback( + agent=mock_context, callback_context=mock_context + ) + await service.run_after_agent_callback( + agent=mock_context, callback_context=mock_context + ) + await service.run_before_tool_callback( + tool=mock_context, tool_args={}, tool_context=mock_context + ) + await service.run_after_tool_callback( + tool=mock_context, tool_args={}, tool_context=mock_context, result={} + ) + await service.run_before_model_callback( + callback_context=mock_context, llm_request=mock_context + ) + await service.run_after_model_callback( + callback_context=mock_context, llm_response=mock_context + ) + + # Verify all callbacks were logged + expected_callbacks = [ + "on_user_message_callback", + "before_run_callback", + "after_run_callback", + "on_event_callback", + "before_agent_callback", + "after_agent_callback", + "before_tool_callback", + "after_tool_callback", + "before_model_callback", + "after_model_callback", + ] + assert set(plugin1.call_log) == set(expected_callbacks)