diff --git a/src/google/adk/plugins/context_filter_plugin.py b/src/google/adk/plugins/context_filter_plugin.py new file mode 100644 index 00000000..b778de02 --- /dev/null +++ b/src/google/adk/plugins/context_filter_plugin.py @@ -0,0 +1,88 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Callable +from typing import List +from typing import Optional + +from ..agents.callback_context import CallbackContext +from ..events.event import Event +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from .base_plugin import BasePlugin + +logger = logging.getLogger("google_adk." + __name__) + + +class ContextFilterPlugin(BasePlugin): + """A plugin that filters the LLM context to reduce its size.""" + + def __init__( + self, + num_invocations_to_keep: Optional[int] = None, + custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, + name: str = "context_filter_plugin", + ): + """Initializes the context management plugin. + + Args: + num_invocations_to_keep: The number of last invocations to keep. An + invocation is defined as one or more consecutive user messages followed + by a model response. + custom_filter: A function to filter the context. + name: The name of the plugin instance. + """ + super().__init__(name) + self._num_invocations_to_keep = num_invocations_to_keep + self._custom_filter = custom_filter + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Filters the LLM request's context before it is sent to the model.""" + try: + contents = llm_request.contents + + if ( + self._num_invocations_to_keep is not None + and self._num_invocations_to_keep > 0 + ): + num_model_turns = sum(1 for c in contents if c.role == "model") + if num_model_turns >= self._num_invocations_to_keep: + model_turns_to_find = self._num_invocations_to_keep + split_index = 0 + for i in range(len(contents) - 1, -1, -1): + if contents[i].role == "model": + model_turns_to_find -= 1 + if model_turns_to_find == 0: + start_index = i + while ( + start_index > 0 and contents[start_index - 1].role == "user" + ): + start_index -= 1 + split_index = start_index + break + contents = contents[split_index:] + + if self._custom_filter: + contents = self._custom_filter(contents) + + llm_request.contents = contents + except Exception as e: + logger.error(f"Failed to reduce context for request: {e}") + + return None diff --git a/tests/unittests/plugins/test_context_filtering_plugin.py b/tests/unittests/plugins/test_context_filtering_plugin.py new file mode 100644 index 00000000..f9c8222e --- /dev/null +++ b/tests/unittests/plugins/test_context_filtering_plugin.py @@ -0,0 +1,185 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the ContextFilteringPlugin.""" + +from unittest.mock import Mock + +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.plugins.context_filter_plugin import ContextFilterPlugin +from google.genai import types +import pytest + + +def _create_content(role: str, text: str) -> types.Content: + return types.Content(parts=[types.Part(text=text)], role=role) + + +@pytest.mark.asyncio +async def test_filter_last_n_invocations(): + """Tests that the context is truncated to the last N invocations.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=1) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert len(llm_request.contents) == 2 + assert llm_request.contents[0].parts[0].text == "user_prompt_2" + assert llm_request.contents[1].parts[0].text == "model_response_2" + + +@pytest.mark.asyncio +async def test_filter_with_function(): + """Tests that a custom filter function is applied to the context.""" + + def remove_model_responses(contents): + return [c for c in contents if c.role != "model"] + + plugin = ContextFilterPlugin(custom_filter=remove_model_responses) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert len(llm_request.contents) == 2 + assert all(c.role == "user" for c in llm_request.contents) + + +@pytest.mark.asyncio +async def test_filter_with_function_and_last_n_invocations(): + """Tests that both filtering methods are applied correctly.""" + + def remove_first_invocation(contents): + return contents[2:] + + plugin = ContextFilterPlugin( + num_invocations_to_keep=1, custom_filter=remove_first_invocation + ) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + _create_content("user", "user_prompt_3"), + _create_content("model", "model_response_3"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert len(llm_request.contents) == 0 + + +@pytest.mark.asyncio +async def test_no_filtering_when_no_options_provided(): + """Tests that no filtering occurs when no options are provided.""" + plugin = ContextFilterPlugin() + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + ] + llm_request = LlmRequest(contents=contents) + original_contents = list(llm_request.contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert llm_request.contents == original_contents + + +@pytest.mark.asyncio +async def test_last_n_invocations_with_multiple_user_turns(): + """Tests filtering with multiple user turns in a single invocation.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=1) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2a"), + _create_content("user", "user_prompt_2b"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert len(llm_request.contents) == 3 + assert llm_request.contents[0].parts[0].text == "user_prompt_2a" + assert llm_request.contents[1].parts[0].text == "user_prompt_2b" + assert llm_request.contents[2].parts[0].text == "model_response_2" + + +@pytest.mark.asyncio +async def test_last_n_invocations_more_than_existing_invocations(): + """Tests that no filtering occurs if last_n_invocations is greater than + + the number of invocations. + """ + plugin = ContextFilterPlugin(num_invocations_to_keep=3) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + original_contents = list(llm_request.contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert llm_request.contents == original_contents + + +@pytest.mark.asyncio +async def test_filter_function_raises_exception(): + """Tests that the plugin handles exceptions from the filter function.""" + + def faulty_filter(contents): + raise ValueError("Filter error") + + plugin = ContextFilterPlugin(custom_filter=faulty_filter) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + ] + llm_request = LlmRequest(contents=contents) + original_contents = list(llm_request.contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + assert llm_request.contents == original_contents