You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Fix the long running function response event merge logic
1) raise explicit error if the response event contains responses against multiple function call events 2) merge all function responses for the corresponding function call event PiperOrigin-RevId: 782154577
This commit is contained in:
committed by
Copybara-Service
parent
a8fcc1b8ab
commit
134ec0d71e
@@ -157,12 +157,21 @@ def _rearrange_events_for_latest_function_response(
|
||||
for function_call in function_calls:
|
||||
if function_call.id in function_responses_ids:
|
||||
function_call_event_idx = idx
|
||||
break
|
||||
if function_call_event_idx != -1:
|
||||
# in case the last response event only have part of the responses
|
||||
# for the function calls in the function call event
|
||||
for function_call in function_calls:
|
||||
function_responses_ids.add(function_call.id)
|
||||
function_call_ids = {
|
||||
function_call.id for function_call in function_calls
|
||||
}
|
||||
# last response event should only contain the responses for the
|
||||
# function calls in the same function call event
|
||||
if not function_responses_ids.issubset(function_call_ids):
|
||||
raise ValueError(
|
||||
'Last response event should only contain the responses for the'
|
||||
' function calls in the same function call event. Function'
|
||||
f' call ids found : {function_call_ids}, function response'
|
||||
f' ids provided: {function_responses_ids}'
|
||||
)
|
||||
# collect all function responses from the function call event to
|
||||
# the last response event
|
||||
function_responses_ids = function_call_ids
|
||||
break
|
||||
|
||||
if function_call_event_idx == -1:
|
||||
@@ -363,10 +372,7 @@ def _merge_function_response_events(
|
||||
list is in increasing order of timestamp; 2. the first event is the
|
||||
initial function_response event; 3. all later events should contain at
|
||||
least one function_response part that related to the function_call
|
||||
event. (Note, 3. may not be true when aync function return some
|
||||
intermediate response, there could also be some intermediate model
|
||||
response event without any function_response and such event will be
|
||||
ignored.)
|
||||
event.
|
||||
Caveat: This implementation doesn't support when a parallel function_call
|
||||
event contains async function_call of the same name.
|
||||
|
||||
|
||||
@@ -359,3 +359,179 @@ def test_rearrange_events_for_latest_function_response():
|
||||
# Should remove intermediate events and merge responses
|
||||
assert len(rearranged) == 2
|
||||
assert rearranged[0] == call_event
|
||||
assert rearranged[1] == response_event
|
||||
|
||||
|
||||
def test_rearrange_events_for_latest_function_response_multiple_calls():
|
||||
"""Test _rearrange_events_for_latest_function_response with multiple function calls."""
|
||||
# Create function call event with multiple calls
|
||||
function_call1 = types.FunctionCall(
|
||||
id="func_123", name="test_function", args={"param": "value1"}
|
||||
)
|
||||
function_call2 = types.FunctionCall(
|
||||
id="func_456", name="test_function2", args={"param": "value2"}
|
||||
)
|
||||
|
||||
call_event = Event(
|
||||
invocation_id="test_inv1",
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part(function_call=function_call1),
|
||||
types.Part(function_call=function_call2),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Create intermediate event
|
||||
intermediate_event = Event(
|
||||
invocation_id="test_inv2",
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part.from_text(text="Processing...")]
|
||||
),
|
||||
)
|
||||
|
||||
# Create function response event with only one response
|
||||
function_response = types.FunctionResponse(
|
||||
id="func_123", name="test_function", response={"result": "success"}
|
||||
)
|
||||
|
||||
response_event = Event(
|
||||
invocation_id="test_inv3",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user", parts=[types.Part(function_response=function_response)]
|
||||
),
|
||||
)
|
||||
|
||||
# Test with matching function call and response
|
||||
events = [call_event, intermediate_event, response_event]
|
||||
rearranged = _rearrange_events_for_latest_function_response(events)
|
||||
|
||||
# Should remove intermediate events and merge responses
|
||||
assert len(rearranged) == 2
|
||||
assert rearranged[0] == call_event
|
||||
assert rearranged[1] == response_event
|
||||
|
||||
|
||||
def test_rearrange_events_for_latest_function_response_validation_error():
|
||||
"""Test _rearrange_events_for_latest_function_response with validation error."""
|
||||
# Create function call event with one function call
|
||||
function_call = types.FunctionCall(
|
||||
id="func_123", name="test_function", args={"param": "value"}
|
||||
)
|
||||
|
||||
call_event = Event(
|
||||
invocation_id="test_inv1",
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(function_call=function_call)]
|
||||
),
|
||||
)
|
||||
|
||||
# Create intermediate event
|
||||
intermediate_event = Event(
|
||||
invocation_id="test_inv2",
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part.from_text(text="Processing...")]
|
||||
),
|
||||
)
|
||||
|
||||
# Create function response event with the matching function call AND an extra one
|
||||
function_response1 = types.FunctionResponse(
|
||||
id="func_123", name="test_function", response={"result": "success"}
|
||||
)
|
||||
function_response2 = types.FunctionResponse(
|
||||
id="func_456", name="other_function", response={"result": "other"}
|
||||
)
|
||||
|
||||
response_event = Event(
|
||||
invocation_id="test_inv3",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part(function_response=function_response1),
|
||||
types.Part(function_response=function_response2),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Test with mismatched function call and response
|
||||
events = [call_event, intermediate_event, response_event]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Last response event should only contain the responses for the"
|
||||
" function calls in the same function call event"
|
||||
),
|
||||
):
|
||||
_rearrange_events_for_latest_function_response(events)
|
||||
|
||||
|
||||
def test_rearrange_events_for_latest_function_response_mixed_responses():
|
||||
"""Test _rearrange_events_for_latest_function_response with mixed function responses."""
|
||||
# Create function call event with two calls
|
||||
function_call1 = types.FunctionCall(
|
||||
id="func_123", name="test_function", args={"param": "value1"}
|
||||
)
|
||||
function_call2 = types.FunctionCall(
|
||||
id="func_456", name="test_function2", args={"param": "value2"}
|
||||
)
|
||||
|
||||
call_event = Event(
|
||||
invocation_id="test_inv1",
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part(function_call=function_call1),
|
||||
types.Part(function_call=function_call2),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Create intermediate event
|
||||
intermediate_event = Event(
|
||||
invocation_id="test_inv2",
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part.from_text(text="Processing...")]
|
||||
),
|
||||
)
|
||||
|
||||
# Create function response event with one matching and one non-matching response
|
||||
function_response1 = types.FunctionResponse(
|
||||
id="func_123", name="test_function", response={"result": "success"}
|
||||
)
|
||||
function_response2 = types.FunctionResponse(
|
||||
id="func_789", name="test_function3", response={"result": "other"}
|
||||
)
|
||||
|
||||
response_event = Event(
|
||||
invocation_id="test_inv3",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part(function_response=function_response1),
|
||||
types.Part(function_response=function_response2),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Test with mixed function responses
|
||||
events = [call_event, intermediate_event, response_event]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Last response event should only contain the responses for the"
|
||||
" function calls in the same function call event"
|
||||
),
|
||||
):
|
||||
_rearrange_events_for_latest_function_response(events)
|
||||
|
||||
@@ -549,13 +549,13 @@ def test_function_get_auth_response_partial():
|
||||
],
|
||||
),
|
||||
)
|
||||
# assert function_invoked == 4
|
||||
assert function_invoked == 4
|
||||
assert len(mock_model.requests) == 4
|
||||
request = mock_model.requests[-1]
|
||||
content = request.contents[-1]
|
||||
parts = content.parts
|
||||
assert len(parts) == 2
|
||||
assert parts[0].function_response.name == 'call_external_api1'
|
||||
assert parts[0].function_response.response == {'result': None}
|
||||
assert parts[0].function_response.response == {'result': 1}
|
||||
assert parts[1].function_response.name == 'call_external_api2'
|
||||
assert parts[1].function_response.response == {'result': 2}
|
||||
|
||||
Reference in New Issue
Block a user