diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index 039eaf8c..e5f58490 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -157,12 +157,21 @@ def _rearrange_events_for_latest_function_response( for function_call in function_calls: if function_call.id in function_responses_ids: function_call_event_idx = idx - break - if function_call_event_idx != -1: - # in case the last response event only have part of the responses - # for the function calls in the function call event - for function_call in function_calls: - function_responses_ids.add(function_call.id) + function_call_ids = { + function_call.id for function_call in function_calls + } + # last response event should only contain the responses for the + # function calls in the same function call event + if not function_responses_ids.issubset(function_call_ids): + raise ValueError( + 'Last response event should only contain the responses for the' + ' function calls in the same function call event. Function' + f' call ids found : {function_call_ids}, function response' + f' ids provided: {function_responses_ids}' + ) + # collect all function responses from the function call event to + # the last response event + function_responses_ids = function_call_ids break if function_call_event_idx == -1: @@ -363,10 +372,7 @@ def _merge_function_response_events( list is in increasing order of timestamp; 2. the first event is the initial function_response event; 3. all later events should contain at least one function_response part that related to the function_call - event. (Note, 3. may not be true when aync function return some - intermediate response, there could also be some intermediate model - response event without any function_response and such event will be - ignored.) + event. Caveat: This implementation doesn't support when a parallel function_call event contains async function_call of the same name. diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py index a330852a..995b3868 100644 --- a/tests/unittests/flows/llm_flows/test_contents.py +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -359,3 +359,179 @@ def test_rearrange_events_for_latest_function_response(): # Should remove intermediate events and merge responses assert len(rearranged) == 2 assert rearranged[0] == call_event + assert rearranged[1] == response_event + + +def test_rearrange_events_for_latest_function_response_multiple_calls(): + """Test _rearrange_events_for_latest_function_response with multiple function calls.""" + # Create function call event with multiple calls + function_call1 = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value1"} + ) + function_call2 = types.FunctionCall( + id="func_456", name="test_function2", args={"param": "value2"} + ) + + call_event = Event( + invocation_id="test_inv1", + author="agent", + content=types.Content( + role="model", + parts=[ + types.Part(function_call=function_call1), + types.Part(function_call=function_call2), + ], + ), + ) + + # Create intermediate event + intermediate_event = Event( + invocation_id="test_inv2", + author="agent", + content=types.Content( + role="model", parts=[types.Part.from_text(text="Processing...")] + ), + ) + + # Create function response event with only one response + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + response_event = Event( + invocation_id="test_inv3", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + # Test with matching function call and response + events = [call_event, intermediate_event, response_event] + rearranged = _rearrange_events_for_latest_function_response(events) + + # Should remove intermediate events and merge responses + assert len(rearranged) == 2 + assert rearranged[0] == call_event + assert rearranged[1] == response_event + + +def test_rearrange_events_for_latest_function_response_validation_error(): + """Test _rearrange_events_for_latest_function_response with validation error.""" + # Create function call event with one function call + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + call_event = Event( + invocation_id="test_inv1", + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + # Create intermediate event + intermediate_event = Event( + invocation_id="test_inv2", + author="agent", + content=types.Content( + role="model", parts=[types.Part.from_text(text="Processing...")] + ), + ) + + # Create function response event with the matching function call AND an extra one + function_response1 = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + function_response2 = types.FunctionResponse( + id="func_456", name="other_function", response={"result": "other"} + ) + + response_event = Event( + invocation_id="test_inv3", + author="user", + content=types.Content( + role="user", + parts=[ + types.Part(function_response=function_response1), + types.Part(function_response=function_response2), + ], + ), + ) + + # Test with mismatched function call and response + events = [call_event, intermediate_event, response_event] + + with pytest.raises( + ValueError, + match=( + "Last response event should only contain the responses for the" + " function calls in the same function call event" + ), + ): + _rearrange_events_for_latest_function_response(events) + + +def test_rearrange_events_for_latest_function_response_mixed_responses(): + """Test _rearrange_events_for_latest_function_response with mixed function responses.""" + # Create function call event with two calls + function_call1 = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value1"} + ) + function_call2 = types.FunctionCall( + id="func_456", name="test_function2", args={"param": "value2"} + ) + + call_event = Event( + invocation_id="test_inv1", + author="agent", + content=types.Content( + role="model", + parts=[ + types.Part(function_call=function_call1), + types.Part(function_call=function_call2), + ], + ), + ) + + # Create intermediate event + intermediate_event = Event( + invocation_id="test_inv2", + author="agent", + content=types.Content( + role="model", parts=[types.Part.from_text(text="Processing...")] + ), + ) + + # Create function response event with one matching and one non-matching response + function_response1 = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + function_response2 = types.FunctionResponse( + id="func_789", name="test_function3", response={"result": "other"} + ) + + response_event = Event( + invocation_id="test_inv3", + author="user", + content=types.Content( + role="user", + parts=[ + types.Part(function_response=function_response1), + types.Part(function_response=function_response2), + ], + ), + ) + + # Test with mixed function responses + events = [call_event, intermediate_event, response_event] + + with pytest.raises( + ValueError, + match=( + "Last response event should only contain the responses for the" + " function calls in the same function call event" + ), + ): + _rearrange_events_for_latest_function_response(events) diff --git a/tests/unittests/flows/llm_flows/test_functions_request_euc.py b/tests/unittests/flows/llm_flows/test_functions_request_euc.py index afb3b73a..03b66a55 100644 --- a/tests/unittests/flows/llm_flows/test_functions_request_euc.py +++ b/tests/unittests/flows/llm_flows/test_functions_request_euc.py @@ -549,13 +549,13 @@ def test_function_get_auth_response_partial(): ], ), ) - # assert function_invoked == 4 + assert function_invoked == 4 assert len(mock_model.requests) == 4 request = mock_model.requests[-1] content = request.contents[-1] parts = content.parts assert len(parts) == 2 assert parts[0].function_response.name == 'call_external_api1' - assert parts[0].function_response.response == {'result': None} + assert parts[0].function_response.response == {'result': 1} assert parts[1].function_response.name == 'call_external_api2' assert parts[1].function_response.response == {'result': 2}