feat: Support parallel execution of parallel function calls

PiperOrigin-RevId: 790182046
This commit is contained in:
Xiang (Sean) Zhou
2025-08-02 12:27:38 -07:00
committed by Copybara-Service
parent 7556ebc76a
commit 57cd41f424
3 changed files with 593 additions and 192 deletions
File diff suppressed because it is too large Load Diff
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Any
from typing import Callable
@@ -676,3 +677,222 @@ def test_shallow_vs_deep_copy_demonstration():
deep_copy['nested_dict']['inner']['value'] == 'modified'
) # Copy is modified
assert 'new_item' in deep_copy['list_param'] # Copy is modified
@pytest.mark.asyncio
async def test_parallel_function_execution_timing():
"""Test that multiple function calls are executed in parallel, not sequentially."""
import time
execution_order = []
execution_times = {}
async def slow_function_1(delay: float = 0.1) -> dict:
start_time = time.time()
execution_order.append('start_1')
await asyncio.sleep(delay)
end_time = time.time()
execution_times['func_1'] = (start_time, end_time)
execution_order.append('end_1')
return {'result': 'function_1_result'}
async def slow_function_2(delay: float = 0.1) -> dict:
start_time = time.time()
execution_order.append('start_2')
await asyncio.sleep(delay)
end_time = time.time()
execution_times['func_2'] = (start_time, end_time)
execution_order.append('end_2')
return {'result': 'function_2_result'}
# Create function calls
function_calls = [
types.Part.from_function_call(
name='slow_function_1', args={'delay': 0.1}
),
types.Part.from_function_call(
name='slow_function_2', args={'delay': 0.1}
),
]
function_responses = [
types.Part.from_function_response(
name='slow_function_1', response={'result': 'function_1_result'}
),
types.Part.from_function_response(
name='slow_function_2', response={'result': 'function_2_result'}
),
]
responses: list[types.Content] = [
function_calls,
'response1',
]
mock_model = testing_utils.MockModel.create(responses=responses)
agent = Agent(
name='test_agent',
model=mock_model,
tools=[slow_function_1, slow_function_2],
)
runner = testing_utils.TestInMemoryRunner(agent)
# Measure total execution time
start_time = time.time()
events = await runner.run_async_with_new_session('test')
total_time = time.time() - start_time
# Verify parallel execution by checking execution order
# In parallel execution, both functions should start before either finishes
assert 'start_1' in execution_order
assert 'start_2' in execution_order
assert 'end_1' in execution_order
assert 'end_2' in execution_order
# Verify both functions started within a reasonable time window
func_1_start, func_1_end = execution_times['func_1']
func_2_start, func_2_end = execution_times['func_2']
# Functions should start at approximately the same time (within 10ms)
start_time_diff = abs(func_1_start - func_2_start)
assert (
start_time_diff < 0.01
), f'Functions started too far apart: {start_time_diff}s'
# Total execution time should be closer to 0.1s (parallel) than 0.2s (sequential)
# Allow some overhead for task creation and synchronization
assert (
total_time < 0.15
), f'Execution took too long: {total_time}s, expected < 0.15s'
# Verify the results are correct
assert testing_utils.simplify_events(events) == [
('test_agent', function_calls),
('test_agent', function_responses),
('test_agent', 'response1'),
]
@pytest.mark.asyncio
async def test_parallel_state_modifications_thread_safety():
"""Test that parallel function calls modifying state are thread-safe."""
state_modifications = []
def modify_state_1(tool_context: ToolContext) -> dict:
# Track when this function modifies state
current_state = dict(tool_context.state.to_dict())
state_modifications.append(('func_1_start', current_state))
tool_context.state['counter'] = tool_context.state.get('counter', 0) + 1
tool_context.state['func_1_executed'] = True
final_state = dict(tool_context.state.to_dict())
state_modifications.append(('func_1_end', final_state))
return {'result': 'modified_state_1'}
def modify_state_2(tool_context: ToolContext) -> dict:
# Track when this function modifies state
current_state = dict(tool_context.state.to_dict())
state_modifications.append(('func_2_start', current_state))
tool_context.state['counter'] = tool_context.state.get('counter', 0) + 1
tool_context.state['func_2_executed'] = True
final_state = dict(tool_context.state.to_dict())
state_modifications.append(('func_2_end', final_state))
return {'result': 'modified_state_2'}
# Create function calls
function_calls = [
types.Part.from_function_call(name='modify_state_1', args={}),
types.Part.from_function_call(name='modify_state_2', args={}),
]
responses: list[types.Content] = [
function_calls,
'response1',
]
mock_model = testing_utils.MockModel.create(responses=responses)
agent = Agent(
name='test_agent',
model=mock_model,
tools=[modify_state_1, modify_state_2],
)
runner = testing_utils.TestInMemoryRunner(agent)
events = await runner.run_async_with_new_session('test')
# Verify the parallel execution worked correctly by checking the events
# The function response event should have the merged state_delta
function_response_event = events[
1
] # Second event should be the function response
assert function_response_event.actions.state_delta['counter'] == 2
assert function_response_event.actions.state_delta['func_1_executed'] is True
assert function_response_event.actions.state_delta['func_2_executed'] is True
# Verify both functions were called
assert len(state_modifications) == 4 # 2 functions × 2 events each
# Extract function names from modifications
func_names = [mod[0] for mod in state_modifications]
assert 'func_1_start' in func_names
assert 'func_1_end' in func_names
assert 'func_2_start' in func_names
assert 'func_2_end' in func_names
@pytest.mark.asyncio
async def test_parallel_mixed_sync_async_functions():
"""Test parallel execution with mix of sync and async functions."""
execution_log = []
def sync_function(value: int) -> dict:
execution_log.append(f'sync_start_{value}')
# Simulate some work
import time
time.sleep(0.05) # 50ms
execution_log.append(f'sync_end_{value}')
return {'result': f'sync_{value}'}
async def async_function(value: int) -> dict:
execution_log.append(f'async_start_{value}')
await asyncio.sleep(0.05) # 50ms
execution_log.append(f'async_end_{value}')
return {'result': f'async_{value}'}
# Create function calls
function_calls = [
types.Part.from_function_call(name='sync_function', args={'value': 1}),
types.Part.from_function_call(name='async_function', args={'value': 2}),
types.Part.from_function_call(name='sync_function', args={'value': 3}),
]
responses: list[types.Content] = [function_calls, 'response1']
mock_model = testing_utils.MockModel.create(responses=responses)
agent = Agent(
name='test_agent',
model=mock_model,
tools=[sync_function, async_function],
)
runner = testing_utils.TestInMemoryRunner(agent)
import time
start_time = time.time()
events = await runner.run_async_with_new_session('test')
total_time = time.time() - start_time
# Should complete in less than 120ms (parallel) rather than 150ms (sequential)
# Allow for overhead from task creation and synchronization
assert total_time < 0.12, f'Execution took {total_time}s, expected < 0.12s'
# Verify all functions were called
assert 'sync_start_1' in execution_log
assert 'sync_end_1' in execution_log
assert 'async_start_2' in execution_log
assert 'async_end_2' in execution_log
assert 'sync_start_3' in execution_log
assert 'sync_end_3' in execution_log
+81 -10
View File
@@ -110,8 +110,15 @@ def test_live_streaming_function_call_single():
try:
session = self.session
# Add timeout to prevent hanging
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
# Return whatever we collected so far
pass
@@ -217,7 +224,15 @@ def test_live_streaming_function_call_multiple():
try:
session = self.session
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
@@ -315,7 +330,15 @@ def test_live_streaming_function_call_parallel():
try:
session = self.session
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
@@ -407,7 +430,15 @@ def test_live_streaming_function_call_with_error():
try:
session = self.session
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
@@ -490,7 +521,15 @@ def test_live_streaming_function_call_sync_tool():
try:
session = self.session
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
@@ -582,7 +621,15 @@ def test_live_streaming_simple_streaming_tool():
try:
session = self.session
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
@@ -686,7 +733,15 @@ def test_live_streaming_video_streaming_tool():
try:
session = self.session
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
@@ -794,7 +849,15 @@ def test_live_streaming_stop_streaming_tool():
try:
session = self.session
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
@@ -903,7 +966,15 @@ def test_live_streaming_multiple_streaming_tools():
try:
session = self.session
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass