You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support parallel execution of parallel function calls
PiperOrigin-RevId: 790182046
This commit is contained in:
committed by
Copybara-Service
parent
7556ebc76a
commit
57cd41f424
File diff suppressed because it is too large
Load Diff
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
|
||||
@@ -676,3 +677,222 @@ def test_shallow_vs_deep_copy_demonstration():
|
||||
deep_copy['nested_dict']['inner']['value'] == 'modified'
|
||||
) # Copy is modified
|
||||
assert 'new_item' in deep_copy['list_param'] # Copy is modified
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_function_execution_timing():
|
||||
"""Test that multiple function calls are executed in parallel, not sequentially."""
|
||||
import time
|
||||
|
||||
execution_order = []
|
||||
execution_times = {}
|
||||
|
||||
async def slow_function_1(delay: float = 0.1) -> dict:
|
||||
start_time = time.time()
|
||||
execution_order.append('start_1')
|
||||
await asyncio.sleep(delay)
|
||||
end_time = time.time()
|
||||
execution_times['func_1'] = (start_time, end_time)
|
||||
execution_order.append('end_1')
|
||||
return {'result': 'function_1_result'}
|
||||
|
||||
async def slow_function_2(delay: float = 0.1) -> dict:
|
||||
start_time = time.time()
|
||||
execution_order.append('start_2')
|
||||
await asyncio.sleep(delay)
|
||||
end_time = time.time()
|
||||
execution_times['func_2'] = (start_time, end_time)
|
||||
execution_order.append('end_2')
|
||||
return {'result': 'function_2_result'}
|
||||
|
||||
# Create function calls
|
||||
function_calls = [
|
||||
types.Part.from_function_call(
|
||||
name='slow_function_1', args={'delay': 0.1}
|
||||
),
|
||||
types.Part.from_function_call(
|
||||
name='slow_function_2', args={'delay': 0.1}
|
||||
),
|
||||
]
|
||||
|
||||
function_responses = [
|
||||
types.Part.from_function_response(
|
||||
name='slow_function_1', response={'result': 'function_1_result'}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='slow_function_2', response={'result': 'function_2_result'}
|
||||
),
|
||||
]
|
||||
|
||||
responses: list[types.Content] = [
|
||||
function_calls,
|
||||
'response1',
|
||||
]
|
||||
mock_model = testing_utils.MockModel.create(responses=responses)
|
||||
|
||||
agent = Agent(
|
||||
name='test_agent',
|
||||
model=mock_model,
|
||||
tools=[slow_function_1, slow_function_2],
|
||||
)
|
||||
runner = testing_utils.TestInMemoryRunner(agent)
|
||||
|
||||
# Measure total execution time
|
||||
start_time = time.time()
|
||||
events = await runner.run_async_with_new_session('test')
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Verify parallel execution by checking execution order
|
||||
# In parallel execution, both functions should start before either finishes
|
||||
assert 'start_1' in execution_order
|
||||
assert 'start_2' in execution_order
|
||||
assert 'end_1' in execution_order
|
||||
assert 'end_2' in execution_order
|
||||
|
||||
# Verify both functions started within a reasonable time window
|
||||
func_1_start, func_1_end = execution_times['func_1']
|
||||
func_2_start, func_2_end = execution_times['func_2']
|
||||
|
||||
# Functions should start at approximately the same time (within 10ms)
|
||||
start_time_diff = abs(func_1_start - func_2_start)
|
||||
assert (
|
||||
start_time_diff < 0.01
|
||||
), f'Functions started too far apart: {start_time_diff}s'
|
||||
|
||||
# Total execution time should be closer to 0.1s (parallel) than 0.2s (sequential)
|
||||
# Allow some overhead for task creation and synchronization
|
||||
assert (
|
||||
total_time < 0.15
|
||||
), f'Execution took too long: {total_time}s, expected < 0.15s'
|
||||
|
||||
# Verify the results are correct
|
||||
assert testing_utils.simplify_events(events) == [
|
||||
('test_agent', function_calls),
|
||||
('test_agent', function_responses),
|
||||
('test_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_state_modifications_thread_safety():
|
||||
"""Test that parallel function calls modifying state are thread-safe."""
|
||||
state_modifications = []
|
||||
|
||||
def modify_state_1(tool_context: ToolContext) -> dict:
|
||||
# Track when this function modifies state
|
||||
current_state = dict(tool_context.state.to_dict())
|
||||
state_modifications.append(('func_1_start', current_state))
|
||||
|
||||
tool_context.state['counter'] = tool_context.state.get('counter', 0) + 1
|
||||
tool_context.state['func_1_executed'] = True
|
||||
|
||||
final_state = dict(tool_context.state.to_dict())
|
||||
state_modifications.append(('func_1_end', final_state))
|
||||
return {'result': 'modified_state_1'}
|
||||
|
||||
def modify_state_2(tool_context: ToolContext) -> dict:
|
||||
# Track when this function modifies state
|
||||
current_state = dict(tool_context.state.to_dict())
|
||||
state_modifications.append(('func_2_start', current_state))
|
||||
|
||||
tool_context.state['counter'] = tool_context.state.get('counter', 0) + 1
|
||||
tool_context.state['func_2_executed'] = True
|
||||
|
||||
final_state = dict(tool_context.state.to_dict())
|
||||
state_modifications.append(('func_2_end', final_state))
|
||||
return {'result': 'modified_state_2'}
|
||||
|
||||
# Create function calls
|
||||
function_calls = [
|
||||
types.Part.from_function_call(name='modify_state_1', args={}),
|
||||
types.Part.from_function_call(name='modify_state_2', args={}),
|
||||
]
|
||||
|
||||
responses: list[types.Content] = [
|
||||
function_calls,
|
||||
'response1',
|
||||
]
|
||||
mock_model = testing_utils.MockModel.create(responses=responses)
|
||||
|
||||
agent = Agent(
|
||||
name='test_agent',
|
||||
model=mock_model,
|
||||
tools=[modify_state_1, modify_state_2],
|
||||
)
|
||||
runner = testing_utils.TestInMemoryRunner(agent)
|
||||
events = await runner.run_async_with_new_session('test')
|
||||
|
||||
# Verify the parallel execution worked correctly by checking the events
|
||||
# The function response event should have the merged state_delta
|
||||
function_response_event = events[
|
||||
1
|
||||
] # Second event should be the function response
|
||||
assert function_response_event.actions.state_delta['counter'] == 2
|
||||
assert function_response_event.actions.state_delta['func_1_executed'] is True
|
||||
assert function_response_event.actions.state_delta['func_2_executed'] is True
|
||||
|
||||
# Verify both functions were called
|
||||
assert len(state_modifications) == 4 # 2 functions × 2 events each
|
||||
|
||||
# Extract function names from modifications
|
||||
func_names = [mod[0] for mod in state_modifications]
|
||||
assert 'func_1_start' in func_names
|
||||
assert 'func_1_end' in func_names
|
||||
assert 'func_2_start' in func_names
|
||||
assert 'func_2_end' in func_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_mixed_sync_async_functions():
|
||||
"""Test parallel execution with mix of sync and async functions."""
|
||||
execution_log = []
|
||||
|
||||
def sync_function(value: int) -> dict:
|
||||
execution_log.append(f'sync_start_{value}')
|
||||
# Simulate some work
|
||||
import time
|
||||
|
||||
time.sleep(0.05) # 50ms
|
||||
execution_log.append(f'sync_end_{value}')
|
||||
return {'result': f'sync_{value}'}
|
||||
|
||||
async def async_function(value: int) -> dict:
|
||||
execution_log.append(f'async_start_{value}')
|
||||
await asyncio.sleep(0.05) # 50ms
|
||||
execution_log.append(f'async_end_{value}')
|
||||
return {'result': f'async_{value}'}
|
||||
|
||||
# Create function calls
|
||||
function_calls = [
|
||||
types.Part.from_function_call(name='sync_function', args={'value': 1}),
|
||||
types.Part.from_function_call(name='async_function', args={'value': 2}),
|
||||
types.Part.from_function_call(name='sync_function', args={'value': 3}),
|
||||
]
|
||||
|
||||
responses: list[types.Content] = [function_calls, 'response1']
|
||||
mock_model = testing_utils.MockModel.create(responses=responses)
|
||||
|
||||
agent = Agent(
|
||||
name='test_agent',
|
||||
model=mock_model,
|
||||
tools=[sync_function, async_function],
|
||||
)
|
||||
runner = testing_utils.TestInMemoryRunner(agent)
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
events = await runner.run_async_with_new_session('test')
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Should complete in less than 120ms (parallel) rather than 150ms (sequential)
|
||||
# Allow for overhead from task creation and synchronization
|
||||
assert total_time < 0.12, f'Execution took {total_time}s, expected < 0.12s'
|
||||
|
||||
# Verify all functions were called
|
||||
assert 'sync_start_1' in execution_log
|
||||
assert 'sync_end_1' in execution_log
|
||||
assert 'async_start_2' in execution_log
|
||||
assert 'async_end_2' in execution_log
|
||||
assert 'sync_start_3' in execution_log
|
||||
assert 'sync_end_3' in execution_log
|
||||
|
||||
@@ -110,8 +110,15 @@ def test_live_streaming_function_call_single():
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Add timeout to prevent hanging
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
# Return whatever we collected so far
|
||||
pass
|
||||
@@ -217,7 +224,15 @@ def test_live_streaming_function_call_multiple():
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
@@ -315,7 +330,15 @@ def test_live_streaming_function_call_parallel():
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
@@ -407,7 +430,15 @@ def test_live_streaming_function_call_with_error():
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
@@ -490,7 +521,15 @@ def test_live_streaming_function_call_sync_tool():
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
@@ -582,7 +621,15 @@ def test_live_streaming_simple_streaming_tool():
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
@@ -686,7 +733,15 @@ def test_live_streaming_video_streaming_tool():
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
@@ -794,7 +849,15 @@ def test_live_streaming_stop_streaming_tool():
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
@@ -903,7 +966,15 @@ def test_live_streaming_multiple_streaming_tools():
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0))
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user