feat: Adding the ContextFilterPlugin

This commit introduces a new ContextFilterPlugin which allows for filtering the LlmRequest contents before they are sent to the LLM. This helps in managing and potentially reducing the size of the LLM context.

The plugin provides two primary filtering mechanisms:

num_invocations_to_keep: Keeps only the specified number of the most recent user-model invocations. An invocation is defined as one or more user messages followed by a model response.
custom_filter: Allows for a user-defined callable to be applied to the contents for more flexible filtering.
Unit tests have been added to cover the different filtering scenarios, including:

Filtering by the last N invocations.
Filtering using a custom function.
Combining both filtering methods.
Handling cases with multiple user turns in a single invocation.
Ensuring no filtering occurs when options are not provided.
Gracefully handling exceptions from custom filter functions."

For example, when num_of_innovacations=2:
-----------------------------------------------------------
Contents:
{"parts":[{"text":"9"}],"role":"user"}
{"parts":[{"text":"I am sorry, I cannot fulfill this request. I need more information on what you would like me to do. I can roll a die or check prime numbers.\n"}],"role":"model"}
{"parts":[{"text":"1"}],"role":"user"}
{"parts":[{"text":"I am sorry, I cannot fulfill this request. I need more information on what you would like me to do. I can roll a die or check prime numbers.\n"}],"role":"model"}
{"parts":[{"text":"10"}],"role":"user"}
-----------------------------------------------------------
PiperOrigin-RevId: 808355316
This commit is contained in:
Hangfei Lin
2025-09-17 19:28:19 -07:00
committed by Copybara-Service
parent 10cf377494
commit a06bf278cb
2 changed files with 273 additions and 0 deletions
@@ -0,0 +1,88 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Callable
from typing import List
from typing import Optional
from ..agents.callback_context import CallbackContext
from ..events.event import Event
from ..models.llm_request import LlmRequest
from ..models.llm_response import LlmResponse
from .base_plugin import BasePlugin
logger = logging.getLogger("google_adk." + __name__)
class ContextFilterPlugin(BasePlugin):
"""A plugin that filters the LLM context to reduce its size."""
def __init__(
self,
num_invocations_to_keep: Optional[int] = None,
custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None,
name: str = "context_filter_plugin",
):
"""Initializes the context management plugin.
Args:
num_invocations_to_keep: The number of last invocations to keep. An
invocation is defined as one or more consecutive user messages followed
by a model response.
custom_filter: A function to filter the context.
name: The name of the plugin instance.
"""
super().__init__(name)
self._num_invocations_to_keep = num_invocations_to_keep
self._custom_filter = custom_filter
async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
"""Filters the LLM request's context before it is sent to the model."""
try:
contents = llm_request.contents
if (
self._num_invocations_to_keep is not None
and self._num_invocations_to_keep > 0
):
num_model_turns = sum(1 for c in contents if c.role == "model")
if num_model_turns >= self._num_invocations_to_keep:
model_turns_to_find = self._num_invocations_to_keep
split_index = 0
for i in range(len(contents) - 1, -1, -1):
if contents[i].role == "model":
model_turns_to_find -= 1
if model_turns_to_find == 0:
start_index = i
while (
start_index > 0 and contents[start_index - 1].role == "user"
):
start_index -= 1
split_index = start_index
break
contents = contents[split_index:]
if self._custom_filter:
contents = self._custom_filter(contents)
llm_request.contents = contents
except Exception as e:
logger.error(f"Failed to reduce context for request: {e}")
return None
@@ -0,0 +1,185 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the ContextFilteringPlugin."""
from unittest.mock import Mock
from google.adk.agents.callback_context import CallbackContext
from google.adk.models.llm_request import LlmRequest
from google.adk.plugins.context_filter_plugin import ContextFilterPlugin
from google.genai import types
import pytest
def _create_content(role: str, text: str) -> types.Content:
return types.Content(parts=[types.Part(text=text)], role=role)
@pytest.mark.asyncio
async def test_filter_last_n_invocations():
"""Tests that the context is truncated to the last N invocations."""
plugin = ContextFilterPlugin(num_invocations_to_keep=1)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2"),
_create_content("model", "model_response_2"),
]
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)
assert len(llm_request.contents) == 2
assert llm_request.contents[0].parts[0].text == "user_prompt_2"
assert llm_request.contents[1].parts[0].text == "model_response_2"
@pytest.mark.asyncio
async def test_filter_with_function():
"""Tests that a custom filter function is applied to the context."""
def remove_model_responses(contents):
return [c for c in contents if c.role != "model"]
plugin = ContextFilterPlugin(custom_filter=remove_model_responses)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2"),
_create_content("model", "model_response_2"),
]
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)
assert len(llm_request.contents) == 2
assert all(c.role == "user" for c in llm_request.contents)
@pytest.mark.asyncio
async def test_filter_with_function_and_last_n_invocations():
"""Tests that both filtering methods are applied correctly."""
def remove_first_invocation(contents):
return contents[2:]
plugin = ContextFilterPlugin(
num_invocations_to_keep=1, custom_filter=remove_first_invocation
)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2"),
_create_content("model", "model_response_2"),
_create_content("user", "user_prompt_3"),
_create_content("model", "model_response_3"),
]
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)
assert len(llm_request.contents) == 0
@pytest.mark.asyncio
async def test_no_filtering_when_no_options_provided():
"""Tests that no filtering occurs when no options are provided."""
plugin = ContextFilterPlugin()
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
]
llm_request = LlmRequest(contents=contents)
original_contents = list(llm_request.contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)
assert llm_request.contents == original_contents
@pytest.mark.asyncio
async def test_last_n_invocations_with_multiple_user_turns():
"""Tests filtering with multiple user turns in a single invocation."""
plugin = ContextFilterPlugin(num_invocations_to_keep=1)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2a"),
_create_content("user", "user_prompt_2b"),
_create_content("model", "model_response_2"),
]
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)
assert len(llm_request.contents) == 3
assert llm_request.contents[0].parts[0].text == "user_prompt_2a"
assert llm_request.contents[1].parts[0].text == "user_prompt_2b"
assert llm_request.contents[2].parts[0].text == "model_response_2"
@pytest.mark.asyncio
async def test_last_n_invocations_more_than_existing_invocations():
"""Tests that no filtering occurs if last_n_invocations is greater than
the number of invocations.
"""
plugin = ContextFilterPlugin(num_invocations_to_keep=3)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2"),
_create_content("model", "model_response_2"),
]
llm_request = LlmRequest(contents=contents)
original_contents = list(llm_request.contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)
assert llm_request.contents == original_contents
@pytest.mark.asyncio
async def test_filter_function_raises_exception():
"""Tests that the plugin handles exceptions from the filter function."""
def faulty_filter(contents):
raise ValueError("Filter error")
plugin = ContextFilterPlugin(custom_filter=faulty_filter)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
]
llm_request = LlmRequest(contents=contents)
original_contents = list(llm_request.contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)
assert llm_request.contents == original_contents