diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index a9bfcfbc..0e9b9386 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -23,6 +23,7 @@ import hashlib import json import logging import sys +import threading from typing import Any from typing import Dict from typing import Optional @@ -220,11 +221,24 @@ class MCPSessionManager: self._connection_params = connection_params self._errlog = errlog - # Session pool: maps session keys to (session, exit_stack) tuples - self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} + # Session pool: maps session keys to (session, exit_stack, loop) tuples + self._sessions: Dict[ + str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop] + ] = {} - # Lock to prevent race conditions in session creation - self._session_lock = asyncio.Lock() + # Map of event loops to their respective locks to prevent race conditions + # across different event loops in session creation. + self._session_lock_map: dict[asyncio.AbstractEventLoop, asyncio.Lock] = {} + self._lock_map_lock = threading.Lock() + + @property + def _session_lock(self) -> asyncio.Lock: + """Returns an asyncio.Lock bound to the current event loop.""" + current_loop = asyncio.get_running_loop() + with self._lock_map_lock: + if current_loop not in self._session_lock_map: + self._session_lock_map[current_loop] = asyncio.Lock() + return self._session_lock_map[current_loop] def _generate_session_key( self, merged_headers: Optional[Dict[str, str]] = None @@ -293,6 +307,62 @@ class MCPSessionManager: """ return session._read_stream._closed or session._write_stream._closed + async def _cleanup_session( + self, + session_key: str, + exit_stack: AsyncExitStack, + stored_loop: asyncio.AbstractEventLoop, + ): + """Cleans up a session, handling different event loops safely. + + Args: + session_key: The session key to clean up. + exit_stack: The AsyncExitStack managing the session resources. + stored_loop: The event loop on which the session was created. + """ + current_loop = asyncio.get_running_loop() + try: + if stored_loop is current_loop: + await exit_stack.aclose() + elif stored_loop.is_closed(): + logger.warning( + f'Error cleaning up session {session_key}: original event loop' + ' is closed, resources may be leaked.' + ) + else: + # The old loop is still running in another thread; + # schedule cleanup on it. + logger.info( + f'Scheduling cleanup of session {session_key} on its original' + ' event loop.' + ) + future = asyncio.run_coroutine_threadsafe( + exit_stack.aclose(), stored_loop + ) + + # Attach a callback so errors don't go unnoticed + def cleanup_done(f: asyncio.Future): + try: + if f.exception(): + logger.warning( + f'Error cleaning up session {session_key} on original' + f' loop: {f.exception()}' + ) + except Exception as e: + logger.warning( + f'Failed to check cleanup status for {session_key}: {e}' + ) + + future.add_done_callback(cleanup_done) + except Exception as e: + logger.warning( + f'Error during session cleanup for {session_key}: {e}', + exc_info=True, + ) + finally: + if session_key in self._sessions: + del self._sessions[session_key] + def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): """Creates an MCP client based on the connection parameters. @@ -364,21 +434,22 @@ class MCPSessionManager: async with self._session_lock: # Check if we have an existing session if session_key in self._sessions: - session, exit_stack = self._sessions[session_key] + session, exit_stack, stored_loop = self._sessions[session_key] - # Check if the existing session is still connected - if not self._is_session_disconnected(session): + # Check if the existing session is still connected and bound to the current loop + current_loop = asyncio.get_running_loop() + if stored_loop is current_loop and not self._is_session_disconnected( + session + ): # Session is still good, return it return session else: - # Session is disconnected, clean it up - logger.info('Cleaning up disconnected session: %s', session_key) - try: - await exit_stack.aclose() - except Exception as e: - logger.warning('Error during disconnected session cleanup: %s', e) - finally: - del self._sessions[session_key] + # Session is disconnected or from a different loop, clean it up + logger.info( + 'Cleaning up session (disconnected or different loop): %s', + session_key, + ) + await self._cleanup_session(session_key, exit_stack, stored_loop) # Create a new session (either first time or replacing disconnected one) exit_stack = AsyncExitStack() @@ -409,8 +480,12 @@ class MCPSessionManager: timeout=timeout_in_seconds, ) - # Store session and exit stack in the pool - self._sessions[session_key] = (session, exit_stack) + # Store session, exit stack, and loop in the pool + self._sessions[session_key] = ( + session, + exit_stack, + asyncio.get_running_loop(), + ) logger.debug('Created new session: %s', session_key) return session @@ -429,17 +504,8 @@ class MCPSessionManager: """Closes all sessions and cleans up resources.""" async with self._session_lock: for session_key in list(self._sessions.keys()): - _, exit_stack = self._sessions[session_key] - try: - await exit_stack.aclose() - except Exception as e: - # Log the error but don't re-raise to avoid blocking shutdown - logger.warning( - f'Error during MCP session cleanup for {session_key}', - exc_info=True, - ) - finally: - del self._sessions[session_key] + _, exit_stack, stored_loop = self._sessions[session_key] + await self._cleanup_session(session_key, exit_stack, stored_loop) SseServerParams = SseConnectionParams diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index cc2cb487..ba9a7e80 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -18,10 +18,12 @@ import hashlib from io import StringIO import json import sys +from unittest.mock import ANY from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from google.adk.platform import thread as platform_thread from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams @@ -298,6 +300,10 @@ class TestMCPSessionManager: assert session == mock_session assert len(manager._sessions) == 1 assert "stdio_session" in manager._sessions + session_data = manager._sessions["stdio_session"] + assert len(session_data) == 3 + assert session_data[0] == mock_session + assert session_data[2] == asyncio.get_running_loop() # Verify SessionContext was created mock_session_context_class.assert_called_once() @@ -312,7 +318,11 @@ class TestMCPSessionManager: # Create mock existing session existing_session = MockClientSession() existing_exit_stack = MockAsyncExitStack() - manager._sessions["stdio_session"] = (existing_session, existing_exit_stack) + manager._sessions["stdio_session"] = ( + existing_session, + existing_exit_stack, + asyncio.get_running_loop(), + ) # Session is connected existing_session._read_stream._closed = False @@ -377,8 +387,16 @@ class TestMCPSessionManager: session2 = MockClientSession() exit_stack2 = MockAsyncExitStack() - manager._sessions["session1"] = (session1, exit_stack1) - manager._sessions["session2"] = (session2, exit_stack2) + manager._sessions["session1"] = ( + session1, + exit_stack1, + asyncio.get_running_loop(), + ) + manager._sessions["session2"] = ( + session2, + exit_stack2, + asyncio.get_running_loop(), + ) await manager.close() @@ -401,8 +419,16 @@ class TestMCPSessionManager: session2 = MockClientSession() exit_stack2 = MockAsyncExitStack() - manager._sessions["session1"] = (session1, exit_stack1) - manager._sessions["session2"] = (session2, exit_stack2) + manager._sessions["session1"] = ( + session1, + exit_stack1, + asyncio.get_running_loop(), + ) + manager._sessions["session2"] = ( + session2, + exit_stack2, + asyncio.get_running_loop(), + ) # Should not raise exception await manager.close() @@ -414,7 +440,7 @@ class TestMCPSessionManager: # Error should be logged via logger.warning mock_logger.warning.assert_called_once() args, kwargs = mock_logger.warning.call_args - assert "Error during MCP session cleanup for session1" in args[0] + assert "Error during session cleanup for session1: Close error 1" in args[0] assert kwargs.get("exc_info") @pytest.mark.asyncio @@ -447,6 +473,140 @@ class TestMCPSessionManager: # Verify session was closed assert not manager._sessions + @pytest.mark.asyncio + async def test_session_lock_different_loops(self): + """Verify that _session_lock returns different locks for different loops.""" + + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Access in current loop + lock1 = manager._session_lock + assert isinstance(lock1, asyncio.Lock) + + # Access in a different loop (in a separate thread) + lock_container = [] + + def run_in_thread(): + loop2 = asyncio.new_event_loop() + asyncio.set_event_loop(loop2) + try: + + async def get_lock(): + return manager._session_lock + + lock_container.append(loop2.run_until_complete(get_lock())) + finally: + loop2.close() + + thread = platform_thread.create_thread(target=run_in_thread) + thread.start() + thread.join() + + assert lock_container + lock2 = lock_container[0] + assert isinstance(lock2, asyncio.Lock) + assert lock1 is not lock2 + + @pytest.mark.asyncio + async def test_cleanup_session_cross_loop(self): + """Verify that _cleanup_session uses run_coroutine_threadsafe for different loops.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + mock_exit_stack = MockAsyncExitStack() + + # Create a dummy loop that is "running" in another thread + loop2 = asyncio.new_event_loop() + try: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.asyncio.run_coroutine_threadsafe" + ) as mock_run_threadsafe: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.logger" + ) as mock_logger: + # We need to mock the return value of run_coroutine_threadsafe to be a future + mock_future = Mock() + mock_run_threadsafe.return_value = mock_future + + await manager._cleanup_session("test_session", mock_exit_stack, loop2) + + # Verify run_coroutine_threadsafe was called + # ANY is used because a new coroutine object is created each time + mock_run_threadsafe.assert_called_once_with(ANY, loop2) + + mock_logger.info.assert_any_call( + "Scheduling cleanup of session test_session on its original" + " event loop." + ) + mock_future.add_done_callback.assert_called_once() + finally: + loop2.close() + + @pytest.mark.asyncio + async def test_create_session_cleans_up_without_aclose_if_loop_is_different( + self, + ): + """Verify that sessions from different loops are cleaned up without calling aclose().""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # 1. Simulate a session created in a "different" loop + mock_session = MockClientSession() + mock_exit_stack = MockAsyncExitStack() + # Use a dummy object as a different loop + different_loop = Mock(spec=asyncio.AbstractEventLoop) + + manager._sessions["stdio_session"] = ( + mock_session, + mock_exit_stack, + different_loop, + ) + + # 2. Mock creation of a new session + # We need to mock create_client, wait_for, and SessionContext + with patch.object(manager, "_create_client") as mock_create_client: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.asyncio.wait_for" + ) as mock_wait_for: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.SessionContext" + ) as mock_session_context_class: + # Setup mocks for new session creation + mock_create_client.return_value = AsyncMock() + new_session = MockClientSession() + mock_wait_for.return_value = new_session + mock_session_context_class.return_value = AsyncMock() + + # 3. Call create_session + session = await manager.create_session() + + # 4. Verify results + assert session == new_session + assert len(manager._sessions) == 1 + # Verify that old exit_stack.aclose was NOT called since loop was different + mock_exit_stack.aclose.assert_not_called() + + @pytest.mark.asyncio + async def test_close_skips_aclose_for_different_loop_sessions(self): + """Verify that close() skips aclose() for sessions from different loops.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add one session from same loop and one from different loop + current_loop = asyncio.get_running_loop() + different_loop = Mock(spec=asyncio.AbstractEventLoop) + + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + manager._sessions["session1"] = (session1, exit_stack1, current_loop) + + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + manager._sessions["session2"] = (session2, exit_stack2, different_loop) + + await manager.close() + + # exit_stack1 should be closed, exit_stack2 should be skipped + exit_stack1.aclose.assert_called_once() + exit_stack2.aclose.assert_not_called() + assert len(manager._sessions) == 0 + @pytest.mark.asyncio async def test_retry_on_errors_decorator():