From a06bf278cbc89f521c187ed51b032d82ffdafe2d Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Wed, 17 Sep 2025 19:28:19 -0700 Subject: [PATCH] feat: Adding the ContextFilterPlugin This commit introduces a new ContextFilterPlugin which allows for filtering the LlmRequest contents before they are sent to the LLM. This helps in managing and potentially reducing the size of the LLM context. The plugin provides two primary filtering mechanisms: num_invocations_to_keep: Keeps only the specified number of the most recent user-model invocations. An invocation is defined as one or more user messages followed by a model response. custom_filter: Allows for a user-defined callable to be applied to the contents for more flexible filtering. Unit tests have been added to cover the different filtering scenarios, including: Filtering by the last N invocations. Filtering using a custom function. Combining both filtering methods. Handling cases with multiple user turns in a single invocation. Ensuring no filtering occurs when options are not provided. Gracefully handling exceptions from custom filter functions." For example, when num_of_innovacations=2: ----------------------------------------------------------- Contents: {"parts":[{"text":"9"}],"role":"user"} {"parts":[{"text":"I am sorry, I cannot fulfill this request. I need more information on what you would like me to do. I can roll a die or check prime numbers.\n"}],"role":"model"} {"parts":[{"text":"1"}],"role":"user"} {"parts":[{"text":"I am sorry, I cannot fulfill this request. I need more information on what you would like me to do. I can roll a die or check prime numbers.\n"}],"role":"model"} {"parts":[{"text":"10"}],"role":"user"} ----------------------------------------------------------- PiperOrigin-RevId: 808355316 --- .../adk/plugins/context_filter_plugin.py | 88 +++++++++ .../plugins/test_context_filtering_plugin.py | 185 ++++++++++++++++++ 2 files changed, 273 insertions(+) create mode 100644 src/google/adk/plugins/context_filter_plugin.py create mode 100644 tests/unittests/plugins/test_context_filtering_plugin.py diff --git a/src/google/adk/plugins/context_filter_plugin.py b/src/google/adk/plugins/context_filter_plugin.py new file mode 100644 index 00000000..b778de02 --- /dev/null +++ b/src/google/adk/plugins/context_filter_plugin.py @@ -0,0 +1,88 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Callable +from typing import List +from typing import Optional + +from ..agents.callback_context import CallbackContext +from ..events.event import Event +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from .base_plugin import BasePlugin + +logger = logging.getLogger("google_adk." + __name__) + + +class ContextFilterPlugin(BasePlugin): + """A plugin that filters the LLM context to reduce its size.""" + + def __init__( + self, + num_invocations_to_keep: Optional[int] = None, + custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, + name: str = "context_filter_plugin", + ): + """Initializes the context management plugin. + + Args: + num_invocations_to_keep: The number of last invocations to keep. An + invocation is defined as one or more consecutive user messages followed + by a model response. + custom_filter: A function to filter the context. + name: The name of the plugin instance. + """ + super().__init__(name) + self._num_invocations_to_keep = num_invocations_to_keep + self._custom_filter = custom_filter + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Filters the LLM request's context before it is sent to the model.""" + try: + contents = llm_request.contents + + if ( + self._num_invocations_to_keep is not None + and self._num_invocations_to_keep > 0 + ): + num_model_turns = sum(1 for c in contents if c.role == "model") + if num_model_turns >= self._num_invocations_to_keep: + model_turns_to_find = self._num_invocations_to_keep + split_index = 0 + for i in range(len(contents) - 1, -1, -1): + if contents[i].role == "model": + model_turns_to_find -= 1 + if model_turns_to_find == 0: + start_index = i + while ( + start_index > 0 and contents[start_index - 1].role == "user" + ): + start_index -= 1 + split_index = start_index + break + contents = contents[split_index:] + + if self._custom_filter: + contents = self._custom_filter(contents) + + llm_request.contents = contents + except Exception as e: + logger.error(f"Failed to reduce context for request: {e}") + + return None diff --git a/tests/unittests/plugins/test_context_filtering_plugin.py b/tests/unittests/plugins/test_context_filtering_plugin.py new file mode 100644 index 00000000..f9c8222e --- /dev/null +++ b/tests/unittests/plugins/test_context_filtering_plugin.py @@ -0,0 +1,185 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the ContextFilteringPlugin.""" + +from unittest.mock import Mock + +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.plugins.context_filter_plugin import ContextFilterPlugin +from google.genai import types +import pytest + + +def _create_content(role: str, text: str) -> types.Content: + return types.Content(parts=[types.Part(text=text)], role=role) + + +@pytest.mark.asyncio +async def test_filter_last_n_invocations(): + """Tests that the context is truncated to the last N invocations.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=1) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert len(llm_request.contents) == 2 + assert llm_request.contents[0].parts[0].text == "user_prompt_2" + assert llm_request.contents[1].parts[0].text == "model_response_2" + + +@pytest.mark.asyncio +async def test_filter_with_function(): + """Tests that a custom filter function is applied to the context.""" + + def remove_model_responses(contents): + return [c for c in contents if c.role != "model"] + + plugin = ContextFilterPlugin(custom_filter=remove_model_responses) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert len(llm_request.contents) == 2 + assert all(c.role == "user" for c in llm_request.contents) + + +@pytest.mark.asyncio +async def test_filter_with_function_and_last_n_invocations(): + """Tests that both filtering methods are applied correctly.""" + + def remove_first_invocation(contents): + return contents[2:] + + plugin = ContextFilterPlugin( + num_invocations_to_keep=1, custom_filter=remove_first_invocation + ) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + _create_content("user", "user_prompt_3"), + _create_content("model", "model_response_3"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert len(llm_request.contents) == 0 + + +@pytest.mark.asyncio +async def test_no_filtering_when_no_options_provided(): + """Tests that no filtering occurs when no options are provided.""" + plugin = ContextFilterPlugin() + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + ] + llm_request = LlmRequest(contents=contents) + original_contents = list(llm_request.contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert llm_request.contents == original_contents + + +@pytest.mark.asyncio +async def test_last_n_invocations_with_multiple_user_turns(): + """Tests filtering with multiple user turns in a single invocation.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=1) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2a"), + _create_content("user", "user_prompt_2b"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert len(llm_request.contents) == 3 + assert llm_request.contents[0].parts[0].text == "user_prompt_2a" + assert llm_request.contents[1].parts[0].text == "user_prompt_2b" + assert llm_request.contents[2].parts[0].text == "model_response_2" + + +@pytest.mark.asyncio +async def test_last_n_invocations_more_than_existing_invocations(): + """Tests that no filtering occurs if last_n_invocations is greater than + + the number of invocations. + """ + plugin = ContextFilterPlugin(num_invocations_to_keep=3) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + original_contents = list(llm_request.contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert llm_request.contents == original_contents + + +@pytest.mark.asyncio +async def test_filter_function_raises_exception(): + """Tests that the plugin handles exceptions from the filter function.""" + + def faulty_filter(contents): + raise ValueError("Filter error") + + plugin = ContextFilterPlugin(custom_filter=faulty_filter) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + ] + llm_request = LlmRequest(contents=contents) + original_contents = list(llm_request.contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert llm_request.contents == original_contents