fix: Fix the long running function response event merge logic

1) raise explicit error if the response event contains responses against multiple function call events
2) merge all function responses for the corresponding function call event

PiperOrigin-RevId: 782154577
This commit is contained in:
Xiang (Sean) Zhou
2025-07-11 16:47:19 -07:00
committed by Copybara-Service
parent a8fcc1b8ab
commit 134ec0d71e
3 changed files with 194 additions and 12 deletions
+16 -10
View File
@@ -157,12 +157,21 @@ def _rearrange_events_for_latest_function_response(
for function_call in function_calls:
if function_call.id in function_responses_ids:
function_call_event_idx = idx
break
if function_call_event_idx != -1:
# in case the last response event only have part of the responses
# for the function calls in the function call event
for function_call in function_calls:
function_responses_ids.add(function_call.id)
function_call_ids = {
function_call.id for function_call in function_calls
}
# last response event should only contain the responses for the
# function calls in the same function call event
if not function_responses_ids.issubset(function_call_ids):
raise ValueError(
'Last response event should only contain the responses for the'
' function calls in the same function call event. Function'
f' call ids found : {function_call_ids}, function response'
f' ids provided: {function_responses_ids}'
)
# collect all function responses from the function call event to
# the last response event
function_responses_ids = function_call_ids
break
if function_call_event_idx == -1:
@@ -363,10 +372,7 @@ def _merge_function_response_events(
list is in increasing order of timestamp; 2. the first event is the
initial function_response event; 3. all later events should contain at
least one function_response part that related to the function_call
event. (Note, 3. may not be true when aync function return some
intermediate response, there could also be some intermediate model
response event without any function_response and such event will be
ignored.)
event.
Caveat: This implementation doesn't support when a parallel function_call
event contains async function_call of the same name.
@@ -359,3 +359,179 @@ def test_rearrange_events_for_latest_function_response():
# Should remove intermediate events and merge responses
assert len(rearranged) == 2
assert rearranged[0] == call_event
assert rearranged[1] == response_event
def test_rearrange_events_for_latest_function_response_multiple_calls():
"""Test _rearrange_events_for_latest_function_response with multiple function calls."""
# Create function call event with multiple calls
function_call1 = types.FunctionCall(
id="func_123", name="test_function", args={"param": "value1"}
)
function_call2 = types.FunctionCall(
id="func_456", name="test_function2", args={"param": "value2"}
)
call_event = Event(
invocation_id="test_inv1",
author="agent",
content=types.Content(
role="model",
parts=[
types.Part(function_call=function_call1),
types.Part(function_call=function_call2),
],
),
)
# Create intermediate event
intermediate_event = Event(
invocation_id="test_inv2",
author="agent",
content=types.Content(
role="model", parts=[types.Part.from_text(text="Processing...")]
),
)
# Create function response event with only one response
function_response = types.FunctionResponse(
id="func_123", name="test_function", response={"result": "success"}
)
response_event = Event(
invocation_id="test_inv3",
author="user",
content=types.Content(
role="user", parts=[types.Part(function_response=function_response)]
),
)
# Test with matching function call and response
events = [call_event, intermediate_event, response_event]
rearranged = _rearrange_events_for_latest_function_response(events)
# Should remove intermediate events and merge responses
assert len(rearranged) == 2
assert rearranged[0] == call_event
assert rearranged[1] == response_event
def test_rearrange_events_for_latest_function_response_validation_error():
"""Test _rearrange_events_for_latest_function_response with validation error."""
# Create function call event with one function call
function_call = types.FunctionCall(
id="func_123", name="test_function", args={"param": "value"}
)
call_event = Event(
invocation_id="test_inv1",
author="agent",
content=types.Content(
role="model", parts=[types.Part(function_call=function_call)]
),
)
# Create intermediate event
intermediate_event = Event(
invocation_id="test_inv2",
author="agent",
content=types.Content(
role="model", parts=[types.Part.from_text(text="Processing...")]
),
)
# Create function response event with the matching function call AND an extra one
function_response1 = types.FunctionResponse(
id="func_123", name="test_function", response={"result": "success"}
)
function_response2 = types.FunctionResponse(
id="func_456", name="other_function", response={"result": "other"}
)
response_event = Event(
invocation_id="test_inv3",
author="user",
content=types.Content(
role="user",
parts=[
types.Part(function_response=function_response1),
types.Part(function_response=function_response2),
],
),
)
# Test with mismatched function call and response
events = [call_event, intermediate_event, response_event]
with pytest.raises(
ValueError,
match=(
"Last response event should only contain the responses for the"
" function calls in the same function call event"
),
):
_rearrange_events_for_latest_function_response(events)
def test_rearrange_events_for_latest_function_response_mixed_responses():
"""Test _rearrange_events_for_latest_function_response with mixed function responses."""
# Create function call event with two calls
function_call1 = types.FunctionCall(
id="func_123", name="test_function", args={"param": "value1"}
)
function_call2 = types.FunctionCall(
id="func_456", name="test_function2", args={"param": "value2"}
)
call_event = Event(
invocation_id="test_inv1",
author="agent",
content=types.Content(
role="model",
parts=[
types.Part(function_call=function_call1),
types.Part(function_call=function_call2),
],
),
)
# Create intermediate event
intermediate_event = Event(
invocation_id="test_inv2",
author="agent",
content=types.Content(
role="model", parts=[types.Part.from_text(text="Processing...")]
),
)
# Create function response event with one matching and one non-matching response
function_response1 = types.FunctionResponse(
id="func_123", name="test_function", response={"result": "success"}
)
function_response2 = types.FunctionResponse(
id="func_789", name="test_function3", response={"result": "other"}
)
response_event = Event(
invocation_id="test_inv3",
author="user",
content=types.Content(
role="user",
parts=[
types.Part(function_response=function_response1),
types.Part(function_response=function_response2),
],
),
)
# Test with mixed function responses
events = [call_event, intermediate_event, response_event]
with pytest.raises(
ValueError,
match=(
"Last response event should only contain the responses for the"
" function calls in the same function call event"
),
):
_rearrange_events_for_latest_function_response(events)
@@ -549,13 +549,13 @@ def test_function_get_auth_response_partial():
],
),
)
# assert function_invoked == 4
assert function_invoked == 4
assert len(mock_model.requests) == 4
request = mock_model.requests[-1]
content = request.contents[-1]
parts = content.parts
assert len(parts) == 2
assert parts[0].function_response.name == 'call_external_api1'
assert parts[0].function_response.response == {'result': None}
assert parts[0].function_response.response == {'result': 1}
assert parts[1].function_response.name == 'call_external_api2'
assert parts[1].function_response.response == {'result': 2}