You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
Merge pull request #1955 from iwknow:func_fix
PiperOrigin-RevId: 783391431
This commit is contained in:
@@ -511,6 +511,16 @@ def __build_response_event(
|
||||
return function_response_event
|
||||
|
||||
|
||||
def deep_merge_dicts(d1: dict, d2: dict) -> dict:
|
||||
"""Recursively merges d2 into d1."""
|
||||
for key, value in d2.items():
|
||||
if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict):
|
||||
d1[key] = deep_merge_dicts(d1[key], value)
|
||||
else:
|
||||
d1[key] = value
|
||||
return d1
|
||||
|
||||
|
||||
def merge_parallel_function_response_events(
|
||||
function_response_events: list['Event'],
|
||||
) -> 'Event':
|
||||
@@ -529,15 +539,17 @@ def merge_parallel_function_response_events(
|
||||
base_event = function_response_events[0]
|
||||
|
||||
# Merge actions from all events
|
||||
|
||||
merged_actions = EventActions()
|
||||
merged_requested_auth_configs = {}
|
||||
merged_actions_data = {}
|
||||
for event in function_response_events:
|
||||
merged_requested_auth_configs.update(event.actions.requested_auth_configs)
|
||||
merged_actions = merged_actions.model_copy(
|
||||
update=event.actions.model_dump()
|
||||
)
|
||||
merged_actions.requested_auth_configs = merged_requested_auth_configs
|
||||
if event.actions:
|
||||
# Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate.
|
||||
merged_actions_data = deep_merge_dicts(
|
||||
merged_actions_data,
|
||||
event.actions.model_dump(exclude_none=True, by_alias=True),
|
||||
)
|
||||
|
||||
merged_actions = EventActions.model_validate(merged_actions_data)
|
||||
|
||||
# Create the new merged event
|
||||
merged_event = Event(
|
||||
invocation_id=Event.new_id(),
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.events.event_actions import EventActions
|
||||
from google.adk.tools import ToolContext
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import testing_utils
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_function_calls_with_state_change():
|
||||
function_calls = [
|
||||
types.Part.from_function_call(
|
||||
name='update_session_state',
|
||||
args={'key': 'test_key1', 'value': 'test_value1'},
|
||||
),
|
||||
types.Part.from_function_call(
|
||||
name='update_session_state',
|
||||
args={'key': 'test_key2', 'value': 'test_value2'},
|
||||
),
|
||||
types.Part.from_function_call(
|
||||
name='transfer_to_agent', args={'agent_name': 'test_sub_agent'}
|
||||
),
|
||||
]
|
||||
function_responses = [
|
||||
types.Part.from_function_response(
|
||||
name='update_session_state', response={'result': None}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='update_session_state', response={'result': None}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='transfer_to_agent', response={'result': None}
|
||||
),
|
||||
]
|
||||
|
||||
responses: list[types.Content] = [
|
||||
function_calls,
|
||||
'response1',
|
||||
]
|
||||
function_called = 0
|
||||
mock_model = testing_utils.MockModel.create(responses=responses)
|
||||
|
||||
async def update_session_state(
|
||||
key: str, value: str, tool_context: ToolContext
|
||||
) -> None:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
tool_context.state.update({key: value})
|
||||
return
|
||||
|
||||
async def transfer_to_agent(
|
||||
agent_name: str, tool_context: ToolContext
|
||||
) -> None:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
tool_context.actions.transfer_to_agent = agent_name
|
||||
return
|
||||
|
||||
test_sub_agent = Agent(
|
||||
name='test_sub_agent',
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[update_session_state, transfer_to_agent],
|
||||
sub_agents=[test_sub_agent],
|
||||
)
|
||||
runner = testing_utils.TestInMemoryRunner(agent)
|
||||
events = await runner.run_async_with_new_session('test')
|
||||
|
||||
# Notice that the following assertion only checks the "contents" part of the events.
|
||||
# The "actions" part will be checked later.
|
||||
assert testing_utils.simplify_events(events) == [
|
||||
('root_agent', function_calls),
|
||||
('root_agent', function_responses),
|
||||
('test_sub_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 3
|
||||
|
||||
# Asserts the actions in response event.
|
||||
response_event = events[1]
|
||||
|
||||
assert response_event.actions == EventActions(
|
||||
state_delta={
|
||||
'test_key1': 'test_value1',
|
||||
'test_key2': 'test_value2',
|
||||
},
|
||||
transfer_to_agent='test_sub_agent',
|
||||
)
|
||||
Reference in New Issue
Block a user