diff --git a/src/google/adk/plugins/context_filter_plugin.py b/src/google/adk/plugins/context_filter_plugin.py index b778de02..8b12f92f 100644 --- a/src/google/adk/plugins/context_filter_plugin.py +++ b/src/google/adk/plugins/context_filter_plugin.py @@ -14,11 +14,14 @@ from __future__ import annotations +from collections.abc import Sequence import logging from typing import Callable from typing import List from typing import Optional +from google.genai import types + from ..agents.callback_context import CallbackContext from ..events.event import Event from ..models.llm_request import LlmRequest @@ -28,6 +31,37 @@ from .base_plugin import BasePlugin logger = logging.getLogger("google_adk." + __name__) +def _adjust_split_index_to_avoid_orphaned_function_responses( + contents: Sequence[types.Content], split_index: int +) -> int: + """Moves `split_index` left until function calls/responses stay paired. + + When truncating context, we must avoid keeping a `function_response` while + dropping its matching preceding `function_call`. + + Args: + contents: Full conversation contents in chronological order. + split_index: Candidate split index (keep `contents[split_index:]`). + + Returns: + A (possibly smaller) split index that preserves call/response pairs. + """ + needed_call_ids = set() + for i in range(len(contents) - 1, -1, -1): + parts = contents[i].parts + if parts: + for part in reversed(parts): + if part.function_response and part.function_response.id: + needed_call_ids.add(part.function_response.id) + if part.function_call and part.function_call.id: + needed_call_ids.discard(part.function_call.id) + + if i <= split_index and not needed_call_ids: + return i + + return 0 + + class ContextFilterPlugin(BasePlugin): """A plugin that filters the LLM context to reduce its size.""" @@ -76,6 +110,12 @@ class ContextFilterPlugin(BasePlugin): start_index -= 1 split_index = start_index break + # Adjust split_index to avoid orphaned function_responses. + split_index = ( + _adjust_split_index_to_avoid_orphaned_function_responses( + contents, split_index + ) + ) contents = contents[split_index:] if self._custom_filter: diff --git a/tests/unittests/plugins/test_context_filtering_plugin.py b/tests/unittests/plugins/test_context_filtering_plugin.py index f9c8222e..de72b32b 100644 --- a/tests/unittests/plugins/test_context_filtering_plugin.py +++ b/tests/unittests/plugins/test_context_filtering_plugin.py @@ -183,3 +183,111 @@ async def test_filter_function_raises_exception(): ) assert llm_request.contents == original_contents + + +def _create_function_call_content(name: str, call_id: str) -> types.Content: + """Creates a model content with a function call.""" + return types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall(id=call_id, name=name, args={}) + ) + ], + role="model", + ) + + +def _create_function_response_content(name: str, call_id: str) -> types.Content: + """Creates a user content with a function response.""" + return types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + id=call_id, name=name, response={"result": "ok"} + ) + ) + ], + role="user", + ) + + +@pytest.mark.asyncio +async def test_filter_preserves_function_call_response_pairs(): + """Tests that function_call and function_response pairs are kept together. + + This tests the fix for issue #4027 where filtering could create orphaned + function_response messages without their corresponding function_call. + """ + plugin = ContextFilterPlugin(num_invocations_to_keep=2) + + # Simulate conversation from issue #4027: + # user -> model -> user -> model(function_call) -> user(function_response) + # -> model -> user -> model(function_call) -> user(function_response) + contents = [ + _create_content("user", "Hello"), + _create_content("model", "Hi there!"), + _create_content("user", "I want to know about X"), + _create_function_call_content("knowledge_base", "call_1"), + _create_function_response_content("knowledge_base", "call_1"), + _create_content("model", "I found some information..."), + _create_content("user", "can you explain more about Y"), + _create_function_call_content("knowledge_base", "call_2"), + _create_function_response_content("knowledge_base", "call_2"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # Verify function_call for call_1 is included (not orphaned function_response) + call_ids_present = set() + response_ids_present = set() + for content in llm_request.contents: + if content.parts: + for part in content.parts: + if part.function_call and part.function_call.id: + call_ids_present.add(part.function_call.id) + if part.function_response and part.function_response.id: + response_ids_present.add(part.function_response.id) + + # Every function_response should have a matching function_call + assert response_ids_present.issubset(call_ids_present), ( + "Orphaned function_responses found. " + f"Responses: {response_ids_present}, Calls: {call_ids_present}" + ) + + +@pytest.mark.asyncio +async def test_filter_with_nested_function_calls(): + """Tests filtering with multiple nested function call sequences.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=1) + + contents = [ + _create_content("user", "Hello"), + _create_content("model", "Hi!"), + _create_content("user", "Do task"), + _create_function_call_content("tool_a", "call_a"), + _create_function_response_content("tool_a", "call_a"), + _create_function_call_content("tool_b", "call_b"), + _create_function_response_content("tool_b", "call_b"), + _create_content("model", "Done with tasks"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # Verify no orphaned function_responses + call_ids = set() + response_ids = set() + for content in llm_request.contents: + if content.parts: + for part in content.parts: + if part.function_call and part.function_call.id: + call_ids.add(part.function_call.id) + if part.function_response and part.function_response.id: + response_ids.add(part.function_response.id) + + assert response_ids.issubset(call_ids)