feat: Improve asyncio loop handling and test cleanup

This CL enhances asyncio event loop management and test isolation.

-   **BigQuery Analytics Plugin:** Ensure the asyncio event loop is consistently closed within the BigQuery analytics plugin. This prevents potential resource leaks. Add checks to handle potential deadlocks in Python 3.13+ when creating loops during interpreter shutdown.
-   **Test Thread Pool Cleanup:** Introduce a pytest fixture (`cleanup_thread_pools`) to automatically shut down and clear all tool-related thread pools after each test run in `test_functions_thread_pool.py`. This improves test isolation and prevents order-dependent test failures.
-   **Streaming Test Loop Restoration:** Refactor event loop handling in `test_streaming.py`. A new `_run_with_loop` method is introduced in the custom test runners to create a temporary event loop for each test execution, run the coroutine, and crucially, restore the original event loop afterwards. This prevents tests from interfering with each other's loop state.
-   **Resource Closure:** Ensure services are closed properly in tests by adding `await service.close()` in `test_service_factory.py` and using `async with session_service` in `test_session_service.py`.

PiperOrigin-RevId: 863305565
This commit is contained in:
Google Team Member
2026-01-30 10:46:37 -08:00
committed by Copybara-Service
parent 585ebfdac7
commit 00aba2d884
6 changed files with 2334 additions and 2241 deletions
@@ -1695,10 +1695,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
except Exception as e:
logger.error("Rescue flush failed: %s", e)
# In Python 3.13+, creating a new event loop during interpreter shutdown
# (inside atexit) can cause deadlocks if the threading module is already
# shutting down. We attempt to run only if safe.
try:
# Check if we can safely create a loop
loop = asyncio.new_event_loop()
loop.run_until_complete(rescue_flush())
loop.close()
try:
loop.run_until_complete(rescue_flush())
finally:
loop.close()
except Exception as e:
logger.error("Failed to run rescue loop: %s", e)
except ReferenceError:
@@ -168,7 +168,8 @@ async def test_create_session_service_respects_app_name_mapping(
assert (agent_dir / ".adk" / "session.db").exists()
def test_create_session_service_fallbacks_to_database(
@pytest.mark.asyncio
async def test_create_session_service_fallbacks_to_database(
tmp_path: Path, monkeypatch
):
registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True)
@@ -189,6 +190,7 @@ def test_create_session_service_fallbacks_to_database(
agents_dir=str(tmp_path),
echo=True,
)
await service.close()
def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch):
@@ -32,6 +32,17 @@ import pytest
from ... import testing_utils
@pytest.fixture(autouse=True)
def cleanup_thread_pools():
yield
from google.adk.flows.llm_flows import functions
# Shutdown all pools
for pool in functions._TOOL_THREAD_POOLS.values():
pool.shutdown(wait=False)
functions._TOOL_THREAD_POOLS.clear()
class TestIsSyncTool:
"""Tests for the _is_sync_tool helper function."""
File diff suppressed because it is too large Load Diff
@@ -509,60 +509,61 @@ async def test_append_event_to_stale_session():
service_type=SessionServiceType.DATABASE
)
app_name = 'my_app'
user_id = 'user'
current_time = datetime.now().astimezone(timezone.utc).timestamp()
async with session_service:
app_name = 'my_app'
user_id = 'user'
current_time = datetime.now().astimezone(timezone.utc).timestamp()
original_session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
event1 = Event(
invocation_id='inv1',
author='user',
timestamp=current_time + 1,
actions=EventActions(state_delta={'sk1': 'v1'}),
)
await session_service.append_event(original_session, event1)
original_session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
event1 = Event(
invocation_id='inv1',
author='user',
timestamp=current_time + 1,
actions=EventActions(state_delta={'sk1': 'v1'}),
)
await session_service.append_event(original_session, event1)
updated_session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=original_session.id
)
event2 = Event(
invocation_id='inv2',
author='user',
timestamp=current_time + 2,
actions=EventActions(state_delta={'sk2': 'v2'}),
)
await session_service.append_event(updated_session, event2)
updated_session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=original_session.id
)
event2 = Event(
invocation_id='inv2',
author='user',
timestamp=current_time + 2,
actions=EventActions(state_delta={'sk2': 'v2'}),
)
await session_service.append_event(updated_session, event2)
# original_session is now stale
assert original_session.last_update_time < updated_session.last_update_time
assert len(original_session.events) == 1
assert 'sk2' not in original_session.state
# original_session is now stale
assert original_session.last_update_time < updated_session.last_update_time
assert len(original_session.events) == 1
assert 'sk2' not in original_session.state
# Appending another event to stale original_session
event3 = Event(
invocation_id='inv3',
author='user',
timestamp=current_time + 3,
actions=EventActions(state_delta={'sk3': 'v3'}),
)
await session_service.append_event(original_session, event3)
# Appending another event to stale original_session
event3 = Event(
invocation_id='inv3',
author='user',
timestamp=current_time + 3,
actions=EventActions(state_delta={'sk3': 'v3'}),
)
await session_service.append_event(original_session, event3)
# If we fetch session from DB, it should contain all 3 events and all state
# changes.
session_final = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=original_session.id
)
assert len(session_final.events) == 3
assert session_final.state.get('sk1') == 'v1'
assert session_final.state.get('sk2') == 'v2'
assert session_final.state.get('sk3') == 'v3'
assert [e.invocation_id for e in session_final.events] == [
'inv1',
'inv2',
'inv3',
]
# If we fetch session from DB, it should contain all 3 events and all state
# changes.
session_final = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=original_session.id
)
assert len(session_final.events) == 3
assert session_final.state.get('sk1') == 'v1'
assert session_final.state.get('sk2') == 'v2'
assert session_final.state.get('sk3') == 'v3'
assert [e.invocation_id for e in session_final.events] == [
'inv1',
'inv2',
'inv3',
]
@pytest.mark.asyncio
+211 -144
View File
@@ -88,6 +88,22 @@ def test_live_streaming_function_call_single():
# Create a custom runner class that collects all events
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -108,20 +124,9 @@ def test_live_streaming_function_call_single():
if len(collected_responses) >= 3:
return
try:
session = self.session
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
# Return whatever we collected so far
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -203,6 +208,22 @@ def test_live_streaming_function_call_multiple():
# Use the custom runner
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -222,19 +243,9 @@ def test_live_streaming_function_call_multiple():
if len(collected_responses) >= 3:
return
try:
session = self.session
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -309,6 +320,22 @@ def test_live_streaming_function_call_parallel():
# Use the custom runner
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -328,19 +355,9 @@ def test_live_streaming_function_call_parallel():
if len(collected_responses) >= 3:
return
try:
session = self.session
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -409,6 +426,22 @@ def test_live_streaming_function_call_with_error():
# Use the custom runner
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -428,19 +461,9 @@ def test_live_streaming_function_call_with_error():
if len(collected_responses) >= 3:
return
try:
session = self.session
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -500,6 +523,22 @@ def test_live_streaming_function_call_sync_tool():
# Use the custom runner
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -519,19 +558,9 @@ def test_live_streaming_function_call_sync_tool():
if len(collected_responses) >= 3:
return
try:
session = self.session
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -600,6 +629,22 @@ def test_live_streaming_simple_streaming_tool():
# Use the custom runner
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -619,19 +664,9 @@ def test_live_streaming_simple_streaming_tool():
if len(collected_responses) >= 3:
return
try:
session = self.session
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -712,6 +747,22 @@ def test_live_streaming_video_streaming_tool():
# Use the custom runner
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -731,19 +782,9 @@ def test_live_streaming_video_streaming_tool():
if len(collected_responses) >= 3:
return
try:
session = self.session
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -828,6 +869,22 @@ def test_live_streaming_stop_streaming_tool():
# Use the custom runner
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -847,19 +904,9 @@ def test_live_streaming_stop_streaming_tool():
if len(collected_responses) >= 3:
return
try:
session = self.session
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -945,6 +992,22 @@ def test_live_streaming_multiple_streaming_tools():
# Use the custom runner
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -964,19 +1027,9 @@ def test_live_streaming_multiple_streaming_tools():
if len(collected_responses) >= 3:
return
try:
session = self.session
# Create a new event loop to avoid nested event loop issues
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -1055,6 +1108,22 @@ def test_live_streaming_buffered_function_call_yielded_during_transcription():
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live(
self,
live_request_queue: LiveRequestQueue,
@@ -1074,18 +1143,9 @@ def test_live_streaming_buffered_function_call_yielded_during_transcription():
if len(collected_responses) >= 5:
return
try:
session = self.session
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
return collected_responses
@@ -1141,6 +1201,22 @@ def test_live_streaming_text_content_persisted_in_session():
class CustomTestRunner(testing_utils.InMemoryRunner):
def _run_with_loop(self, coro):
try:
old_loop = asyncio.get_event_loop()
except RuntimeError:
old_loop = None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
finally:
loop.close()
asyncio.set_event_loop(old_loop)
def run_live_and_get_session(
self,
live_request_queue: LiveRequestQueue,
@@ -1159,24 +1235,15 @@ def test_live_streaming_text_content_persisted_in_session():
if len(collected_responses) >= 1:
return
try:
session = self.session
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
asyncio.wait_for(consume_responses(session), timeout=5.0)
)
finally:
loop.close()
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
self._run_with_loop(
asyncio.wait_for(consume_responses(self.session), timeout=5.0)
)
# Get the updated session
updated_session = self.runner.session_service.get_session_sync(
app_name=self.app_name,
user_id=session.user_id,
session_id=session.id,
user_id=self.session.user_id,
session_id=self.session.id,
)
return collected_responses, updated_session