You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Prevent ContextFilterPlugin from creating orphaned function responses
When truncating conversation history, make sure function_response messages always have their corresponding function_call included Close #4027 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 852377919
This commit is contained in:
committed by
Copybara-Service
parent
688f48fffb
commit
e32f017979
@@ -14,11 +14,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from ..events.event import Event
|
||||
from ..models.llm_request import LlmRequest
|
||||
@@ -28,6 +31,37 @@ from .base_plugin import BasePlugin
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
def _adjust_split_index_to_avoid_orphaned_function_responses(
|
||||
contents: Sequence[types.Content], split_index: int
|
||||
) -> int:
|
||||
"""Moves `split_index` left until function calls/responses stay paired.
|
||||
|
||||
When truncating context, we must avoid keeping a `function_response` while
|
||||
dropping its matching preceding `function_call`.
|
||||
|
||||
Args:
|
||||
contents: Full conversation contents in chronological order.
|
||||
split_index: Candidate split index (keep `contents[split_index:]`).
|
||||
|
||||
Returns:
|
||||
A (possibly smaller) split index that preserves call/response pairs.
|
||||
"""
|
||||
needed_call_ids = set()
|
||||
for i in range(len(contents) - 1, -1, -1):
|
||||
parts = contents[i].parts
|
||||
if parts:
|
||||
for part in reversed(parts):
|
||||
if part.function_response and part.function_response.id:
|
||||
needed_call_ids.add(part.function_response.id)
|
||||
if part.function_call and part.function_call.id:
|
||||
needed_call_ids.discard(part.function_call.id)
|
||||
|
||||
if i <= split_index and not needed_call_ids:
|
||||
return i
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class ContextFilterPlugin(BasePlugin):
|
||||
"""A plugin that filters the LLM context to reduce its size."""
|
||||
|
||||
@@ -76,6 +110,12 @@ class ContextFilterPlugin(BasePlugin):
|
||||
start_index -= 1
|
||||
split_index = start_index
|
||||
break
|
||||
# Adjust split_index to avoid orphaned function_responses.
|
||||
split_index = (
|
||||
_adjust_split_index_to_avoid_orphaned_function_responses(
|
||||
contents, split_index
|
||||
)
|
||||
)
|
||||
contents = contents[split_index:]
|
||||
|
||||
if self._custom_filter:
|
||||
|
||||
@@ -183,3 +183,111 @@ async def test_filter_function_raises_exception():
|
||||
)
|
||||
|
||||
assert llm_request.contents == original_contents
|
||||
|
||||
|
||||
def _create_function_call_content(name: str, call_id: str) -> types.Content:
|
||||
"""Creates a model content with a function call."""
|
||||
return types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(id=call_id, name=name, args={})
|
||||
)
|
||||
],
|
||||
role="model",
|
||||
)
|
||||
|
||||
|
||||
def _create_function_response_content(name: str, call_id: str) -> types.Content:
|
||||
"""Creates a user content with a function response."""
|
||||
return types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
id=call_id, name=name, response={"result": "ok"}
|
||||
)
|
||||
)
|
||||
],
|
||||
role="user",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_preserves_function_call_response_pairs():
|
||||
"""Tests that function_call and function_response pairs are kept together.
|
||||
|
||||
This tests the fix for issue #4027 where filtering could create orphaned
|
||||
function_response messages without their corresponding function_call.
|
||||
"""
|
||||
plugin = ContextFilterPlugin(num_invocations_to_keep=2)
|
||||
|
||||
# Simulate conversation from issue #4027:
|
||||
# user -> model -> user -> model(function_call) -> user(function_response)
|
||||
# -> model -> user -> model(function_call) -> user(function_response)
|
||||
contents = [
|
||||
_create_content("user", "Hello"),
|
||||
_create_content("model", "Hi there!"),
|
||||
_create_content("user", "I want to know about X"),
|
||||
_create_function_call_content("knowledge_base", "call_1"),
|
||||
_create_function_response_content("knowledge_base", "call_1"),
|
||||
_create_content("model", "I found some information..."),
|
||||
_create_content("user", "can you explain more about Y"),
|
||||
_create_function_call_content("knowledge_base", "call_2"),
|
||||
_create_function_response_content("knowledge_base", "call_2"),
|
||||
]
|
||||
llm_request = LlmRequest(contents=contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
)
|
||||
|
||||
# Verify function_call for call_1 is included (not orphaned function_response)
|
||||
call_ids_present = set()
|
||||
response_ids_present = set()
|
||||
for content in llm_request.contents:
|
||||
if content.parts:
|
||||
for part in content.parts:
|
||||
if part.function_call and part.function_call.id:
|
||||
call_ids_present.add(part.function_call.id)
|
||||
if part.function_response and part.function_response.id:
|
||||
response_ids_present.add(part.function_response.id)
|
||||
|
||||
# Every function_response should have a matching function_call
|
||||
assert response_ids_present.issubset(call_ids_present), (
|
||||
"Orphaned function_responses found. "
|
||||
f"Responses: {response_ids_present}, Calls: {call_ids_present}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_with_nested_function_calls():
|
||||
"""Tests filtering with multiple nested function call sequences."""
|
||||
plugin = ContextFilterPlugin(num_invocations_to_keep=1)
|
||||
|
||||
contents = [
|
||||
_create_content("user", "Hello"),
|
||||
_create_content("model", "Hi!"),
|
||||
_create_content("user", "Do task"),
|
||||
_create_function_call_content("tool_a", "call_a"),
|
||||
_create_function_response_content("tool_a", "call_a"),
|
||||
_create_function_call_content("tool_b", "call_b"),
|
||||
_create_function_response_content("tool_b", "call_b"),
|
||||
_create_content("model", "Done with tasks"),
|
||||
]
|
||||
llm_request = LlmRequest(contents=contents)
|
||||
|
||||
await plugin.before_model_callback(
|
||||
callback_context=Mock(spec=CallbackContext), llm_request=llm_request
|
||||
)
|
||||
|
||||
# Verify no orphaned function_responses
|
||||
call_ids = set()
|
||||
response_ids = set()
|
||||
for content in llm_request.contents:
|
||||
if content.parts:
|
||||
for part in content.parts:
|
||||
if part.function_call and part.function_call.id:
|
||||
call_ids.add(part.function_call.id)
|
||||
if part.function_response and part.function_response.id:
|
||||
response_ids.add(part.function_response.id)
|
||||
|
||||
assert response_ids.issubset(call_ids)
|
||||
|
||||
Reference in New Issue
Block a user