From 94b5aaf0a1c2bba8a3b80b26a1eb09f35099c488 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 13 Jul 2025 03:37:09 +0000 Subject: [PATCH] fix: Correct EventAction merging logic and add corresponding tests --- src/google/adk/flows/llm_flows/functions.py | 28 +++-- .../llm_flows/test_functions_parallel.py | 107 ++++++++++++++++++ 2 files changed, 127 insertions(+), 8 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_functions_parallel.py diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index c64b7cec..06feab74 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -511,6 +511,16 @@ def __build_response_event( return function_response_event +def deep_merge_dicts(d1: dict, d2: dict) -> dict: + """Recursively merges d2 into d1.""" + for key, value in d2.items(): + if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict): + d1[key] = deep_merge_dicts(d1[key], value) + else: + d1[key] = value + return d1 + + def merge_parallel_function_response_events( function_response_events: list['Event'], ) -> 'Event': @@ -529,15 +539,17 @@ def merge_parallel_function_response_events( base_event = function_response_events[0] # Merge actions from all events - - merged_actions = EventActions() - merged_requested_auth_configs = {} + merged_actions_data = {} for event in function_response_events: - merged_requested_auth_configs.update(event.actions.requested_auth_configs) - merged_actions = merged_actions.model_copy( - update=event.actions.model_dump() - ) - merged_actions.requested_auth_configs = merged_requested_auth_configs + if event.actions: + # Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate. + merged_actions_data = deep_merge_dicts( + merged_actions_data, + event.actions.model_dump(exclude_none=True, by_alias=True), + ) + + merged_actions = EventActions.model_validate(merged_actions_data) + # Create the new merged event merged_event = Event( invocation_id=Event.new_id(), diff --git a/tests/unittests/flows/llm_flows/test_functions_parallel.py b/tests/unittests/flows/llm_flows/test_functions_parallel.py new file mode 100644 index 00000000..626dfcf6 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_functions_parallel.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents import Agent +from google.adk.events.event_actions import EventActions +from google.adk.tools import ToolContext +from google.genai import types +import pytest + +from ... import testing_utils + + +@pytest.mark.asyncio +async def test_parallel_function_calls_with_state_change(): + function_calls = [ + types.Part.from_function_call( + name='update_session_state', + args={'key': 'test_key1', 'value': 'test_value1'}, + ), + types.Part.from_function_call( + name='update_session_state', + args={'key': 'test_key2', 'value': 'test_value2'}, + ), + types.Part.from_function_call( + name='transfer_to_agent', args={'agent_name': 'test_sub_agent'} + ), + ] + function_responses = [ + types.Part.from_function_response( + name='update_session_state', response={'result': None} + ), + types.Part.from_function_response( + name='update_session_state', response={'result': None} + ), + types.Part.from_function_response( + name='transfer_to_agent', response={'result': None} + ), + ] + + responses: list[types.Content] = [ + function_calls, + 'response1', + ] + function_called = 0 + mock_model = testing_utils.MockModel.create(responses=responses) + + async def update_session_state( + key: str, value: str, tool_context: ToolContext + ) -> None: + nonlocal function_called + function_called += 1 + tool_context.state.update({key: value}) + return + + async def transfer_to_agent( + agent_name: str, tool_context: ToolContext + ) -> None: + nonlocal function_called + function_called += 1 + tool_context.actions.transfer_to_agent = agent_name + return + + test_sub_agent = Agent( + name='test_sub_agent', + ) + + agent = Agent( + name='root_agent', + model=mock_model, + tools=[update_session_state, transfer_to_agent], + sub_agents=[test_sub_agent], + ) + runner = testing_utils.TestInMemoryRunner(agent) + events = await runner.run_async_with_new_session('test') + + # Notice that the following assertion only checks the "contents" part of the events. + # The "actions" part will be checked later. + assert testing_utils.simplify_events(events) == [ + ('root_agent', function_calls), + ('root_agent', function_responses), + ('test_sub_agent', 'response1'), + ] + + # Asserts the function calls. + assert function_called == 3 + + # Asserts the actions in response event. + response_event = events[1] + + assert response_event.actions == EventActions( + state_delta={ + 'test_key1': 'test_value1', + 'test_key2': 'test_value2', + }, + transfer_to_agent='test_sub_agent', + )