You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Improve asyncio loop handling and test cleanup
This CL enhances asyncio event loop management and test isolation. - **BigQuery Analytics Plugin:** Ensure the asyncio event loop is consistently closed within the BigQuery analytics plugin. This prevents potential resource leaks. Add checks to handle potential deadlocks in Python 3.13+ when creating loops during interpreter shutdown. - **Test Thread Pool Cleanup:** Introduce a pytest fixture (`cleanup_thread_pools`) to automatically shut down and clear all tool-related thread pools after each test run in `test_functions_thread_pool.py`. This improves test isolation and prevents order-dependent test failures. - **Streaming Test Loop Restoration:** Refactor event loop handling in `test_streaming.py`. A new `_run_with_loop` method is introduced in the custom test runners to create a temporary event loop for each test execution, run the coroutine, and crucially, restore the original event loop afterwards. This prevents tests from interfering with each other's loop state. - **Resource Closure:** Ensure services are closed properly in tests by adding `await service.close()` in `test_service_factory.py` and using `async with session_service` in `test_session_service.py`. PiperOrigin-RevId: 863305565
This commit is contained in:
committed by
Copybara-Service
parent
585ebfdac7
commit
00aba2d884
@@ -1695,10 +1695,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
except Exception as e:
|
||||
logger.error("Rescue flush failed: %s", e)
|
||||
|
||||
# In Python 3.13+, creating a new event loop during interpreter shutdown
|
||||
# (inside atexit) can cause deadlocks if the threading module is already
|
||||
# shutting down. We attempt to run only if safe.
|
||||
try:
|
||||
# Check if we can safely create a loop
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(rescue_flush())
|
||||
loop.close()
|
||||
try:
|
||||
loop.run_until_complete(rescue_flush())
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.error("Failed to run rescue loop: %s", e)
|
||||
except ReferenceError:
|
||||
|
||||
@@ -168,7 +168,8 @@ async def test_create_session_service_respects_app_name_mapping(
|
||||
assert (agent_dir / ".adk" / "session.db").exists()
|
||||
|
||||
|
||||
def test_create_session_service_fallbacks_to_database(
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_service_fallbacks_to_database(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True)
|
||||
@@ -189,6 +190,7 @@ def test_create_session_service_fallbacks_to_database(
|
||||
agents_dir=str(tmp_path),
|
||||
echo=True,
|
||||
)
|
||||
await service.close()
|
||||
|
||||
|
||||
def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch):
|
||||
|
||||
@@ -32,6 +32,17 @@ import pytest
|
||||
from ... import testing_utils
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_thread_pools():
|
||||
yield
|
||||
from google.adk.flows.llm_flows import functions
|
||||
|
||||
# Shutdown all pools
|
||||
for pool in functions._TOOL_THREAD_POOLS.values():
|
||||
pool.shutdown(wait=False)
|
||||
functions._TOOL_THREAD_POOLS.clear()
|
||||
|
||||
|
||||
class TestIsSyncTool:
|
||||
"""Tests for the _is_sync_tool helper function."""
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -509,60 +509,61 @@ async def test_append_event_to_stale_session():
|
||||
service_type=SessionServiceType.DATABASE
|
||||
)
|
||||
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
current_time = datetime.now().astimezone(timezone.utc).timestamp()
|
||||
async with session_service:
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
current_time = datetime.now().astimezone(timezone.utc).timestamp()
|
||||
|
||||
original_session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event1 = Event(
|
||||
invocation_id='inv1',
|
||||
author='user',
|
||||
timestamp=current_time + 1,
|
||||
actions=EventActions(state_delta={'sk1': 'v1'}),
|
||||
)
|
||||
await session_service.append_event(original_session, event1)
|
||||
original_session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event1 = Event(
|
||||
invocation_id='inv1',
|
||||
author='user',
|
||||
timestamp=current_time + 1,
|
||||
actions=EventActions(state_delta={'sk1': 'v1'}),
|
||||
)
|
||||
await session_service.append_event(original_session, event1)
|
||||
|
||||
updated_session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=original_session.id
|
||||
)
|
||||
event2 = Event(
|
||||
invocation_id='inv2',
|
||||
author='user',
|
||||
timestamp=current_time + 2,
|
||||
actions=EventActions(state_delta={'sk2': 'v2'}),
|
||||
)
|
||||
await session_service.append_event(updated_session, event2)
|
||||
updated_session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=original_session.id
|
||||
)
|
||||
event2 = Event(
|
||||
invocation_id='inv2',
|
||||
author='user',
|
||||
timestamp=current_time + 2,
|
||||
actions=EventActions(state_delta={'sk2': 'v2'}),
|
||||
)
|
||||
await session_service.append_event(updated_session, event2)
|
||||
|
||||
# original_session is now stale
|
||||
assert original_session.last_update_time < updated_session.last_update_time
|
||||
assert len(original_session.events) == 1
|
||||
assert 'sk2' not in original_session.state
|
||||
# original_session is now stale
|
||||
assert original_session.last_update_time < updated_session.last_update_time
|
||||
assert len(original_session.events) == 1
|
||||
assert 'sk2' not in original_session.state
|
||||
|
||||
# Appending another event to stale original_session
|
||||
event3 = Event(
|
||||
invocation_id='inv3',
|
||||
author='user',
|
||||
timestamp=current_time + 3,
|
||||
actions=EventActions(state_delta={'sk3': 'v3'}),
|
||||
)
|
||||
await session_service.append_event(original_session, event3)
|
||||
# Appending another event to stale original_session
|
||||
event3 = Event(
|
||||
invocation_id='inv3',
|
||||
author='user',
|
||||
timestamp=current_time + 3,
|
||||
actions=EventActions(state_delta={'sk3': 'v3'}),
|
||||
)
|
||||
await session_service.append_event(original_session, event3)
|
||||
|
||||
# If we fetch session from DB, it should contain all 3 events and all state
|
||||
# changes.
|
||||
session_final = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=original_session.id
|
||||
)
|
||||
assert len(session_final.events) == 3
|
||||
assert session_final.state.get('sk1') == 'v1'
|
||||
assert session_final.state.get('sk2') == 'v2'
|
||||
assert session_final.state.get('sk3') == 'v3'
|
||||
assert [e.invocation_id for e in session_final.events] == [
|
||||
'inv1',
|
||||
'inv2',
|
||||
'inv3',
|
||||
]
|
||||
# If we fetch session from DB, it should contain all 3 events and all state
|
||||
# changes.
|
||||
session_final = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=original_session.id
|
||||
)
|
||||
assert len(session_final.events) == 3
|
||||
assert session_final.state.get('sk1') == 'v1'
|
||||
assert session_final.state.get('sk2') == 'v2'
|
||||
assert session_final.state.get('sk3') == 'v3'
|
||||
assert [e.invocation_id for e in session_final.events] == [
|
||||
'inv1',
|
||||
'inv2',
|
||||
'inv3',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -88,6 +88,22 @@ def test_live_streaming_function_call_single():
|
||||
# Create a custom runner class that collects all events
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -108,20 +124,9 @@ def test_live_streaming_function_call_single():
|
||||
if len(collected_responses) >= 3:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
# Return whatever we collected so far
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -203,6 +208,22 @@ def test_live_streaming_function_call_multiple():
|
||||
# Use the custom runner
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -222,19 +243,9 @@ def test_live_streaming_function_call_multiple():
|
||||
if len(collected_responses) >= 3:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -309,6 +320,22 @@ def test_live_streaming_function_call_parallel():
|
||||
# Use the custom runner
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -328,19 +355,9 @@ def test_live_streaming_function_call_parallel():
|
||||
if len(collected_responses) >= 3:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -409,6 +426,22 @@ def test_live_streaming_function_call_with_error():
|
||||
# Use the custom runner
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -428,19 +461,9 @@ def test_live_streaming_function_call_with_error():
|
||||
if len(collected_responses) >= 3:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -500,6 +523,22 @@ def test_live_streaming_function_call_sync_tool():
|
||||
# Use the custom runner
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -519,19 +558,9 @@ def test_live_streaming_function_call_sync_tool():
|
||||
if len(collected_responses) >= 3:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -600,6 +629,22 @@ def test_live_streaming_simple_streaming_tool():
|
||||
# Use the custom runner
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -619,19 +664,9 @@ def test_live_streaming_simple_streaming_tool():
|
||||
if len(collected_responses) >= 3:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -712,6 +747,22 @@ def test_live_streaming_video_streaming_tool():
|
||||
# Use the custom runner
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -731,19 +782,9 @@ def test_live_streaming_video_streaming_tool():
|
||||
if len(collected_responses) >= 3:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -828,6 +869,22 @@ def test_live_streaming_stop_streaming_tool():
|
||||
# Use the custom runner
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -847,19 +904,9 @@ def test_live_streaming_stop_streaming_tool():
|
||||
if len(collected_responses) >= 3:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -945,6 +992,22 @@ def test_live_streaming_multiple_streaming_tools():
|
||||
# Use the custom runner
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -964,19 +1027,9 @@ def test_live_streaming_multiple_streaming_tools():
|
||||
if len(collected_responses) >= 3:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
# Create a new event loop to avoid nested event loop issues
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -1055,6 +1108,22 @@ def test_live_streaming_buffered_function_call_yielded_during_transcription():
|
||||
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -1074,18 +1143,9 @@ def test_live_streaming_buffered_function_call_yielded_during_transcription():
|
||||
if len(collected_responses) >= 5:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
return collected_responses
|
||||
|
||||
@@ -1141,6 +1201,22 @@ def test_live_streaming_text_content_persisted_in_session():
|
||||
|
||||
class CustomTestRunner(testing_utils.InMemoryRunner):
|
||||
|
||||
def _run_with_loop(self, coro):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
old_loop = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def run_live_and_get_session(
|
||||
self,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
@@ -1159,24 +1235,15 @@ def test_live_streaming_text_content_persisted_in_session():
|
||||
if len(collected_responses) >= 1:
|
||||
return
|
||||
|
||||
try:
|
||||
session = self.session
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(consume_responses(session), timeout=5.0)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
self._run_with_loop(
|
||||
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
|
||||
)
|
||||
|
||||
# Get the updated session
|
||||
updated_session = self.runner.session_service.get_session_sync(
|
||||
app_name=self.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
user_id=self.session.user_id,
|
||||
session_id=self.session.id,
|
||||
)
|
||||
return collected_responses, updated_session
|
||||
|
||||
|
||||
Reference in New Issue
Block a user