You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: support to customize timeout for mcpstdio connections
Fixes https://github.com/google/adk-python/issues/643 PiperOrigin-RevId: 767788472
This commit is contained in:
committed by
Copybara-Service
parent
fe1de7b103
commit
54367dcc56
@@ -16,8 +16,9 @@
|
||||
import os
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.tools.mcp_tool import StdioConnectionParams
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import StdioServerParameters
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
_allowed_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
@@ -31,13 +32,16 @@ Allowed directory: {_allowed_path}
|
||||
""",
|
||||
tools=[
|
||||
MCPToolset(
|
||||
connection_params=StdioServerParameters(
|
||||
command='npx',
|
||||
args=[
|
||||
'-y', # Arguments for the command
|
||||
'@modelcontextprotocol/server-filesystem',
|
||||
_allowed_path,
|
||||
],
|
||||
connection_params=StdioConnectionParams(
|
||||
server_params=StdioServerParameters(
|
||||
command='npx',
|
||||
args=[
|
||||
'-y', # Arguments for the command
|
||||
'@modelcontextprotocol/server-filesystem',
|
||||
_allowed_path,
|
||||
],
|
||||
),
|
||||
timeout=5,
|
||||
),
|
||||
# don't want agent to do write operation
|
||||
# you can also do below
|
||||
|
||||
@@ -17,6 +17,9 @@ __all__ = []
|
||||
try:
|
||||
from .conversion_utils import adk_to_mcp_tool_type
|
||||
from .conversion_utils import gemini_to_json_schema
|
||||
from .mcp_session_manager import SseConnectionParams
|
||||
from .mcp_session_manager import StdioConnectionParams
|
||||
from .mcp_session_manager import StreamableHTTPConnectionParams
|
||||
from .mcp_tool import MCPTool
|
||||
from .mcp_toolset import MCPToolset
|
||||
|
||||
@@ -25,6 +28,9 @@ try:
|
||||
'gemini_to_json_schema',
|
||||
'MCPTool',
|
||||
'MCPToolset',
|
||||
'StdioConnectionParams',
|
||||
'SseConnectionParams',
|
||||
'StreamableHTTPConnectionParams',
|
||||
])
|
||||
|
||||
except ImportError as e:
|
||||
|
||||
@@ -47,30 +47,61 @@ except ImportError as e:
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
class SseServerParams(BaseModel):
|
||||
class StdioConnectionParams(BaseModel):
|
||||
"""Parameters for the MCP Stdio connection.
|
||||
|
||||
Attributes:
|
||||
server_params: Parameters for the MCP Stdio server.
|
||||
timeout: Timeout in seconds for establishing the connection to the MCP
|
||||
stdio server.
|
||||
"""
|
||||
|
||||
server_params: StdioServerParameters
|
||||
timeout: float = 5.0
|
||||
|
||||
|
||||
class SseConnectionParams(BaseModel):
|
||||
"""Parameters for the MCP SSE connection.
|
||||
|
||||
See MCP SSE Client documentation for more details.
|
||||
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
|
||||
|
||||
Attributes:
|
||||
url: URL for the MCP SSE server.
|
||||
headers: Headers for the MCP SSE connection.
|
||||
timeout: Timeout in seconds for establishing the connection to the MCP SSE
|
||||
server.
|
||||
sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
|
||||
server.
|
||||
"""
|
||||
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5
|
||||
sse_read_timeout: float = 60 * 5
|
||||
timeout: float = 5.0
|
||||
sse_read_timeout: float = 60 * 5.0
|
||||
|
||||
|
||||
class StreamableHTTPServerParams(BaseModel):
|
||||
class StreamableHTTPConnectionParams(BaseModel):
|
||||
"""Parameters for the MCP SSE connection.
|
||||
|
||||
See MCP SSE Client documentation for more details.
|
||||
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py
|
||||
|
||||
Attributes:
|
||||
url: URL for the MCP Streamable HTTP server.
|
||||
headers: Headers for the MCP Streamable HTTP connection.
|
||||
timeout: Timeout in seconds for establishing the connection to the MCP
|
||||
Streamable HTTP server.
|
||||
sse_read_timeout: Timeout in seconds for reading data from the MCP
|
||||
Streamable HTTP server.
|
||||
terminate_on_close: Whether to terminate the MCP Streamable HTTP server
|
||||
when the connection is closed.
|
||||
"""
|
||||
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5
|
||||
sse_read_timeout: float = 60 * 5
|
||||
timeout: float = 5.0
|
||||
sse_read_timeout: float = 60 * 5.0
|
||||
terminate_on_close: bool = True
|
||||
|
||||
|
||||
@@ -142,7 +173,10 @@ class MCPSessionManager:
|
||||
def __init__(
|
||||
self,
|
||||
connection_params: Union[
|
||||
StdioServerParameters, SseServerParams, StreamableHTTPServerParams
|
||||
StdioServerParameters,
|
||||
StdioConnectionParams,
|
||||
SseConnectionParams,
|
||||
StreamableHTTPConnectionParams,
|
||||
],
|
||||
errlog: TextIO = sys.stderr,
|
||||
):
|
||||
@@ -155,7 +189,20 @@ class MCPSessionManager:
|
||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||
initializing a local stdio MCP session.
|
||||
"""
|
||||
self._connection_params = connection_params
|
||||
if isinstance(connection_params, StdioServerParameters):
|
||||
# So far timeout is not configurable. Given MCP is still evolving, we
|
||||
# would expect stdio_client to evolve to accept timeout parameter like
|
||||
# other client.
|
||||
logger.warning(
|
||||
'StdioServerParameters is not recommended. Please use'
|
||||
' StdioConnectionParams.'
|
||||
)
|
||||
self._connection_params = StdioConnectionParams(
|
||||
server_params=connection_params,
|
||||
timeout=5,
|
||||
)
|
||||
else:
|
||||
self._connection_params = connection_params
|
||||
self._errlog = errlog
|
||||
# Each session manager maintains its own exit stack for proper cleanup
|
||||
self._exit_stack: Optional[AsyncExitStack] = None
|
||||
@@ -174,21 +221,19 @@ class MCPSessionManager:
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
try:
|
||||
if isinstance(self._connection_params, StdioServerParameters):
|
||||
# So far timeout is not configurable. Given MCP is still evolving, we
|
||||
# would expect stdio_client to evolve to accept timeout parameter like
|
||||
# other client.
|
||||
if isinstance(self._connection_params, StdioConnectionParams):
|
||||
client = stdio_client(
|
||||
server=self._connection_params, errlog=self._errlog
|
||||
server=self._connection_params.server_params,
|
||||
errlog=self._errlog,
|
||||
)
|
||||
elif isinstance(self._connection_params, SseServerParams):
|
||||
elif isinstance(self._connection_params, SseConnectionParams):
|
||||
client = sse_client(
|
||||
url=self._connection_params.url,
|
||||
headers=self._connection_params.headers,
|
||||
timeout=self._connection_params.timeout,
|
||||
sse_read_timeout=self._connection_params.sse_read_timeout,
|
||||
)
|
||||
elif isinstance(self._connection_params, StreamableHTTPServerParams):
|
||||
elif isinstance(self._connection_params, StreamableHTTPConnectionParams):
|
||||
client = streamablehttp_client(
|
||||
url=self._connection_params.url,
|
||||
headers=self._connection_params.headers,
|
||||
@@ -208,24 +253,13 @@ class MCPSessionManager:
|
||||
transports = await self._exit_stack.enter_async_context(client)
|
||||
# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
|
||||
# needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
|
||||
# The StdioServerParameters does not provide a timeout parameter for the
|
||||
# session, so we need to set a default timeout for it. Other clients
|
||||
# (SseServerParams and StreamableHTTPServerParams) already provide a
|
||||
# timeout parameter in their configuration.
|
||||
if isinstance(self._connection_params, StdioServerParameters):
|
||||
# Default timeout for MCP session is 5 seconds, same as SseServerParams
|
||||
# and StreamableHTTPServerParams.
|
||||
# TODO :
|
||||
# 1. make timeout configurable
|
||||
# 2. Add StdioConnectionParams to include StdioServerParameters as a
|
||||
# field and rename other two params to XXXXConnetionParams. Ohter
|
||||
# two params are actually connection params, while stdio is
|
||||
# special, stdio_client takes the resposibility of starting the
|
||||
# server and working as a client.
|
||||
if isinstance(self._connection_params, StdioConnectionParams):
|
||||
session = await self._exit_stack.enter_async_context(
|
||||
ClientSession(
|
||||
*transports[:2],
|
||||
read_timeout_seconds=timedelta(seconds=5),
|
||||
read_timeout_seconds=timedelta(
|
||||
seconds=self._connection_params.timeout
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -257,3 +291,8 @@ class MCPSessionManager:
|
||||
finally:
|
||||
self._exit_stack = None
|
||||
self._session = None
|
||||
|
||||
|
||||
SseServerParams = SseConnectionParams
|
||||
|
||||
StreamableHTTPServerParams = StreamableHTTPConnectionParams
|
||||
|
||||
@@ -27,7 +27,10 @@ from ..base_toolset import BaseToolset
|
||||
from ..base_toolset import ToolPredicate
|
||||
from .mcp_session_manager import MCPSessionManager
|
||||
from .mcp_session_manager import retry_on_closed_resource
|
||||
from .mcp_session_manager import SseConnectionParams
|
||||
from .mcp_session_manager import SseServerParams
|
||||
from .mcp_session_manager import StdioConnectionParams
|
||||
from .mcp_session_manager import StreamableHTTPConnectionParams
|
||||
from .mcp_session_manager import StreamableHTTPServerParams
|
||||
|
||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||
@@ -85,9 +88,12 @@ class MCPToolset(BaseToolset):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
connection_params: (
|
||||
StdioServerParameters | SseServerParams | StreamableHTTPServerParams
|
||||
),
|
||||
connection_params: Union[
|
||||
StdioServerParameters,
|
||||
StdioConnectionParams,
|
||||
SseConnectionParams,
|
||||
StreamableHTTPConnectionParams,
|
||||
],
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
):
|
||||
@@ -95,12 +101,16 @@ class MCPToolset(BaseToolset):
|
||||
|
||||
Args:
|
||||
connection_params: The connection parameters to the MCP server. Can be:
|
||||
`StdioServerParameters` for using local mcp server (e.g. using `npx` or
|
||||
`python3`); or `SseServerParams` for a local/remote SSE server; or
|
||||
`StreamableHTTPServerParams` for local/remote Streamable http server.
|
||||
tool_filter: Optional filter to select specific tools. Can be either:
|
||||
- A list of tool names to include
|
||||
- A ToolPredicate function for custom filtering logic
|
||||
`StdioConnectionParams` for using local mcp server (e.g. using `npx` or
|
||||
`python3`); or `SseConnectionParams` for a local/remote SSE server; or
|
||||
`StreamableHTTPConnectionParams` for local/remote Streamable http
|
||||
server. Note, `StdioServerParameters` is also supported for using local
|
||||
mcp server (e.g. using `npx` or `python3` ), but it does not support
|
||||
timeout, and we recommend to use `StdioConnectionParams` instead when
|
||||
timeout is needed.
|
||||
tool_filter: Optional filter to select specific tools. Can be either: - A
|
||||
list of tool names to include - A ToolPredicate function for custom
|
||||
filtering logic
|
||||
errlog: TextIO stream for error logging.
|
||||
"""
|
||||
super().__init__(tool_filter=tool_filter)
|
||||
|
||||
Reference in New Issue
Block a user