fix: Refactor context filtering to better handle multi-turn invocations

The definition of an "invocation" for context filtering has been updated. An invocation now starts with a user message and can include multiple model turns (like the tool calls and responses) until the next user message. The filtering logic has been rewritten to identify invocation start points based on human user messages

Close #4296

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 866023290
This commit is contained in:
George Weale
2026-02-05 10:57:22 -08:00
committed by Copybara-Service
parent a08bf62b95
commit 9b112e2d13
2 changed files with 112 additions and 33 deletions
+50 -23
View File
@@ -17,13 +17,11 @@ from __future__ import annotations
from collections.abc import Sequence
import logging
from typing import Callable
from typing import List
from typing import Optional
from google.genai import types
from ..agents.callback_context import CallbackContext
from ..events.event import Event
from ..models.llm_request import LlmRequest
from ..models.llm_response import LlmResponse
from .base_plugin import BasePlugin
@@ -62,21 +60,61 @@ def _adjust_split_index_to_avoid_orphaned_function_responses(
return 0
def _is_function_response_content(content: types.Content) -> bool:
"""Returns whether a content contains function responses."""
return bool(content.parts) and any(
part.function_response is not None for part in content.parts
)
def _is_human_user_content(content: types.Content) -> bool:
"""Returns whether a content represents user input (not tool output)."""
return content.role == "user" and not _is_function_response_content(content)
def _get_invocation_start_indices(
contents: Sequence[types.Content],
) -> list[int]:
"""Returns indices that begin a user-started invocation.
An invocation begins with one or more consecutive user messages. Tool outputs
(function responses) are role="user" but are *not* considered invocation
starts.
Args:
contents: Full conversation contents in chronological order.
Returns:
A list of indices where each index marks the beginning of an invocation.
"""
invocation_start_indices = []
previous_was_human_user = False
for i, content in enumerate(contents):
is_human_user = _is_human_user_content(content)
if is_human_user and not previous_was_human_user:
invocation_start_indices.append(i)
previous_was_human_user = is_human_user
return invocation_start_indices
class ContextFilterPlugin(BasePlugin):
"""A plugin that filters the LLM context to reduce its size."""
def __init__(
self,
num_invocations_to_keep: Optional[int] = None,
custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None,
custom_filter: Optional[
Callable[[list[types.Content]], list[types.Content]]
] = None,
name: str = "context_filter_plugin",
):
"""Initializes the context management plugin.
Args:
num_invocations_to_keep: The number of last invocations to keep. An
invocation is defined as one or more consecutive user messages followed
by a model response.
invocation starts with one or more consecutive user messages and can
contain multiple model turns (e.g. tool calls) until the next user
message starts a new invocation.
custom_filter: A function to filter the context.
name: The name of the plugin instance.
"""
@@ -89,27 +127,16 @@ class ContextFilterPlugin(BasePlugin):
) -> Optional[LlmResponse]:
"""Filters the LLM request's context before it is sent to the model."""
try:
contents = llm_request.contents
contents: list[types.Content] = llm_request.contents
if (
self._num_invocations_to_keep is not None
and self._num_invocations_to_keep > 0
):
num_model_turns = sum(1 for c in contents if c.role == "model")
if num_model_turns >= self._num_invocations_to_keep:
model_turns_to_find = self._num_invocations_to_keep
split_index = 0
for i in range(len(contents) - 1, -1, -1):
if contents[i].role == "model":
model_turns_to_find -= 1
if model_turns_to_find == 0:
start_index = i
while (
start_index > 0 and contents[start_index - 1].role == "user"
):
start_index -= 1
split_index = start_index
break
invocation_start_indices = _get_invocation_start_indices(contents)
if len(invocation_start_indices) > self._num_invocations_to_keep:
split_index = invocation_start_indices[-self._num_invocations_to_keep]
# Adjust split_index to avoid orphaned function_responses.
split_index = (
_adjust_split_index_to_avoid_orphaned_function_responses(
@@ -122,7 +149,7 @@ class ContextFilterPlugin(BasePlugin):
contents = self._custom_filter(contents)
llm_request.contents = contents
except Exception as e:
logger.error(f"Failed to reduce context for request: {e}")
except Exception:
logger.exception("Failed to reduce context for request")
return None
@@ -14,7 +14,7 @@
"""Unit tests for the ContextFilteringPlugin."""
from unittest.mock import Mock
from unittest import mock
from google.adk.agents.callback_context import CallbackContext
from google.adk.models.llm_request import LlmRequest
@@ -40,7 +40,8 @@ async def test_filter_last_n_invocations():
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
assert len(llm_request.contents) == 2
@@ -65,7 +66,8 @@ async def test_filter_with_function():
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
assert len(llm_request.contents) == 2
@@ -93,7 +95,8 @@ async def test_filter_with_function_and_last_n_invocations():
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
assert len(llm_request.contents) == 0
@@ -111,7 +114,8 @@ async def test_no_filtering_when_no_options_provided():
original_contents = list(llm_request.contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
assert llm_request.contents == original_contents
@@ -131,7 +135,8 @@ async def test_last_n_invocations_with_multiple_user_turns():
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
assert len(llm_request.contents) == 3
@@ -157,7 +162,8 @@ async def test_last_n_invocations_more_than_existing_invocations():
original_contents = list(llm_request.contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
assert llm_request.contents == original_contents
@@ -179,7 +185,8 @@ async def test_filter_function_raises_exception():
original_contents = list(llm_request.contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
assert llm_request.contents == original_contents
@@ -237,7 +244,8 @@ async def test_filter_preserves_function_call_response_pairs():
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
# Verify function_call for call_1 is included (not orphaned function_response)
@@ -276,7 +284,8 @@ async def test_filter_with_nested_function_calls():
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
# Verify no orphaned function_responses
@@ -290,4 +299,47 @@ async def test_filter_with_nested_function_calls():
if part.function_response and part.function_response.id:
response_ids.add(part.function_response.id)
texts = []
for content in llm_request.contents:
if content.parts:
for part in content.parts:
if part.text:
texts.append(part.text)
assert "Do task" in texts
assert "Done with tasks" in texts
assert "Hello" not in texts
assert "Hi!" not in texts
assert response_ids.issubset(call_ids)
@pytest.mark.asyncio
async def test_last_invocation_with_tool_call_keeps_user_prompt():
"""Tests that multi-model-turn invocations keep the initial user prompt."""
plugin = ContextFilterPlugin(num_invocations_to_keep=1)
contents = [
_create_content("user", "user_prompt_1"),
_create_content("model", "model_response_1"),
_create_content("user", "user_prompt_2"),
_create_function_call_content("get_weather", "call_1"),
_create_function_response_content("get_weather", "call_1"),
_create_content("model", "final_answer_2"),
]
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=mock.create_autospec(CallbackContext, instance=True),
llm_request=llm_request,
)
texts = []
for content in llm_request.contents:
if content.parts:
for part in content.parts:
if part.text:
texts.append(part.text)
assert "user_prompt_2" in texts
assert "final_answer_2" in texts