From a754c96d3c4fd00f9c2cd924fc428b68cc5115fb Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Fri, 14 Nov 2025 14:33:09 -0800 Subject: [PATCH] fix: Improve logic for checking if a MCP session is disconnected Currently logic to check for a disconnected session only checks for certain headers but doesn't detect all cases, leading to situations where it tries to connect to a session that is down. This adds logic so that we ping the server to check if it is disconnected. Fixes https://github.com/google/adk-python/issues/3321. Co-authored-by: Kathy Wu PiperOrigin-RevId: 832460068 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 24 +++++++- src/google/adk/tools/mcp_tool/mcp_toolset.py | 5 +- .../mcp_tool/test_mcp_session_manager.py | 57 ++++++++++++++++--- 3 files changed, 73 insertions(+), 13 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index d95d48f2..af255d94 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -37,6 +37,7 @@ try: from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client + from mcp.types import EmptyResult except ImportError as e: if sys.version_info < (3, 10): @@ -241,7 +242,7 @@ class MCPSessionManager: return base_headers - def _is_session_disconnected(self, session: ClientSession) -> bool: + async def _is_session_disconnected(self, session: ClientSession) -> bool: """Checks if a session is disconnected or closed. Args: @@ -250,7 +251,24 @@ class MCPSessionManager: Returns: True if the session is disconnected, False otherwise. """ - return session._read_stream._closed or session._write_stream._closed + if session._read_stream._closed or session._write_stream._closed: + return True + + try: + response = await asyncio.wait_for(session.send_ping(), timeout=5.0) + if not isinstance(response, EmptyResult): + logger.info( + 'Session ping returns illegal response %s, treating as' + ' disconnected', + response, + ) + return True + return False + except Exception as e: + logger.info( + 'Session ping failed with error %s, treating as disconnected', e + ) + return True def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): """Creates an MCP client based on the connection parameters. @@ -325,7 +343,7 @@ class MCPSessionManager: session, exit_stack = self._sessions[session_key] # Check if the existing session is still connected - if not self._is_session_disconnected(session): + if not await self._is_session_disconnected(session): # Session is still good, return it return session else: diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index daa88f90..429d63ab 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -175,7 +175,10 @@ class McpToolset(BaseToolset): else None ) # Get session from session manager - session = await self._mcp_session_manager.create_session(headers=headers) + try: + session = await self._mcp_session_manager.create_session(headers=headers) + except Exception as e: + raise ConnectionError(f"Failed to create MCP session") from e # Fetch available tools from the MCP server timeout_in_seconds = ( diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 6c001ccf..8eb743eb 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -54,6 +54,7 @@ except ImportError as e: # Import real MCP classes try: from mcp import StdioServerParameters + from mcp.types import EmptyResult except ImportError: # Create a mock if MCP is not available class StdioServerParameters: @@ -62,6 +63,9 @@ except ImportError: self.command = command self.args = args or [] + class EmptyResult: + pass + class MockClientSession: """Mock ClientSession for testing.""" @@ -72,6 +76,7 @@ class MockClientSession: self._read_stream._closed = False self._write_stream._closed = False self.initialize = AsyncMock() + self.send_ping = AsyncMock() class MockAsyncExitStack: @@ -206,19 +211,52 @@ class TestMCPSessionManager: } assert merged == expected - def test_is_session_disconnected(self): - """Test session disconnection detection.""" + @pytest.mark.asyncio + async def test_is_session_disconnected_when_connected(self): + """Test session disconnection detection when session is connected.""" manager = MCPSessionManager(self.mock_stdio_connection_params) - - # Create mock session session = MockClientSession() + session.send_ping.return_value = EmptyResult() + assert not await manager._is_session_disconnected(session) + session.send_ping.assert_called_once() - # Not disconnected - assert not manager._is_session_disconnected(session) - - # Disconnected - read stream closed + @pytest.mark.asyncio + async def test_is_session_disconnected_read_stream_closed(self): + """Test session disconnection detection when read stream is closed.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + session = MockClientSession() + session.send_ping.return_value = EmptyResult() session._read_stream._closed = True - assert manager._is_session_disconnected(session) + assert await manager._is_session_disconnected(session) + session.send_ping.assert_not_called() + + @pytest.mark.asyncio + async def test_is_session_disconnected_write_stream_closed(self): + """Test session disconnection detection when write stream is closed.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + session = MockClientSession() + session.send_ping.return_value = EmptyResult() + session._write_stream._closed = True + assert await manager._is_session_disconnected(session) + session.send_ping.assert_not_called() + + @pytest.mark.asyncio + async def test_is_session_disconnected_ping_fails(self): + """Test session disconnection detection when ping fails.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + session = MockClientSession() + session.send_ping.side_effect = Exception("Ping failed") + assert await manager._is_session_disconnected(session) + session.send_ping.assert_called_once() + + @pytest.mark.asyncio + async def test_is_session_disconnected_ping_returns_wrong_result(self): + """Test session disconnection detection when ping returns wrong result.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + session = MockClientSession() + session.send_ping.return_value = "Wrong result" + assert await manager._is_session_disconnected(session) + session.send_ping.assert_called_once() @pytest.mark.asyncio async def test_create_session_stdio_new(self): @@ -271,6 +309,7 @@ class TestMCPSessionManager: # Session is connected existing_session._read_stream._closed = False existing_session._write_stream._closed = False + existing_session.send_ping.return_value = EmptyResult() session = await manager.create_session()