fix: Improve logic for checking if a MCP session is disconnected

Currently logic to check for a disconnected session only checks for certain headers but doesn't detect all cases, leading to situations where it tries to connect to a session that is down. This adds logic so that we ping the server to check if it is disconnected. Fixes https://github.com/google/adk-python/issues/3321.

Co-authored-by: Kathy Wu <wukathy@google.com>
PiperOrigin-RevId: 832460068
This commit is contained in:
Kathy Wu
2025-11-14 14:33:09 -08:00
committed by Copybara-Service
parent 29fea7ec1f
commit a754c96d3c
3 changed files with 73 additions and 13 deletions
@@ -37,6 +37,7 @@ try:
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import EmptyResult
except ImportError as e:
if sys.version_info < (3, 10):
@@ -241,7 +242,7 @@ class MCPSessionManager:
return base_headers
def _is_session_disconnected(self, session: ClientSession) -> bool:
async def _is_session_disconnected(self, session: ClientSession) -> bool:
"""Checks if a session is disconnected or closed.
Args:
@@ -250,7 +251,24 @@ class MCPSessionManager:
Returns:
True if the session is disconnected, False otherwise.
"""
return session._read_stream._closed or session._write_stream._closed
if session._read_stream._closed or session._write_stream._closed:
return True
try:
response = await asyncio.wait_for(session.send_ping(), timeout=5.0)
if not isinstance(response, EmptyResult):
logger.info(
'Session ping returns illegal response %s, treating as'
' disconnected',
response,
)
return True
return False
except Exception as e:
logger.info(
'Session ping failed with error %s, treating as disconnected', e
)
return True
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
"""Creates an MCP client based on the connection parameters.
@@ -325,7 +343,7 @@ class MCPSessionManager:
session, exit_stack = self._sessions[session_key]
# Check if the existing session is still connected
if not self._is_session_disconnected(session):
if not await self._is_session_disconnected(session):
# Session is still good, return it
return session
else:
+4 -1
View File
@@ -175,7 +175,10 @@ class McpToolset(BaseToolset):
else None
)
# Get session from session manager
session = await self._mcp_session_manager.create_session(headers=headers)
try:
session = await self._mcp_session_manager.create_session(headers=headers)
except Exception as e:
raise ConnectionError(f"Failed to create MCP session") from e
# Fetch available tools from the MCP server
timeout_in_seconds = (
@@ -54,6 +54,7 @@ except ImportError as e:
# Import real MCP classes
try:
from mcp import StdioServerParameters
from mcp.types import EmptyResult
except ImportError:
# Create a mock if MCP is not available
class StdioServerParameters:
@@ -62,6 +63,9 @@ except ImportError:
self.command = command
self.args = args or []
class EmptyResult:
pass
class MockClientSession:
"""Mock ClientSession for testing."""
@@ -72,6 +76,7 @@ class MockClientSession:
self._read_stream._closed = False
self._write_stream._closed = False
self.initialize = AsyncMock()
self.send_ping = AsyncMock()
class MockAsyncExitStack:
@@ -206,19 +211,52 @@ class TestMCPSessionManager:
}
assert merged == expected
def test_is_session_disconnected(self):
"""Test session disconnection detection."""
@pytest.mark.asyncio
async def test_is_session_disconnected_when_connected(self):
"""Test session disconnection detection when session is connected."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
# Create mock session
session = MockClientSession()
session.send_ping.return_value = EmptyResult()
assert not await manager._is_session_disconnected(session)
session.send_ping.assert_called_once()
# Not disconnected
assert not manager._is_session_disconnected(session)
# Disconnected - read stream closed
@pytest.mark.asyncio
async def test_is_session_disconnected_read_stream_closed(self):
"""Test session disconnection detection when read stream is closed."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
session = MockClientSession()
session.send_ping.return_value = EmptyResult()
session._read_stream._closed = True
assert manager._is_session_disconnected(session)
assert await manager._is_session_disconnected(session)
session.send_ping.assert_not_called()
@pytest.mark.asyncio
async def test_is_session_disconnected_write_stream_closed(self):
"""Test session disconnection detection when write stream is closed."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
session = MockClientSession()
session.send_ping.return_value = EmptyResult()
session._write_stream._closed = True
assert await manager._is_session_disconnected(session)
session.send_ping.assert_not_called()
@pytest.mark.asyncio
async def test_is_session_disconnected_ping_fails(self):
"""Test session disconnection detection when ping fails."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
session = MockClientSession()
session.send_ping.side_effect = Exception("Ping failed")
assert await manager._is_session_disconnected(session)
session.send_ping.assert_called_once()
@pytest.mark.asyncio
async def test_is_session_disconnected_ping_returns_wrong_result(self):
"""Test session disconnection detection when ping returns wrong result."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
session = MockClientSession()
session.send_ping.return_value = "Wrong result"
assert await manager._is_session_disconnected(session)
session.send_ping.assert_called_once()
@pytest.mark.asyncio
async def test_create_session_stdio_new(self):
@@ -271,6 +309,7 @@ class TestMCPSessionManager:
# Session is connected
existing_session._read_stream._closed = False
existing_session._write_stream._closed = False
existing_session.send_ping.return_value = EmptyResult()
session = await manager.create_session()