You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Refactor context filtering to better handle multi-turn invocations
The definition of an "invocation" for context filtering has been updated. An invocation now starts with a user message and can include multiple model turns (like the tool calls and responses) until the next user message. The filtering logic has been rewritten to identify invocation start points based on human user messages Close #4296 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 866023290
This commit is contained in:
committed by
Copybara-Service
parent
a08bf62b95
commit
9b112e2d13
@@ -17,13 +17,11 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from ..events.event import Event
|
||||
from ..models.llm_request import LlmRequest
|
||||
from ..models.llm_response import LlmResponse
|
||||
from .base_plugin import BasePlugin
|
||||
@@ -62,21 +60,61 @@ def _adjust_split_index_to_avoid_orphaned_function_responses(
|
||||
return 0
|
||||
|
||||
|
||||
def _is_function_response_content(content: types.Content) -> bool:
|
||||
"""Returns whether a content contains function responses."""
|
||||
return bool(content.parts) and any(
|
||||
part.function_response is not None for part in content.parts
|
||||
)
|
||||
|
||||
|
||||
def _is_human_user_content(content: types.Content) -> bool:
|
||||
"""Returns whether a content represents user input (not tool output)."""
|
||||
return content.role == "user" and not _is_function_response_content(content)
|
||||
|
||||
|
||||
def _get_invocation_start_indices(
|
||||
contents: Sequence[types.Content],
|
||||
) -> list[int]:
|
||||
"""Returns indices that begin a user-started invocation.
|
||||
|
||||
An invocation begins with one or more consecutive user messages. Tool outputs
|
||||
(function responses) are role="user" but are *not* considered invocation
|
||||
starts.
|
||||
|
||||
Args:
|
||||
contents: Full conversation contents in chronological order.
|
||||
|
||||
Returns:
|
||||
A list of indices where each index marks the beginning of an invocation.
|
||||
"""
|
||||
invocation_start_indices = []
|
||||
previous_was_human_user = False
|
||||
for i, content in enumerate(contents):
|
||||
is_human_user = _is_human_user_content(content)
|
||||
if is_human_user and not previous_was_human_user:
|
||||
invocation_start_indices.append(i)
|
||||
previous_was_human_user = is_human_user
|
||||
return invocation_start_indices
|
||||
|
||||
|
||||
class ContextFilterPlugin(BasePlugin):
|
||||
"""A plugin that filters the LLM context to reduce its size."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_invocations_to_keep: Optional[int] = None,
|
||||
custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None,
|
||||
custom_filter: Optional[
|
||||
Callable[[list[types.Content]], list[types.Content]]
|
||||
] = None,
|
||||
name: str = "context_filter_plugin",
|
||||
):
|
||||
"""Initializes the context management plugin.
|
||||
|
||||
Args:
|
||||
num_invocations_to_keep: The number of last invocations to keep. An
|
||||
invocation is defined as one or more consecutive user messages followed
|
||||
by a model response.
|
||||
invocation starts with one or more consecutive user messages and can
|
||||
contain multiple model turns (e.g. tool calls) until the next user
|
||||
message starts a new invocation.
|
||||
custom_filter: A function to filter the context.
|
||||
name: The name of the plugin instance.
|
||||
"""
|
||||
@@ -89,27 +127,16 @@ class ContextFilterPlugin(BasePlugin):
|
||||
) -> Optional[LlmResponse]:
|
||||
"""Filters the LLM request's context before it is sent to the model."""
|
||||
try:
|
||||
contents = llm_request.contents
|
||||
contents: list[types.Content] = llm_request.contents
|
||||
|
||||
if (
|
||||
self._num_invocations_to_keep is not None
|
||||
and self._num_invocations_to_keep > 0
|
||||
):
|
||||
num_model_turns = sum(1 for c in contents if c.role == "model")
|
||||
if num_model_turns >= self._num_invocations_to_keep:
|
||||
model_turns_to_find = self._num_invocations_to_keep
|
||||
split_index = 0
|
||||
for i in range(len(contents) - 1, -1, -1):
|
||||
if contents[i].role == "model":
|
||||
model_turns_to_find -= 1
|
||||
if model_turns_to_find == 0:
|
||||
start_index = i
|
||||
while (
|
||||
start_index > 0 and contents[start_index - 1].role == "user"
|
||||
):
|
||||
start_index -= 1
|
||||
split_index = start_index
|
||||
break
|
||||
invocation_start_indices = _get_invocation_start_indices(contents)
|
||||
if len(invocation_start_indices) > self._num_invocations_to_keep:
|
||||
split_index = invocation_start_indices[-self._num_invocations_to_keep]
|
||||
|
||||
# Adjust split_index to avoid orphaned function_responses.
|
||||
split_index = (
|
||||
_adjust_split_index_to_avoid_orphaned_function_responses(
|
||||
@@ -122,7 +149,7 @@ class ContextFilterPlugin(BasePlugin):
|
||||
contents = self._custom_filter(contents)
|
||||
|
||||
llm_request.contents = contents
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reduce context for request: {e}")
|
||||
except Exception:
|
||||
logger.exception("Failed to reduce context for request")
|
||||
|
||||
return None
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
"""Unit tests for the ContextFilteringPlugin."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
@@ -40,7 +40,8 @@ async def test_filter_last_n_invocations():
|
||||
llm_request = LlmRequest(contents=contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
assert len(llm_request.contents) == 2
|
||||
@@ -65,7 +66,8 @@ async def test_filter_with_function():
|
||||
llm_request = LlmRequest(contents=contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
assert len(llm_request.contents) == 2
|
||||
@@ -93,7 +95,8 @@ async def test_filter_with_function_and_last_n_invocations():
|
||||
llm_request = LlmRequest(contents=contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
assert len(llm_request.contents) == 0
|
||||
@@ -111,7 +114,8 @@ async def test_no_filtering_when_no_options_provided():
|
||||
original_contents = list(llm_request.contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
assert llm_request.contents == original_contents
|
||||
@@ -131,7 +135,8 @@ async def test_last_n_invocations_with_multiple_user_turns():
|
||||
llm_request = LlmRequest(contents=contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
assert len(llm_request.contents) == 3
|
||||
@@ -157,7 +162,8 @@ async def test_last_n_invocations_more_than_existing_invocations():
|
||||
original_contents = list(llm_request.contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
assert llm_request.contents == original_contents
|
||||
@@ -179,7 +185,8 @@ async def test_filter_function_raises_exception():
|
||||
original_contents = list(llm_request.contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
assert llm_request.contents == original_contents
|
||||
@@ -237,7 +244,8 @@ async def test_filter_preserves_function_call_response_pairs():
|
||||
llm_request = LlmRequest(contents=contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
# Verify function_call for call_1 is included (not orphaned function_response)
|
||||
@@ -276,7 +284,8 @@ async def test_filter_with_nested_function_calls():
|
||||
llm_request = LlmRequest(contents=contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
# Verify no orphaned function_responses
|
||||
@@ -290,4 +299,47 @@ async def test_filter_with_nested_function_calls():
|
||||
if part.function_response and part.function_response.id:
|
||||
response_ids.add(part.function_response.id)
|
||||
|
||||
texts = []
|
||||
for content in llm_request.contents:
|
||||
if content.parts:
|
||||
for part in content.parts:
|
||||
if part.text:
|
||||
texts.append(part.text)
|
||||
|
||||
assert "Do task" in texts
|
||||
assert "Done with tasks" in texts
|
||||
assert "Hello" not in texts
|
||||
assert "Hi!" not in texts
|
||||
|
||||
assert response_ids.issubset(call_ids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last_invocation_with_tool_call_keeps_user_prompt():
|
||||
"""Tests that multi-model-turn invocations keep the initial user prompt."""
|
||||
plugin = ContextFilterPlugin(num_invocations_to_keep=1)
|
||||
|
||||
contents = [
|
||||
_create_content("user", "user_prompt_1"),
|
||||
_create_content("model", "model_response_1"),
|
||||
_create_content("user", "user_prompt_2"),
|
||||
_create_function_call_content("get_weather", "call_1"),
|
||||
_create_function_response_content("get_weather", "call_1"),
|
||||
_create_content("model", "final_answer_2"),
|
||||
]
|
||||
llm_request = LlmRequest(contents=contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=mock.create_autospec(CallbackContext, instance=True),
|
||||
llm_request=llm_request,
|
||||
)
|
||||
|
||||
texts = []
|
||||
for content in llm_request.contents:
|
||||
if content.parts:
|
||||
for part in content.parts:
|
||||
if part.text:
|
||||
texts.append(part.text)
|
||||
|
||||
assert "user_prompt_2" in texts
|
||||
assert "final_answer_2" in texts
|
||||
|
||||
Reference in New Issue
Block a user