fix: Prevent ContextFilterPlugin from creating orphaned function responses

When truncating conversation history, make sure function_response messages always have their corresponding function_call included

Close #4027

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 852377919
This commit is contained in:
George Weale
2026-01-05 11:09:02 -08:00
committed by Copybara-Service
parent 688f48fffb
commit e32f017979
2 changed files with 148 additions and 0 deletions
@@ -14,11 +14,14 @@
from __future__ import annotations
from collections.abc import Sequence
import logging
from typing import Callable
from typing import List
from typing import Optional
from google.genai import types
from ..agents.callback_context import CallbackContext
from ..events.event import Event
from ..models.llm_request import LlmRequest
@@ -28,6 +31,37 @@ from .base_plugin import BasePlugin
logger = logging.getLogger("google_adk." + __name__)
def _adjust_split_index_to_avoid_orphaned_function_responses(
contents: Sequence[types.Content], split_index: int
) -> int:
"""Moves `split_index` left until function calls/responses stay paired.
When truncating context, we must avoid keeping a `function_response` while
dropping its matching preceding `function_call`.
Args:
contents: Full conversation contents in chronological order.
split_index: Candidate split index (keep `contents[split_index:]`).
Returns:
A (possibly smaller) split index that preserves call/response pairs.
"""
needed_call_ids = set()
for i in range(len(contents) - 1, -1, -1):
parts = contents[i].parts
if parts:
for part in reversed(parts):
if part.function_response and part.function_response.id:
needed_call_ids.add(part.function_response.id)
if part.function_call and part.function_call.id:
needed_call_ids.discard(part.function_call.id)
if i <= split_index and not needed_call_ids:
return i
return 0
class ContextFilterPlugin(BasePlugin):
"""A plugin that filters the LLM context to reduce its size."""
@@ -76,6 +110,12 @@ class ContextFilterPlugin(BasePlugin):
start_index -= 1
split_index = start_index
break
# Adjust split_index to avoid orphaned function_responses.
split_index = (
_adjust_split_index_to_avoid_orphaned_function_responses(
contents, split_index
)
)
contents = contents[split_index:]
if self._custom_filter:
@@ -183,3 +183,111 @@ async def test_filter_function_raises_exception():
)
assert llm_request.contents == original_contents
def _create_function_call_content(name: str, call_id: str) -> types.Content:
"""Creates a model content with a function call."""
return types.Content(
parts=[
types.Part(
function_call=types.FunctionCall(id=call_id, name=name, args={})
)
],
role="model",
)
def _create_function_response_content(name: str, call_id: str) -> types.Content:
"""Creates a user content with a function response."""
return types.Content(
parts=[
types.Part(
function_response=types.FunctionResponse(
id=call_id, name=name, response={"result": "ok"}
)
)
],
role="user",
)
@pytest.mark.asyncio
async def test_filter_preserves_function_call_response_pairs():
"""Tests that function_call and function_response pairs are kept together.
This tests the fix for issue #4027 where filtering could create orphaned
function_response messages without their corresponding function_call.
"""
plugin = ContextFilterPlugin(num_invocations_to_keep=2)
# Simulate conversation from issue #4027:
# user -> model -> user -> model(function_call) -> user(function_response)
# -> model -> user -> model(function_call) -> user(function_response)
contents = [
_create_content("user", "Hello"),
_create_content("model", "Hi there!"),
_create_content("user", "I want to know about X"),
_create_function_call_content("knowledge_base", "call_1"),
_create_function_response_content("knowledge_base", "call_1"),
_create_content("model", "I found some information..."),
_create_content("user", "can you explain more about Y"),
_create_function_call_content("knowledge_base", "call_2"),
_create_function_response_content("knowledge_base", "call_2"),
]
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)
# Verify function_call for call_1 is included (not orphaned function_response)
call_ids_present = set()
response_ids_present = set()
for content in llm_request.contents:
if content.parts:
for part in content.parts:
if part.function_call and part.function_call.id:
call_ids_present.add(part.function_call.id)
if part.function_response and part.function_response.id:
response_ids_present.add(part.function_response.id)
# Every function_response should have a matching function_call
assert response_ids_present.issubset(call_ids_present), (
"Orphaned function_responses found. "
f"Responses: {response_ids_present}, Calls: {call_ids_present}"
)
@pytest.mark.asyncio
async def test_filter_with_nested_function_calls():
"""Tests filtering with multiple nested function call sequences."""
plugin = ContextFilterPlugin(num_invocations_to_keep=1)
contents = [
_create_content("user", "Hello"),
_create_content("model", "Hi!"),
_create_content("user", "Do task"),
_create_function_call_content("tool_a", "call_a"),
_create_function_response_content("tool_a", "call_a"),
_create_function_call_content("tool_b", "call_b"),
_create_function_response_content("tool_b", "call_b"),
_create_content("model", "Done with tasks"),
]
llm_request = LlmRequest(contents=contents)
await plugin.before_model_callback(
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
)
# Verify no orphaned function_responses
call_ids = set()
response_ids = set()
for content in llm_request.contents:
if content.parts:
for part in content.parts:
if part.function_call and part.function_call.id:
call_ids.add(part.function_call.id)
if part.function_response and part.function_response.id:
response_ids.add(part.function_response.id)
assert response_ids.issubset(call_ids)