fix: Fix event loop closed bug in McpSessionManager

Sessions were being erroneously cached and reused across different asyncio event loops, causing "Event loop is closed" in environments with transient loops. This updates the session caching to be loop-aware: before reusing a cached session, check that the stored loop matches the current loop. Also, if session is disconnected and loops do not match, discard the cached entry without calling aclose().

Co-authored-by: Kathy Wu <wukathy@google.com>
PiperOrigin-RevId: 868380746
This commit is contained in:
Kathy Wu
2026-02-10 16:50:19 -08:00
committed by Copybara-Service
parent 7110336788
commit 4aa475145f
2 changed files with 260 additions and 34 deletions
@@ -23,6 +23,7 @@ import hashlib
import json
import logging
import sys
import threading
from typing import Any
from typing import Dict
from typing import Optional
@@ -220,11 +221,24 @@ class MCPSessionManager:
self._connection_params = connection_params
self._errlog = errlog
# Session pool: maps session keys to (session, exit_stack) tuples
self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {}
# Session pool: maps session keys to (session, exit_stack, loop) tuples
self._sessions: Dict[
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop]
] = {}
# Lock to prevent race conditions in session creation
self._session_lock = asyncio.Lock()
# Map of event loops to their respective locks to prevent race conditions
# across different event loops in session creation.
self._session_lock_map: dict[asyncio.AbstractEventLoop, asyncio.Lock] = {}
self._lock_map_lock = threading.Lock()
@property
def _session_lock(self) -> asyncio.Lock:
"""Returns an asyncio.Lock bound to the current event loop."""
current_loop = asyncio.get_running_loop()
with self._lock_map_lock:
if current_loop not in self._session_lock_map:
self._session_lock_map[current_loop] = asyncio.Lock()
return self._session_lock_map[current_loop]
def _generate_session_key(
self, merged_headers: Optional[Dict[str, str]] = None
@@ -293,6 +307,62 @@ class MCPSessionManager:
"""
return session._read_stream._closed or session._write_stream._closed
async def _cleanup_session(
self,
session_key: str,
exit_stack: AsyncExitStack,
stored_loop: asyncio.AbstractEventLoop,
):
"""Cleans up a session, handling different event loops safely.
Args:
session_key: The session key to clean up.
exit_stack: The AsyncExitStack managing the session resources.
stored_loop: The event loop on which the session was created.
"""
current_loop = asyncio.get_running_loop()
try:
if stored_loop is current_loop:
await exit_stack.aclose()
elif stored_loop.is_closed():
logger.warning(
f'Error cleaning up session {session_key}: original event loop'
' is closed, resources may be leaked.'
)
else:
# The old loop is still running in another thread;
# schedule cleanup on it.
logger.info(
f'Scheduling cleanup of session {session_key} on its original'
' event loop.'
)
future = asyncio.run_coroutine_threadsafe(
exit_stack.aclose(), stored_loop
)
# Attach a callback so errors don't go unnoticed
def cleanup_done(f: asyncio.Future):
try:
if f.exception():
logger.warning(
f'Error cleaning up session {session_key} on original'
f' loop: {f.exception()}'
)
except Exception as e:
logger.warning(
f'Failed to check cleanup status for {session_key}: {e}'
)
future.add_done_callback(cleanup_done)
except Exception as e:
logger.warning(
f'Error during session cleanup for {session_key}: {e}',
exc_info=True,
)
finally:
if session_key in self._sessions:
del self._sessions[session_key]
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
"""Creates an MCP client based on the connection parameters.
@@ -364,21 +434,22 @@ class MCPSessionManager:
async with self._session_lock:
# Check if we have an existing session
if session_key in self._sessions:
session, exit_stack = self._sessions[session_key]
session, exit_stack, stored_loop = self._sessions[session_key]
# Check if the existing session is still connected
if not self._is_session_disconnected(session):
# Check if the existing session is still connected and bound to the current loop
current_loop = asyncio.get_running_loop()
if stored_loop is current_loop and not self._is_session_disconnected(
session
):
# Session is still good, return it
return session
else:
# Session is disconnected, clean it up
logger.info('Cleaning up disconnected session: %s', session_key)
try:
await exit_stack.aclose()
except Exception as e:
logger.warning('Error during disconnected session cleanup: %s', e)
finally:
del self._sessions[session_key]
# Session is disconnected or from a different loop, clean it up
logger.info(
'Cleaning up session (disconnected or different loop): %s',
session_key,
)
await self._cleanup_session(session_key, exit_stack, stored_loop)
# Create a new session (either first time or replacing disconnected one)
exit_stack = AsyncExitStack()
@@ -409,8 +480,12 @@ class MCPSessionManager:
timeout=timeout_in_seconds,
)
# Store session and exit stack in the pool
self._sessions[session_key] = (session, exit_stack)
# Store session, exit stack, and loop in the pool
self._sessions[session_key] = (
session,
exit_stack,
asyncio.get_running_loop(),
)
logger.debug('Created new session: %s', session_key)
return session
@@ -429,17 +504,8 @@ class MCPSessionManager:
"""Closes all sessions and cleans up resources."""
async with self._session_lock:
for session_key in list(self._sessions.keys()):
_, exit_stack = self._sessions[session_key]
try:
await exit_stack.aclose()
except Exception as e:
# Log the error but don't re-raise to avoid blocking shutdown
logger.warning(
f'Error during MCP session cleanup for {session_key}',
exc_info=True,
)
finally:
del self._sessions[session_key]
_, exit_stack, stored_loop = self._sessions[session_key]
await self._cleanup_session(session_key, exit_stack, stored_loop)
SseServerParams = SseConnectionParams
@@ -18,10 +18,12 @@ import hashlib
from io import StringIO
import json
import sys
from unittest.mock import ANY
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
from google.adk.platform import thread as platform_thread
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
@@ -298,6 +300,10 @@ class TestMCPSessionManager:
assert session == mock_session
assert len(manager._sessions) == 1
assert "stdio_session" in manager._sessions
session_data = manager._sessions["stdio_session"]
assert len(session_data) == 3
assert session_data[0] == mock_session
assert session_data[2] == asyncio.get_running_loop()
# Verify SessionContext was created
mock_session_context_class.assert_called_once()
@@ -312,7 +318,11 @@ class TestMCPSessionManager:
# Create mock existing session
existing_session = MockClientSession()
existing_exit_stack = MockAsyncExitStack()
manager._sessions["stdio_session"] = (existing_session, existing_exit_stack)
manager._sessions["stdio_session"] = (
existing_session,
existing_exit_stack,
asyncio.get_running_loop(),
)
# Session is connected
existing_session._read_stream._closed = False
@@ -377,8 +387,16 @@ class TestMCPSessionManager:
session2 = MockClientSession()
exit_stack2 = MockAsyncExitStack()
manager._sessions["session1"] = (session1, exit_stack1)
manager._sessions["session2"] = (session2, exit_stack2)
manager._sessions["session1"] = (
session1,
exit_stack1,
asyncio.get_running_loop(),
)
manager._sessions["session2"] = (
session2,
exit_stack2,
asyncio.get_running_loop(),
)
await manager.close()
@@ -401,8 +419,16 @@ class TestMCPSessionManager:
session2 = MockClientSession()
exit_stack2 = MockAsyncExitStack()
manager._sessions["session1"] = (session1, exit_stack1)
manager._sessions["session2"] = (session2, exit_stack2)
manager._sessions["session1"] = (
session1,
exit_stack1,
asyncio.get_running_loop(),
)
manager._sessions["session2"] = (
session2,
exit_stack2,
asyncio.get_running_loop(),
)
# Should not raise exception
await manager.close()
@@ -414,7 +440,7 @@ class TestMCPSessionManager:
# Error should be logged via logger.warning
mock_logger.warning.assert_called_once()
args, kwargs = mock_logger.warning.call_args
assert "Error during MCP session cleanup for session1" in args[0]
assert "Error during session cleanup for session1: Close error 1" in args[0]
assert kwargs.get("exc_info")
@pytest.mark.asyncio
@@ -447,6 +473,140 @@ class TestMCPSessionManager:
# Verify session was closed
assert not manager._sessions
@pytest.mark.asyncio
async def test_session_lock_different_loops(self):
"""Verify that _session_lock returns different locks for different loops."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
# Access in current loop
lock1 = manager._session_lock
assert isinstance(lock1, asyncio.Lock)
# Access in a different loop (in a separate thread)
lock_container = []
def run_in_thread():
loop2 = asyncio.new_event_loop()
asyncio.set_event_loop(loop2)
try:
async def get_lock():
return manager._session_lock
lock_container.append(loop2.run_until_complete(get_lock()))
finally:
loop2.close()
thread = platform_thread.create_thread(target=run_in_thread)
thread.start()
thread.join()
assert lock_container
lock2 = lock_container[0]
assert isinstance(lock2, asyncio.Lock)
assert lock1 is not lock2
@pytest.mark.asyncio
async def test_cleanup_session_cross_loop(self):
"""Verify that _cleanup_session uses run_coroutine_threadsafe for different loops."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
mock_exit_stack = MockAsyncExitStack()
# Create a dummy loop that is "running" in another thread
loop2 = asyncio.new_event_loop()
try:
with patch(
"google.adk.tools.mcp_tool.mcp_session_manager.asyncio.run_coroutine_threadsafe"
) as mock_run_threadsafe:
with patch(
"google.adk.tools.mcp_tool.mcp_session_manager.logger"
) as mock_logger:
# We need to mock the return value of run_coroutine_threadsafe to be a future
mock_future = Mock()
mock_run_threadsafe.return_value = mock_future
await manager._cleanup_session("test_session", mock_exit_stack, loop2)
# Verify run_coroutine_threadsafe was called
# ANY is used because a new coroutine object is created each time
mock_run_threadsafe.assert_called_once_with(ANY, loop2)
mock_logger.info.assert_any_call(
"Scheduling cleanup of session test_session on its original"
" event loop."
)
mock_future.add_done_callback.assert_called_once()
finally:
loop2.close()
@pytest.mark.asyncio
async def test_create_session_cleans_up_without_aclose_if_loop_is_different(
self,
):
"""Verify that sessions from different loops are cleaned up without calling aclose()."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
# 1. Simulate a session created in a "different" loop
mock_session = MockClientSession()
mock_exit_stack = MockAsyncExitStack()
# Use a dummy object as a different loop
different_loop = Mock(spec=asyncio.AbstractEventLoop)
manager._sessions["stdio_session"] = (
mock_session,
mock_exit_stack,
different_loop,
)
# 2. Mock creation of a new session
# We need to mock create_client, wait_for, and SessionContext
with patch.object(manager, "_create_client") as mock_create_client:
with patch(
"google.adk.tools.mcp_tool.mcp_session_manager.asyncio.wait_for"
) as mock_wait_for:
with patch(
"google.adk.tools.mcp_tool.mcp_session_manager.SessionContext"
) as mock_session_context_class:
# Setup mocks for new session creation
mock_create_client.return_value = AsyncMock()
new_session = MockClientSession()
mock_wait_for.return_value = new_session
mock_session_context_class.return_value = AsyncMock()
# 3. Call create_session
session = await manager.create_session()
# 4. Verify results
assert session == new_session
assert len(manager._sessions) == 1
# Verify that old exit_stack.aclose was NOT called since loop was different
mock_exit_stack.aclose.assert_not_called()
@pytest.mark.asyncio
async def test_close_skips_aclose_for_different_loop_sessions(self):
"""Verify that close() skips aclose() for sessions from different loops."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
# Add one session from same loop and one from different loop
current_loop = asyncio.get_running_loop()
different_loop = Mock(spec=asyncio.AbstractEventLoop)
session1 = MockClientSession()
exit_stack1 = MockAsyncExitStack()
manager._sessions["session1"] = (session1, exit_stack1, current_loop)
session2 = MockClientSession()
exit_stack2 = MockAsyncExitStack()
manager._sessions["session2"] = (session2, exit_stack2, different_loop)
await manager.close()
# exit_stack1 should be closed, exit_stack2 should be skipped
exit_stack1.aclose.assert_called_once()
exit_stack2.aclose.assert_not_called()
assert len(manager._sessions) == 0
@pytest.mark.asyncio
async def test_retry_on_errors_decorator():