Merge pull request #1955 from iwknow:func_fix

PiperOrigin-RevId: 783391431
This commit is contained in:
Copybara-Service
2025-07-15 10:47:16 -07:00
2 changed files with 127 additions and 8 deletions
+20 -8
View File
@@ -511,6 +511,16 @@ def __build_response_event(
return function_response_event
def deep_merge_dicts(d1: dict, d2: dict) -> dict:
"""Recursively merges d2 into d1."""
for key, value in d2.items():
if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict):
d1[key] = deep_merge_dicts(d1[key], value)
else:
d1[key] = value
return d1
def merge_parallel_function_response_events(
function_response_events: list['Event'],
) -> 'Event':
@@ -529,15 +539,17 @@ def merge_parallel_function_response_events(
base_event = function_response_events[0]
# Merge actions from all events
merged_actions = EventActions()
merged_requested_auth_configs = {}
merged_actions_data = {}
for event in function_response_events:
merged_requested_auth_configs.update(event.actions.requested_auth_configs)
merged_actions = merged_actions.model_copy(
update=event.actions.model_dump()
)
merged_actions.requested_auth_configs = merged_requested_auth_configs
if event.actions:
# Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate.
merged_actions_data = deep_merge_dicts(
merged_actions_data,
event.actions.model_dump(exclude_none=True, by_alias=True),
)
merged_actions = EventActions.model_validate(merged_actions_data)
# Create the new merged event
merged_event = Event(
invocation_id=Event.new_id(),
@@ -0,0 +1,107 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.agents import Agent
from google.adk.events.event_actions import EventActions
from google.adk.tools import ToolContext
from google.genai import types
import pytest
from ... import testing_utils
@pytest.mark.asyncio
async def test_parallel_function_calls_with_state_change():
function_calls = [
types.Part.from_function_call(
name='update_session_state',
args={'key': 'test_key1', 'value': 'test_value1'},
),
types.Part.from_function_call(
name='update_session_state',
args={'key': 'test_key2', 'value': 'test_value2'},
),
types.Part.from_function_call(
name='transfer_to_agent', args={'agent_name': 'test_sub_agent'}
),
]
function_responses = [
types.Part.from_function_response(
name='update_session_state', response={'result': None}
),
types.Part.from_function_response(
name='update_session_state', response={'result': None}
),
types.Part.from_function_response(
name='transfer_to_agent', response={'result': None}
),
]
responses: list[types.Content] = [
function_calls,
'response1',
]
function_called = 0
mock_model = testing_utils.MockModel.create(responses=responses)
async def update_session_state(
key: str, value: str, tool_context: ToolContext
) -> None:
nonlocal function_called
function_called += 1
tool_context.state.update({key: value})
return
async def transfer_to_agent(
agent_name: str, tool_context: ToolContext
) -> None:
nonlocal function_called
function_called += 1
tool_context.actions.transfer_to_agent = agent_name
return
test_sub_agent = Agent(
name='test_sub_agent',
)
agent = Agent(
name='root_agent',
model=mock_model,
tools=[update_session_state, transfer_to_agent],
sub_agents=[test_sub_agent],
)
runner = testing_utils.TestInMemoryRunner(agent)
events = await runner.run_async_with_new_session('test')
# Notice that the following assertion only checks the "contents" part of the events.
# The "actions" part will be checked later.
assert testing_utils.simplify_events(events) == [
('root_agent', function_calls),
('root_agent', function_responses),
('test_sub_agent', 'response1'),
]
# Asserts the function calls.
assert function_called == 3
# Asserts the actions in response event.
response_event = events[1]
assert response_event.actions == EventActions(
state_delta={
'test_key1': 'test_value1',
'test_key2': 'test_value2',
},
transfer_to_agent='test_sub_agent',
)