diff --git a/src/google/adk/plugins/context_filter_plugin.py b/src/google/adk/plugins/context_filter_plugin.py index d76bff43..48974923 100644 --- a/src/google/adk/plugins/context_filter_plugin.py +++ b/src/google/adk/plugins/context_filter_plugin.py @@ -17,13 +17,11 @@ from __future__ import annotations from collections.abc import Sequence import logging from typing import Callable -from typing import List from typing import Optional from google.genai import types from ..agents.callback_context import CallbackContext -from ..events.event import Event from ..models.llm_request import LlmRequest from ..models.llm_response import LlmResponse from .base_plugin import BasePlugin @@ -62,21 +60,61 @@ def _adjust_split_index_to_avoid_orphaned_function_responses( return 0 +def _is_function_response_content(content: types.Content) -> bool: + """Returns whether a content contains function responses.""" + return bool(content.parts) and any( + part.function_response is not None for part in content.parts + ) + + +def _is_human_user_content(content: types.Content) -> bool: + """Returns whether a content represents user input (not tool output).""" + return content.role == "user" and not _is_function_response_content(content) + + +def _get_invocation_start_indices( + contents: Sequence[types.Content], +) -> list[int]: + """Returns indices that begin a user-started invocation. + + An invocation begins with one or more consecutive user messages. Tool outputs + (function responses) are role="user" but are *not* considered invocation + starts. + + Args: + contents: Full conversation contents in chronological order. + + Returns: + A list of indices where each index marks the beginning of an invocation. + """ + invocation_start_indices = [] + previous_was_human_user = False + for i, content in enumerate(contents): + is_human_user = _is_human_user_content(content) + if is_human_user and not previous_was_human_user: + invocation_start_indices.append(i) + previous_was_human_user = is_human_user + return invocation_start_indices + + class ContextFilterPlugin(BasePlugin): """A plugin that filters the LLM context to reduce its size.""" def __init__( self, num_invocations_to_keep: Optional[int] = None, - custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, + custom_filter: Optional[ + Callable[[list[types.Content]], list[types.Content]] + ] = None, name: str = "context_filter_plugin", ): """Initializes the context management plugin. Args: num_invocations_to_keep: The number of last invocations to keep. An - invocation is defined as one or more consecutive user messages followed - by a model response. + invocation starts with one or more consecutive user messages and can + contain multiple model turns (e.g. tool calls) until the next user + message starts a new invocation. custom_filter: A function to filter the context. name: The name of the plugin instance. """ @@ -89,27 +127,16 @@ class ContextFilterPlugin(BasePlugin): ) -> Optional[LlmResponse]: """Filters the LLM request's context before it is sent to the model.""" try: - contents = llm_request.contents + contents: list[types.Content] = llm_request.contents if ( self._num_invocations_to_keep is not None and self._num_invocations_to_keep > 0 ): - num_model_turns = sum(1 for c in contents if c.role == "model") - if num_model_turns >= self._num_invocations_to_keep: - model_turns_to_find = self._num_invocations_to_keep - split_index = 0 - for i in range(len(contents) - 1, -1, -1): - if contents[i].role == "model": - model_turns_to_find -= 1 - if model_turns_to_find == 0: - start_index = i - while ( - start_index > 0 and contents[start_index - 1].role == "user" - ): - start_index -= 1 - split_index = start_index - break + invocation_start_indices = _get_invocation_start_indices(contents) + if len(invocation_start_indices) > self._num_invocations_to_keep: + split_index = invocation_start_indices[-self._num_invocations_to_keep] + # Adjust split_index to avoid orphaned function_responses. split_index = ( _adjust_split_index_to_avoid_orphaned_function_responses( @@ -122,7 +149,7 @@ class ContextFilterPlugin(BasePlugin): contents = self._custom_filter(contents) llm_request.contents = contents - except Exception as e: - logger.error(f"Failed to reduce context for request: {e}") + except Exception: + logger.exception("Failed to reduce context for request") return None diff --git a/tests/unittests/plugins/test_context_filtering_plugin.py b/tests/unittests/plugins/test_context_filtering_plugin.py index e821b7e7..b3393245 100644 --- a/tests/unittests/plugins/test_context_filtering_plugin.py +++ b/tests/unittests/plugins/test_context_filtering_plugin.py @@ -14,7 +14,7 @@ """Unit tests for the ContextFilteringPlugin.""" -from unittest.mock import Mock +from unittest import mock from google.adk.agents.callback_context import CallbackContext from google.adk.models.llm_request import LlmRequest @@ -40,7 +40,8 @@ async def test_filter_last_n_invocations(): llm_request = LlmRequest(contents=contents) await plugin.before_model_callback( - callback_context=Mock(spec=CallbackContext), llm_request=llm_request + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, ) assert len(llm_request.contents) == 2 @@ -65,7 +66,8 @@ async def test_filter_with_function(): llm_request = LlmRequest(contents=contents) await plugin.before_model_callback( - callback_context=Mock(spec=CallbackContext), llm_request=llm_request + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, ) assert len(llm_request.contents) == 2 @@ -93,7 +95,8 @@ async def test_filter_with_function_and_last_n_invocations(): llm_request = LlmRequest(contents=contents) await plugin.before_model_callback( - callback_context=Mock(spec=CallbackContext), llm_request=llm_request + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, ) assert len(llm_request.contents) == 0 @@ -111,7 +114,8 @@ async def test_no_filtering_when_no_options_provided(): original_contents = list(llm_request.contents) await plugin.before_model_callback( - callback_context=Mock(spec=CallbackContext), llm_request=llm_request + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, ) assert llm_request.contents == original_contents @@ -131,7 +135,8 @@ async def test_last_n_invocations_with_multiple_user_turns(): llm_request = LlmRequest(contents=contents) await plugin.before_model_callback( - callback_context=Mock(spec=CallbackContext), llm_request=llm_request + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, ) assert len(llm_request.contents) == 3 @@ -157,7 +162,8 @@ async def test_last_n_invocations_more_than_existing_invocations(): original_contents = list(llm_request.contents) await plugin.before_model_callback( - callback_context=Mock(spec=CallbackContext), llm_request=llm_request + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, ) assert llm_request.contents == original_contents @@ -179,7 +185,8 @@ async def test_filter_function_raises_exception(): original_contents = list(llm_request.contents) await plugin.before_model_callback( - callback_context=Mock(spec=CallbackContext), llm_request=llm_request + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, ) assert llm_request.contents == original_contents @@ -237,7 +244,8 @@ async def test_filter_preserves_function_call_response_pairs(): llm_request = LlmRequest(contents=contents) await plugin.before_model_callback( - callback_context=Mock(spec=CallbackContext), llm_request=llm_request + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, ) # Verify function_call for call_1 is included (not orphaned function_response) @@ -276,7 +284,8 @@ async def test_filter_with_nested_function_calls(): llm_request = LlmRequest(contents=contents) await plugin.before_model_callback( - callback_context=Mock(spec=CallbackContext), llm_request=llm_request + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, ) # Verify no orphaned function_responses @@ -290,4 +299,47 @@ async def test_filter_with_nested_function_calls(): if part.function_response and part.function_response.id: response_ids.add(part.function_response.id) + texts = [] + for content in llm_request.contents: + if content.parts: + for part in content.parts: + if part.text: + texts.append(part.text) + + assert "Do task" in texts + assert "Done with tasks" in texts + assert "Hello" not in texts + assert "Hi!" not in texts + assert response_ids.issubset(call_ids) + + +@pytest.mark.asyncio +async def test_last_invocation_with_tool_call_keeps_user_prompt(): + """Tests that multi-model-turn invocations keep the initial user prompt.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=1) + + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_function_call_content("get_weather", "call_1"), + _create_function_response_content("get_weather", "call_1"), + _create_content("model", "final_answer_2"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=mock.create_autospec(CallbackContext, instance=True), + llm_request=llm_request, + ) + + texts = [] + for content in llm_request.contents: + if content.parts: + for part in content.parts: + if part.text: + texts.append(part.text) + + assert "user_prompt_2" in texts + assert "final_answer_2" in texts