From 4aa475145f196fb35fe97290dd9f928548bc737f Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Tue, 10 Feb 2026 16:50:19 -0800 Subject: [PATCH] fix: Fix event loop closed bug in McpSessionManager Sessions were being erroneously cached and reused across different asyncio event loops, causing "Event loop is closed" in environments with transient loops. This updates the session caching to be loop-aware: before reusing a cached session, check that the stored loop matches the current loop. Also, if session is disconnected and loops do not match, discard the cached entry without calling aclose(). Co-authored-by: Kathy Wu PiperOrigin-RevId: 868380746 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 122 ++++++++++--- .../mcp_tool/test_mcp_session_manager.py | 172 +++++++++++++++++- 2 files changed, 260 insertions(+), 34 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index a9bfcfbc..0e9b9386 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -23,6 +23,7 @@ import hashlib import json import logging import sys +import threading from typing import Any from typing import Dict from typing import Optional @@ -220,11 +221,24 @@ class MCPSessionManager: self._connection_params = connection_params self._errlog = errlog - # Session pool: maps session keys to (session, exit_stack) tuples - self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} + # Session pool: maps session keys to (session, exit_stack, loop) tuples + self._sessions: Dict[ + str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop] + ] = {} - # Lock to prevent race conditions in session creation - self._session_lock = asyncio.Lock() + # Map of event loops to their respective locks to prevent race conditions + # across different event loops in session creation. + self._session_lock_map: dict[asyncio.AbstractEventLoop, asyncio.Lock] = {} + self._lock_map_lock = threading.Lock() + + @property + def _session_lock(self) -> asyncio.Lock: + """Returns an asyncio.Lock bound to the current event loop.""" + current_loop = asyncio.get_running_loop() + with self._lock_map_lock: + if current_loop not in self._session_lock_map: + self._session_lock_map[current_loop] = asyncio.Lock() + return self._session_lock_map[current_loop] def _generate_session_key( self, merged_headers: Optional[Dict[str, str]] = None @@ -293,6 +307,62 @@ class MCPSessionManager: """ return session._read_stream._closed or session._write_stream._closed + async def _cleanup_session( + self, + session_key: str, + exit_stack: AsyncExitStack, + stored_loop: asyncio.AbstractEventLoop, + ): + """Cleans up a session, handling different event loops safely. + + Args: + session_key: The session key to clean up. + exit_stack: The AsyncExitStack managing the session resources. + stored_loop: The event loop on which the session was created. + """ + current_loop = asyncio.get_running_loop() + try: + if stored_loop is current_loop: + await exit_stack.aclose() + elif stored_loop.is_closed(): + logger.warning( + f'Error cleaning up session {session_key}: original event loop' + ' is closed, resources may be leaked.' + ) + else: + # The old loop is still running in another thread; + # schedule cleanup on it. + logger.info( + f'Scheduling cleanup of session {session_key} on its original' + ' event loop.' + ) + future = asyncio.run_coroutine_threadsafe( + exit_stack.aclose(), stored_loop + ) + + # Attach a callback so errors don't go unnoticed + def cleanup_done(f: asyncio.Future): + try: + if f.exception(): + logger.warning( + f'Error cleaning up session {session_key} on original' + f' loop: {f.exception()}' + ) + except Exception as e: + logger.warning( + f'Failed to check cleanup status for {session_key}: {e}' + ) + + future.add_done_callback(cleanup_done) + except Exception as e: + logger.warning( + f'Error during session cleanup for {session_key}: {e}', + exc_info=True, + ) + finally: + if session_key in self._sessions: + del self._sessions[session_key] + def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): """Creates an MCP client based on the connection parameters. @@ -364,21 +434,22 @@ class MCPSessionManager: async with self._session_lock: # Check if we have an existing session if session_key in self._sessions: - session, exit_stack = self._sessions[session_key] + session, exit_stack, stored_loop = self._sessions[session_key] - # Check if the existing session is still connected - if not self._is_session_disconnected(session): + # Check if the existing session is still connected and bound to the current loop + current_loop = asyncio.get_running_loop() + if stored_loop is current_loop and not self._is_session_disconnected( + session + ): # Session is still good, return it return session else: - # Session is disconnected, clean it up - logger.info('Cleaning up disconnected session: %s', session_key) - try: - await exit_stack.aclose() - except Exception as e: - logger.warning('Error during disconnected session cleanup: %s', e) - finally: - del self._sessions[session_key] + # Session is disconnected or from a different loop, clean it up + logger.info( + 'Cleaning up session (disconnected or different loop): %s', + session_key, + ) + await self._cleanup_session(session_key, exit_stack, stored_loop) # Create a new session (either first time or replacing disconnected one) exit_stack = AsyncExitStack() @@ -409,8 +480,12 @@ class MCPSessionManager: timeout=timeout_in_seconds, ) - # Store session and exit stack in the pool - self._sessions[session_key] = (session, exit_stack) + # Store session, exit stack, and loop in the pool + self._sessions[session_key] = ( + session, + exit_stack, + asyncio.get_running_loop(), + ) logger.debug('Created new session: %s', session_key) return session @@ -429,17 +504,8 @@ class MCPSessionManager: """Closes all sessions and cleans up resources.""" async with self._session_lock: for session_key in list(self._sessions.keys()): - _, exit_stack = self._sessions[session_key] - try: - await exit_stack.aclose() - except Exception as e: - # Log the error but don't re-raise to avoid blocking shutdown - logger.warning( - f'Error during MCP session cleanup for {session_key}', - exc_info=True, - ) - finally: - del self._sessions[session_key] + _, exit_stack, stored_loop = self._sessions[session_key] + await self._cleanup_session(session_key, exit_stack, stored_loop) SseServerParams = SseConnectionParams diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index cc2cb487..ba9a7e80 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -18,10 +18,12 @@ import hashlib from io import StringIO import json import sys +from unittest.mock import ANY from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from google.adk.platform import thread as platform_thread from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams @@ -298,6 +300,10 @@ class TestMCPSessionManager: assert session == mock_session assert len(manager._sessions) == 1 assert "stdio_session" in manager._sessions + session_data = manager._sessions["stdio_session"] + assert len(session_data) == 3 + assert session_data[0] == mock_session + assert session_data[2] == asyncio.get_running_loop() # Verify SessionContext was created mock_session_context_class.assert_called_once() @@ -312,7 +318,11 @@ class TestMCPSessionManager: # Create mock existing session existing_session = MockClientSession() existing_exit_stack = MockAsyncExitStack() - manager._sessions["stdio_session"] = (existing_session, existing_exit_stack) + manager._sessions["stdio_session"] = ( + existing_session, + existing_exit_stack, + asyncio.get_running_loop(), + ) # Session is connected existing_session._read_stream._closed = False @@ -377,8 +387,16 @@ class TestMCPSessionManager: session2 = MockClientSession() exit_stack2 = MockAsyncExitStack() - manager._sessions["session1"] = (session1, exit_stack1) - manager._sessions["session2"] = (session2, exit_stack2) + manager._sessions["session1"] = ( + session1, + exit_stack1, + asyncio.get_running_loop(), + ) + manager._sessions["session2"] = ( + session2, + exit_stack2, + asyncio.get_running_loop(), + ) await manager.close() @@ -401,8 +419,16 @@ class TestMCPSessionManager: session2 = MockClientSession() exit_stack2 = MockAsyncExitStack() - manager._sessions["session1"] = (session1, exit_stack1) - manager._sessions["session2"] = (session2, exit_stack2) + manager._sessions["session1"] = ( + session1, + exit_stack1, + asyncio.get_running_loop(), + ) + manager._sessions["session2"] = ( + session2, + exit_stack2, + asyncio.get_running_loop(), + ) # Should not raise exception await manager.close() @@ -414,7 +440,7 @@ class TestMCPSessionManager: # Error should be logged via logger.warning mock_logger.warning.assert_called_once() args, kwargs = mock_logger.warning.call_args - assert "Error during MCP session cleanup for session1" in args[0] + assert "Error during session cleanup for session1: Close error 1" in args[0] assert kwargs.get("exc_info") @pytest.mark.asyncio @@ -447,6 +473,140 @@ class TestMCPSessionManager: # Verify session was closed assert not manager._sessions + @pytest.mark.asyncio + async def test_session_lock_different_loops(self): + """Verify that _session_lock returns different locks for different loops.""" + + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Access in current loop + lock1 = manager._session_lock + assert isinstance(lock1, asyncio.Lock) + + # Access in a different loop (in a separate thread) + lock_container = [] + + def run_in_thread(): + loop2 = asyncio.new_event_loop() + asyncio.set_event_loop(loop2) + try: + + async def get_lock(): + return manager._session_lock + + lock_container.append(loop2.run_until_complete(get_lock())) + finally: + loop2.close() + + thread = platform_thread.create_thread(target=run_in_thread) + thread.start() + thread.join() + + assert lock_container + lock2 = lock_container[0] + assert isinstance(lock2, asyncio.Lock) + assert lock1 is not lock2 + + @pytest.mark.asyncio + async def test_cleanup_session_cross_loop(self): + """Verify that _cleanup_session uses run_coroutine_threadsafe for different loops.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + mock_exit_stack = MockAsyncExitStack() + + # Create a dummy loop that is "running" in another thread + loop2 = asyncio.new_event_loop() + try: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.asyncio.run_coroutine_threadsafe" + ) as mock_run_threadsafe: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.logger" + ) as mock_logger: + # We need to mock the return value of run_coroutine_threadsafe to be a future + mock_future = Mock() + mock_run_threadsafe.return_value = mock_future + + await manager._cleanup_session("test_session", mock_exit_stack, loop2) + + # Verify run_coroutine_threadsafe was called + # ANY is used because a new coroutine object is created each time + mock_run_threadsafe.assert_called_once_with(ANY, loop2) + + mock_logger.info.assert_any_call( + "Scheduling cleanup of session test_session on its original" + " event loop." + ) + mock_future.add_done_callback.assert_called_once() + finally: + loop2.close() + + @pytest.mark.asyncio + async def test_create_session_cleans_up_without_aclose_if_loop_is_different( + self, + ): + """Verify that sessions from different loops are cleaned up without calling aclose().""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # 1. Simulate a session created in a "different" loop + mock_session = MockClientSession() + mock_exit_stack = MockAsyncExitStack() + # Use a dummy object as a different loop + different_loop = Mock(spec=asyncio.AbstractEventLoop) + + manager._sessions["stdio_session"] = ( + mock_session, + mock_exit_stack, + different_loop, + ) + + # 2. Mock creation of a new session + # We need to mock create_client, wait_for, and SessionContext + with patch.object(manager, "_create_client") as mock_create_client: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.asyncio.wait_for" + ) as mock_wait_for: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.SessionContext" + ) as mock_session_context_class: + # Setup mocks for new session creation + mock_create_client.return_value = AsyncMock() + new_session = MockClientSession() + mock_wait_for.return_value = new_session + mock_session_context_class.return_value = AsyncMock() + + # 3. Call create_session + session = await manager.create_session() + + # 4. Verify results + assert session == new_session + assert len(manager._sessions) == 1 + # Verify that old exit_stack.aclose was NOT called since loop was different + mock_exit_stack.aclose.assert_not_called() + + @pytest.mark.asyncio + async def test_close_skips_aclose_for_different_loop_sessions(self): + """Verify that close() skips aclose() for sessions from different loops.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add one session from same loop and one from different loop + current_loop = asyncio.get_running_loop() + different_loop = Mock(spec=asyncio.AbstractEventLoop) + + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + manager._sessions["session1"] = (session1, exit_stack1, current_loop) + + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + manager._sessions["session2"] = (session2, exit_stack2, different_loop) + + await manager.close() + + # exit_stack1 should be closed, exit_stack2 should be skipped + exit_stack1.aclose.assert_called_once() + exit_stack2.aclose.assert_not_called() + assert len(manager._sessions) == 0 + @pytest.mark.asyncio async def test_retry_on_errors_decorator():