diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 4fa44caf..05ab2e65 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -20,10 +20,12 @@ import asyncio import copy import inspect import logging +import threading from typing import Any from typing import AsyncGenerator from typing import cast from typing import Optional +from typing import TYPE_CHECKING import uuid from google.genai import types @@ -39,6 +41,9 @@ from ...telemetry import tracer from ...tools.base_tool import BaseTool from ...tools.tool_context import ToolContext +if TYPE_CHECKING: + from ...agents.llm_agent import LlmAgent + AF_FUNCTION_CALL_ID_PREFIX = 'adk-' REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential' @@ -135,117 +140,42 @@ async def handle_function_calls_async( agent = invocation_context.agent if not isinstance(agent, LlmAgent): - return + return None function_calls = function_call_event.get_function_calls() - function_response_events: list[Event] = [] - for function_call in function_calls: - if filters and function_call.id not in filters: - continue - tool, tool_context = _get_tool_and_context( - invocation_context, - function_call_event, - function_call, - tools_dict, - ) + # Filter function calls + filtered_calls = [ + fc for fc in function_calls if not filters or fc.id in filters + ] - with tracer.start_as_current_span(f'execute_tool {tool.name}'): - # Do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - # Make a deep copy to avoid being modified. - function_args = ( - copy.deepcopy(function_call.args) if function_call.args else {} - ) + if not filtered_calls: + return None - # Step 1: Check if plugin before_tool_callback overrides the function - # response. - function_response = ( - await invocation_context.plugin_manager.run_before_tool_callback( - tool=tool, tool_args=function_args, tool_context=tool_context + # Create tasks for parallel execution + tasks = [ + asyncio.create_task( + _execute_single_function_call_async( + invocation_context, + function_call, + tools_dict, + agent, ) ) + for function_call in filtered_calls + ] - # Step 2: If no overrides are provided from the plugins, further run the - # canonical callback. - if function_response is None: - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break + # Wait for all tasks to complete + function_response_events = await asyncio.gather(*tasks) - # Step 3: Otherwise, proceed calling the tool normally. - if function_response is None: - try: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context - ) - except Exception as tool_error: - error_response = await invocation_context.plugin_manager.run_on_tool_error_callback( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - error=tool_error, - ) - if error_response is not None: - function_response = error_response - else: - raise tool_error - - # Step 4: Check if plugin after_tool_callback overrides the function - # response. - altered_function_response = ( - await invocation_context.plugin_manager.run_after_tool_callback( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - result=function_response, - ) - ) - - # Step 5: If no overrides are provided from the plugins, further run the - # canonical after_tool_callbacks. - if altered_function_response is None: - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break - - # Step 6: If alternative response exists from after_tool_callback, use it - # instead of the original function response. - if altered_function_response is not None: - function_response = altered_function_response - - if tool.is_long_running: - # Allow long running function to return None to not provide function - # response. - if not function_response: - continue - - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context - ) - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) - function_response_events.append(function_response_event) + # Filter out None results + function_response_events = [ + event for event in function_response_events if event is not None + ] if not function_response_events: return None + merged_event = merge_parallel_function_response_events( function_response_events ) @@ -262,6 +192,120 @@ async def handle_function_calls_async( return merged_event +async def _execute_single_function_call_async( + invocation_context: InvocationContext, + function_call: types.FunctionCall, + tools_dict: dict[str, BaseTool], + agent: LlmAgent, +) -> Optional[Event]: + """Execute a single function call with thread safety for state modifications.""" + tool, tool_context = _get_tool_and_context( + invocation_context, + function_call, + tools_dict, + ) + + with tracer.start_as_current_span(f'execute_tool {tool.name}'): + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) + + # Step 1: Check if plugin before_tool_callback overrides the function + # response. + function_response = ( + await invocation_context.plugin_manager.run_before_tool_callback( + tool=tool, tool_args=function_args, tool_context=tool_context + ) + ) + + # Step 2: If no overrides are provided from the plugins, further run the + # canonical callback. + if function_response is None: + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break + + # Step 3: Otherwise, proceed calling the tool normally. + if function_response is None: + try: + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + except Exception as tool_error: + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + ) + if error_response is not None: + function_response = error_response + else: + raise tool_error + + # Step 4: Check if plugin after_tool_callback overrides the function + # response. + altered_function_response = ( + await invocation_context.plugin_manager.run_after_tool_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + result=function_response, + ) + ) + + # Step 5: If no overrides are provided from the plugins, further run the + # canonical after_tool_callbacks. + if altered_function_response is None: + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + # Step 6: If alternative response exists from after_tool_callback, use it + # instead of the original function response. + if altered_function_response is not None: + function_response = altered_function_response + + if tool.is_long_running: + # Allow long running function to return None to not provide function + # response. + if not function_response: + return None + + # Note: State deltas are not applied here - they are collected in + # tool_context.actions.state_delta and applied later when the session + # service processes the events + + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) + return function_response_event + + async def handle_function_calls_live( invocation_context: InvocationContext, function_call_event: Event, @@ -273,71 +317,37 @@ async def handle_function_calls_live( agent = cast(LlmAgent, invocation_context.agent) function_calls = function_call_event.get_function_calls() - function_response_events: list[Event] = [] - for function_call in function_calls: - tool, tool_context = _get_tool_and_context( - invocation_context, function_call_event, function_call, tools_dict - ) - with tracer.start_as_current_span(f'execute_tool {tool.name}'): - # Do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - # Make a deep copy to avoid being modified. - function_args = ( - copy.deepcopy(function_call.args) if function_call.args else {} + if not function_calls: + return None + + # Create thread-safe lock for active_streaming_tools modifications + streaming_lock = threading.Lock() + + # Create tasks for parallel execution + tasks = [ + asyncio.create_task( + _execute_single_function_call_live( + invocation_context, + function_call, + tools_dict, + agent, + streaming_lock, + ) ) - function_response = None + for function_call in function_calls + ] - # Handle before_tool_callbacks - iterate through the canonical callback - # list - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break + # Wait for all tasks to complete + function_response_events = await asyncio.gather(*tasks) - if function_response is None: - function_response = await _process_function_live_helper( - tool, tool_context, function_call, function_args, invocation_context - ) - - # Calls after_tool_callback if it exists. - altered_function_response = None - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break - - if altered_function_response is not None: - function_response = altered_function_response - - if tool.is_long_running: - # Allow async function to return None to not provide function response. - if not function_response: - continue - - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context - ) - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) - function_response_events.append(function_response_event) + # Filter out None results + function_response_events = [ + event for event in function_response_events if event is not None + ] if not function_response_events: return None + merged_event = merge_parallel_function_response_events( function_response_events ) @@ -353,8 +363,92 @@ async def handle_function_calls_live( return merged_event +async def _execute_single_function_call_live( + invocation_context: InvocationContext, + function_call: types.FunctionCall, + tools_dict: dict[str, BaseTool], + agent: LlmAgent, + streaming_lock: threading.Lock, +) -> Optional[Event]: + """Execute a single function call for live mode with thread safety.""" + tool, tool_context = _get_tool_and_context( + invocation_context, function_call, tools_dict + ) + with tracer.start_as_current_span(f'execute_tool {tool.name}'): + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) + function_response = None + + # Handle before_tool_callbacks - iterate through the canonical callback + # list + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break + + if function_response is None: + function_response = await _process_function_live_helper( + tool, + tool_context, + function_call, + function_args, + invocation_context, + streaming_lock, + ) + + # Calls after_tool_callback if it exists. + altered_function_response = None + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + if altered_function_response is not None: + function_response = altered_function_response + + if tool.is_long_running: + # Allow async function to return None to not provide function response. + if not function_response: + return None + + # Note: State deltas are not applied here - they are collected in + # tool_context.actions.state_delta and applied later when the session + # service processes the events + + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) + return function_response_event + + async def _process_function_live_helper( - tool, tool_context, function_call, function_args, invocation_context + tool, + tool_context, + function_call, + function_args, + invocation_context, + streaming_lock: threading.Lock, ): function_response = None # Check if this is a stop_streaming function call @@ -363,13 +457,20 @@ async def _process_function_live_helper( and 'function_name' in function_args ): function_name = function_args['function_name'] - active_tasks = invocation_context.active_streaming_tools - if ( - function_name in active_tasks - and active_tasks[function_name].task - and not active_tasks[function_name].task.done() - ): - task = active_tasks[function_name].task + # Thread-safe access to active_streaming_tools + with streaming_lock: + active_tasks = invocation_context.active_streaming_tools + if ( + active_tasks + and function_name in active_tasks + and active_tasks[function_name].task + and not active_tasks[function_name].task.done() + ): + task = active_tasks[function_name].task + else: + task = None + + if task: task.cancel() try: # Wait for the task to be cancelled @@ -377,20 +478,25 @@ async def _process_function_live_helper( except (asyncio.CancelledError, asyncio.TimeoutError): # Log the specific condition if task.cancelled(): - logging.info(f'Task {function_name} was cancelled successfully') + logging.info('Task %s was cancelled successfully', function_name) elif task.done(): - logging.info(f'Task {function_name} completed during cancellation') + logging.info('Task %s completed during cancellation', function_name) else: logging.warning( - f'Task {function_name} might still be running after' - ' cancellation timeout' + 'Task %s might still be running after cancellation timeout', + function_name, ) function_response = { 'status': f'The task is not cancelled yet for {function_name}.' } if not function_response: - # Clean up the reference - active_tasks[function_name].task = None + # Clean up the reference under lock + with streaming_lock: + if ( + invocation_context.active_streaming_tools + and function_name in invocation_context.active_streaming_tools + ): + invocation_context.active_streaming_tools[function_name].task = None function_response = { 'status': f'Successfully stopped streaming function {function_name}' @@ -425,14 +531,19 @@ async def _process_function_live_helper( task = asyncio.create_task( run_tool_and_update_queue(tool, function_args, tool_context) ) - if invocation_context.active_streaming_tools is None: - invocation_context.active_streaming_tools = {} - if tool.name in invocation_context.active_streaming_tools: - invocation_context.active_streaming_tools[tool.name].task = task - else: - invocation_context.active_streaming_tools[tool.name] = ( - ActiveStreamingTool(task=task) - ) + + # Register streaming tool using original logic + with streaming_lock: + if invocation_context.active_streaming_tools is None: + invocation_context.active_streaming_tools = {} + + if tool.name in invocation_context.active_streaming_tools: + invocation_context.active_streaming_tools[tool.name].task = task + else: + invocation_context.active_streaming_tools[tool.name] = ( + ActiveStreamingTool(task=task) + ) + # Immediately return a pending response. # This is required by current live model. function_response = { @@ -450,7 +561,6 @@ async def _process_function_live_helper( def _get_tool_and_context( invocation_context: InvocationContext, - function_call_event: Event, function_call: types.FunctionCall, tools_dict: dict[str, BaseTool], ): @@ -552,7 +662,7 @@ def merge_parallel_function_response_events( base_event = function_response_events[0] # Merge actions from all events - merged_actions_data = {} + merged_actions_data: dict[str, Any] = {} for event in function_response_events: if event.actions: # Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate. diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index df6fcb3c..dbaf3c8c 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from typing import Any from typing import Callable @@ -676,3 +677,222 @@ def test_shallow_vs_deep_copy_demonstration(): deep_copy['nested_dict']['inner']['value'] == 'modified' ) # Copy is modified assert 'new_item' in deep_copy['list_param'] # Copy is modified + + +@pytest.mark.asyncio +async def test_parallel_function_execution_timing(): + """Test that multiple function calls are executed in parallel, not sequentially.""" + import time + + execution_order = [] + execution_times = {} + + async def slow_function_1(delay: float = 0.1) -> dict: + start_time = time.time() + execution_order.append('start_1') + await asyncio.sleep(delay) + end_time = time.time() + execution_times['func_1'] = (start_time, end_time) + execution_order.append('end_1') + return {'result': 'function_1_result'} + + async def slow_function_2(delay: float = 0.1) -> dict: + start_time = time.time() + execution_order.append('start_2') + await asyncio.sleep(delay) + end_time = time.time() + execution_times['func_2'] = (start_time, end_time) + execution_order.append('end_2') + return {'result': 'function_2_result'} + + # Create function calls + function_calls = [ + types.Part.from_function_call( + name='slow_function_1', args={'delay': 0.1} + ), + types.Part.from_function_call( + name='slow_function_2', args={'delay': 0.1} + ), + ] + + function_responses = [ + types.Part.from_function_response( + name='slow_function_1', response={'result': 'function_1_result'} + ), + types.Part.from_function_response( + name='slow_function_2', response={'result': 'function_2_result'} + ), + ] + + responses: list[types.Content] = [ + function_calls, + 'response1', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[slow_function_1, slow_function_2], + ) + runner = testing_utils.TestInMemoryRunner(agent) + + # Measure total execution time + start_time = time.time() + events = await runner.run_async_with_new_session('test') + total_time = time.time() - start_time + + # Verify parallel execution by checking execution order + # In parallel execution, both functions should start before either finishes + assert 'start_1' in execution_order + assert 'start_2' in execution_order + assert 'end_1' in execution_order + assert 'end_2' in execution_order + + # Verify both functions started within a reasonable time window + func_1_start, func_1_end = execution_times['func_1'] + func_2_start, func_2_end = execution_times['func_2'] + + # Functions should start at approximately the same time (within 10ms) + start_time_diff = abs(func_1_start - func_2_start) + assert ( + start_time_diff < 0.01 + ), f'Functions started too far apart: {start_time_diff}s' + + # Total execution time should be closer to 0.1s (parallel) than 0.2s (sequential) + # Allow some overhead for task creation and synchronization + assert ( + total_time < 0.15 + ), f'Execution took too long: {total_time}s, expected < 0.15s' + + # Verify the results are correct + assert testing_utils.simplify_events(events) == [ + ('test_agent', function_calls), + ('test_agent', function_responses), + ('test_agent', 'response1'), + ] + + +@pytest.mark.asyncio +async def test_parallel_state_modifications_thread_safety(): + """Test that parallel function calls modifying state are thread-safe.""" + state_modifications = [] + + def modify_state_1(tool_context: ToolContext) -> dict: + # Track when this function modifies state + current_state = dict(tool_context.state.to_dict()) + state_modifications.append(('func_1_start', current_state)) + + tool_context.state['counter'] = tool_context.state.get('counter', 0) + 1 + tool_context.state['func_1_executed'] = True + + final_state = dict(tool_context.state.to_dict()) + state_modifications.append(('func_1_end', final_state)) + return {'result': 'modified_state_1'} + + def modify_state_2(tool_context: ToolContext) -> dict: + # Track when this function modifies state + current_state = dict(tool_context.state.to_dict()) + state_modifications.append(('func_2_start', current_state)) + + tool_context.state['counter'] = tool_context.state.get('counter', 0) + 1 + tool_context.state['func_2_executed'] = True + + final_state = dict(tool_context.state.to_dict()) + state_modifications.append(('func_2_end', final_state)) + return {'result': 'modified_state_2'} + + # Create function calls + function_calls = [ + types.Part.from_function_call(name='modify_state_1', args={}), + types.Part.from_function_call(name='modify_state_2', args={}), + ] + + responses: list[types.Content] = [ + function_calls, + 'response1', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[modify_state_1, modify_state_2], + ) + runner = testing_utils.TestInMemoryRunner(agent) + events = await runner.run_async_with_new_session('test') + + # Verify the parallel execution worked correctly by checking the events + # The function response event should have the merged state_delta + function_response_event = events[ + 1 + ] # Second event should be the function response + assert function_response_event.actions.state_delta['counter'] == 2 + assert function_response_event.actions.state_delta['func_1_executed'] is True + assert function_response_event.actions.state_delta['func_2_executed'] is True + + # Verify both functions were called + assert len(state_modifications) == 4 # 2 functions × 2 events each + + # Extract function names from modifications + func_names = [mod[0] for mod in state_modifications] + assert 'func_1_start' in func_names + assert 'func_1_end' in func_names + assert 'func_2_start' in func_names + assert 'func_2_end' in func_names + + +@pytest.mark.asyncio +async def test_parallel_mixed_sync_async_functions(): + """Test parallel execution with mix of sync and async functions.""" + execution_log = [] + + def sync_function(value: int) -> dict: + execution_log.append(f'sync_start_{value}') + # Simulate some work + import time + + time.sleep(0.05) # 50ms + execution_log.append(f'sync_end_{value}') + return {'result': f'sync_{value}'} + + async def async_function(value: int) -> dict: + execution_log.append(f'async_start_{value}') + await asyncio.sleep(0.05) # 50ms + execution_log.append(f'async_end_{value}') + return {'result': f'async_{value}'} + + # Create function calls + function_calls = [ + types.Part.from_function_call(name='sync_function', args={'value': 1}), + types.Part.from_function_call(name='async_function', args={'value': 2}), + types.Part.from_function_call(name='sync_function', args={'value': 3}), + ] + + responses: list[types.Content] = [function_calls, 'response1'] + mock_model = testing_utils.MockModel.create(responses=responses) + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[sync_function, async_function], + ) + runner = testing_utils.TestInMemoryRunner(agent) + + import time + + start_time = time.time() + events = await runner.run_async_with_new_session('test') + total_time = time.time() - start_time + + # Should complete in less than 120ms (parallel) rather than 150ms (sequential) + # Allow for overhead from task creation and synchronization + assert total_time < 0.12, f'Execution took {total_time}s, expected < 0.12s' + + # Verify all functions were called + assert 'sync_start_1' in execution_log + assert 'sync_end_1' in execution_log + assert 'async_start_2' in execution_log + assert 'async_end_2' in execution_log + assert 'sync_start_3' in execution_log + assert 'sync_end_3' in execution_log diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index dd0e6d5c..ac827a45 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -110,8 +110,15 @@ def test_live_streaming_function_call_single(): try: session = self.session - # Add timeout to prevent hanging - asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + # Create a new event loop to avoid nested event loop issues + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() except (asyncio.TimeoutError, asyncio.CancelledError): # Return whatever we collected so far pass @@ -217,7 +224,15 @@ def test_live_streaming_function_call_multiple(): try: session = self.session - asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + # Create a new event loop to avoid nested event loop issues + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -315,7 +330,15 @@ def test_live_streaming_function_call_parallel(): try: session = self.session - asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + # Create a new event loop to avoid nested event loop issues + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -407,7 +430,15 @@ def test_live_streaming_function_call_with_error(): try: session = self.session - asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + # Create a new event loop to avoid nested event loop issues + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -490,7 +521,15 @@ def test_live_streaming_function_call_sync_tool(): try: session = self.session - asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + # Create a new event loop to avoid nested event loop issues + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -582,7 +621,15 @@ def test_live_streaming_simple_streaming_tool(): try: session = self.session - asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + # Create a new event loop to avoid nested event loop issues + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -686,7 +733,15 @@ def test_live_streaming_video_streaming_tool(): try: session = self.session - asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + # Create a new event loop to avoid nested event loop issues + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -794,7 +849,15 @@ def test_live_streaming_stop_streaming_tool(): try: session = self.session - asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + # Create a new event loop to avoid nested event loop issues + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -903,7 +966,15 @@ def test_live_streaming_multiple_streaming_tools(): try: session = self.session - asyncio.run(asyncio.wait_for(consume_responses(session), timeout=5.0)) + # Create a new event loop to avoid nested event loop issues + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() except (asyncio.TimeoutError, asyncio.CancelledError): pass