You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Improve logic for checking if a MCP session is disconnected
Currently logic to check for a disconnected session only checks for certain headers but doesn't detect all cases, leading to situations where it tries to connect to a session that is down. This adds logic so that we ping the server to check if it is disconnected. Fixes https://github.com/google/adk-python/issues/3321. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 832460068
This commit is contained in:
committed by
Copybara-Service
parent
29fea7ec1f
commit
a754c96d3c
@@ -37,6 +37,7 @@ try:
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.types import EmptyResult
|
||||
except ImportError as e:
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
@@ -241,7 +242,7 @@ class MCPSessionManager:
|
||||
|
||||
return base_headers
|
||||
|
||||
def _is_session_disconnected(self, session: ClientSession) -> bool:
|
||||
async def _is_session_disconnected(self, session: ClientSession) -> bool:
|
||||
"""Checks if a session is disconnected or closed.
|
||||
|
||||
Args:
|
||||
@@ -250,7 +251,24 @@ class MCPSessionManager:
|
||||
Returns:
|
||||
True if the session is disconnected, False otherwise.
|
||||
"""
|
||||
return session._read_stream._closed or session._write_stream._closed
|
||||
if session._read_stream._closed or session._write_stream._closed:
|
||||
return True
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(session.send_ping(), timeout=5.0)
|
||||
if not isinstance(response, EmptyResult):
|
||||
logger.info(
|
||||
'Session ping returns illegal response %s, treating as'
|
||||
' disconnected',
|
||||
response,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
'Session ping failed with error %s, treating as disconnected', e
|
||||
)
|
||||
return True
|
||||
|
||||
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
|
||||
"""Creates an MCP client based on the connection parameters.
|
||||
@@ -325,7 +343,7 @@ class MCPSessionManager:
|
||||
session, exit_stack = self._sessions[session_key]
|
||||
|
||||
# Check if the existing session is still connected
|
||||
if not self._is_session_disconnected(session):
|
||||
if not await self._is_session_disconnected(session):
|
||||
# Session is still good, return it
|
||||
return session
|
||||
else:
|
||||
|
||||
@@ -175,7 +175,10 @@ class McpToolset(BaseToolset):
|
||||
else None
|
||||
)
|
||||
# Get session from session manager
|
||||
session = await self._mcp_session_manager.create_session(headers=headers)
|
||||
try:
|
||||
session = await self._mcp_session_manager.create_session(headers=headers)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Failed to create MCP session") from e
|
||||
|
||||
# Fetch available tools from the MCP server
|
||||
timeout_in_seconds = (
|
||||
|
||||
@@ -54,6 +54,7 @@ except ImportError as e:
|
||||
# Import real MCP classes
|
||||
try:
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.types import EmptyResult
|
||||
except ImportError:
|
||||
# Create a mock if MCP is not available
|
||||
class StdioServerParameters:
|
||||
@@ -62,6 +63,9 @@ except ImportError:
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
|
||||
class EmptyResult:
|
||||
pass
|
||||
|
||||
|
||||
class MockClientSession:
|
||||
"""Mock ClientSession for testing."""
|
||||
@@ -72,6 +76,7 @@ class MockClientSession:
|
||||
self._read_stream._closed = False
|
||||
self._write_stream._closed = False
|
||||
self.initialize = AsyncMock()
|
||||
self.send_ping = AsyncMock()
|
||||
|
||||
|
||||
class MockAsyncExitStack:
|
||||
@@ -206,19 +211,52 @@ class TestMCPSessionManager:
|
||||
}
|
||||
assert merged == expected
|
||||
|
||||
def test_is_session_disconnected(self):
|
||||
"""Test session disconnection detection."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_session_disconnected_when_connected(self):
|
||||
"""Test session disconnection detection when session is connected."""
|
||||
manager = MCPSessionManager(self.mock_stdio_connection_params)
|
||||
|
||||
# Create mock session
|
||||
session = MockClientSession()
|
||||
session.send_ping.return_value = EmptyResult()
|
||||
assert not await manager._is_session_disconnected(session)
|
||||
session.send_ping.assert_called_once()
|
||||
|
||||
# Not disconnected
|
||||
assert not manager._is_session_disconnected(session)
|
||||
|
||||
# Disconnected - read stream closed
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_session_disconnected_read_stream_closed(self):
|
||||
"""Test session disconnection detection when read stream is closed."""
|
||||
manager = MCPSessionManager(self.mock_stdio_connection_params)
|
||||
session = MockClientSession()
|
||||
session.send_ping.return_value = EmptyResult()
|
||||
session._read_stream._closed = True
|
||||
assert manager._is_session_disconnected(session)
|
||||
assert await manager._is_session_disconnected(session)
|
||||
session.send_ping.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_session_disconnected_write_stream_closed(self):
|
||||
"""Test session disconnection detection when write stream is closed."""
|
||||
manager = MCPSessionManager(self.mock_stdio_connection_params)
|
||||
session = MockClientSession()
|
||||
session.send_ping.return_value = EmptyResult()
|
||||
session._write_stream._closed = True
|
||||
assert await manager._is_session_disconnected(session)
|
||||
session.send_ping.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_session_disconnected_ping_fails(self):
|
||||
"""Test session disconnection detection when ping fails."""
|
||||
manager = MCPSessionManager(self.mock_stdio_connection_params)
|
||||
session = MockClientSession()
|
||||
session.send_ping.side_effect = Exception("Ping failed")
|
||||
assert await manager._is_session_disconnected(session)
|
||||
session.send_ping.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_session_disconnected_ping_returns_wrong_result(self):
|
||||
"""Test session disconnection detection when ping returns wrong result."""
|
||||
manager = MCPSessionManager(self.mock_stdio_connection_params)
|
||||
session = MockClientSession()
|
||||
session.send_ping.return_value = "Wrong result"
|
||||
assert await manager._is_session_disconnected(session)
|
||||
session.send_ping.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_stdio_new(self):
|
||||
@@ -271,6 +309,7 @@ class TestMCPSessionManager:
|
||||
# Session is connected
|
||||
existing_session._read_stream._closed = False
|
||||
existing_session._write_stream._closed = False
|
||||
existing_session.send_ping.return_value = EmptyResult()
|
||||
|
||||
session = await manager.create_session()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user