diff --git a/contributing/samples/mcp_stdio_server_agent/agent.py b/contributing/samples/mcp_stdio_server_agent/agent.py index a14ab439..fe8b75c2 100755 --- a/contributing/samples/mcp_stdio_server_agent/agent.py +++ b/contributing/samples/mcp_stdio_server_agent/agent.py @@ -16,8 +16,9 @@ import os from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.mcp_tool import StdioConnectionParams from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset -from google.adk.tools.mcp_tool.mcp_toolset import StdioServerParameters +from mcp import StdioServerParameters _allowed_path = os.path.dirname(os.path.abspath(__file__)) @@ -31,13 +32,16 @@ Allowed directory: {_allowed_path} """, tools=[ MCPToolset( - connection_params=StdioServerParameters( - command='npx', - args=[ - '-y', # Arguments for the command - '@modelcontextprotocol/server-filesystem', - _allowed_path, - ], + connection_params=StdioConnectionParams( + server_params=StdioServerParameters( + command='npx', + args=[ + '-y', # Arguments for the command + '@modelcontextprotocol/server-filesystem', + _allowed_path, + ], + ), + timeout=5, ), # don't want agent to do write operation # you can also do below diff --git a/src/google/adk/tools/mcp_tool/__init__.py b/src/google/adk/tools/mcp_tool/__init__.py index b849b1f7..bd28c4f4 100644 --- a/src/google/adk/tools/mcp_tool/__init__.py +++ b/src/google/adk/tools/mcp_tool/__init__.py @@ -17,6 +17,9 @@ __all__ = [] try: from .conversion_utils import adk_to_mcp_tool_type from .conversion_utils import gemini_to_json_schema + from .mcp_session_manager import SseConnectionParams + from .mcp_session_manager import StdioConnectionParams + from .mcp_session_manager import StreamableHTTPConnectionParams from .mcp_tool import MCPTool from .mcp_toolset import MCPToolset @@ -25,6 +28,9 @@ try: 'gemini_to_json_schema', 'MCPTool', 'MCPToolset', + 'StdioConnectionParams', + 'SseConnectionParams', + 'StreamableHTTPConnectionParams', ]) except ImportError as e: diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 13a9b612..3a07a6fe 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -47,30 +47,61 @@ except ImportError as e: logger = logging.getLogger('google_adk.' + __name__) -class SseServerParams(BaseModel): +class StdioConnectionParams(BaseModel): + """Parameters for the MCP Stdio connection. + + Attributes: + server_params: Parameters for the MCP Stdio server. + timeout: Timeout in seconds for establishing the connection to the MCP + stdio server. + """ + + server_params: StdioServerParameters + timeout: float = 5.0 + + +class SseConnectionParams(BaseModel): """Parameters for the MCP SSE connection. See MCP SSE Client documentation for more details. https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py + + Attributes: + url: URL for the MCP SSE server. + headers: Headers for the MCP SSE connection. + timeout: Timeout in seconds for establishing the connection to the MCP SSE + server. + sse_read_timeout: Timeout in seconds for reading data from the MCP SSE + server. """ url: str headers: dict[str, Any] | None = None - timeout: float = 5 - sse_read_timeout: float = 60 * 5 + timeout: float = 5.0 + sse_read_timeout: float = 60 * 5.0 -class StreamableHTTPServerParams(BaseModel): +class StreamableHTTPConnectionParams(BaseModel): """Parameters for the MCP SSE connection. See MCP SSE Client documentation for more details. https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py + + Attributes: + url: URL for the MCP Streamable HTTP server. + headers: Headers for the MCP Streamable HTTP connection. + timeout: Timeout in seconds for establishing the connection to the MCP + Streamable HTTP server. + sse_read_timeout: Timeout in seconds for reading data from the MCP + Streamable HTTP server. + terminate_on_close: Whether to terminate the MCP Streamable HTTP server + when the connection is closed. """ url: str headers: dict[str, Any] | None = None - timeout: float = 5 - sse_read_timeout: float = 60 * 5 + timeout: float = 5.0 + sse_read_timeout: float = 60 * 5.0 terminate_on_close: bool = True @@ -142,7 +173,10 @@ class MCPSessionManager: def __init__( self, connection_params: Union[ - StdioServerParameters, SseServerParams, StreamableHTTPServerParams + StdioServerParameters, + StdioConnectionParams, + SseConnectionParams, + StreamableHTTPConnectionParams, ], errlog: TextIO = sys.stderr, ): @@ -155,7 +189,20 @@ class MCPSessionManager: errlog: (Optional) TextIO stream for error logging. Use only for initializing a local stdio MCP session. """ - self._connection_params = connection_params + if isinstance(connection_params, StdioServerParameters): + # So far timeout is not configurable. Given MCP is still evolving, we + # would expect stdio_client to evolve to accept timeout parameter like + # other client. + logger.warning( + 'StdioServerParameters is not recommended. Please use' + ' StdioConnectionParams.' + ) + self._connection_params = StdioConnectionParams( + server_params=connection_params, + timeout=5, + ) + else: + self._connection_params = connection_params self._errlog = errlog # Each session manager maintains its own exit stack for proper cleanup self._exit_stack: Optional[AsyncExitStack] = None @@ -174,21 +221,19 @@ class MCPSessionManager: self._exit_stack = AsyncExitStack() try: - if isinstance(self._connection_params, StdioServerParameters): - # So far timeout is not configurable. Given MCP is still evolving, we - # would expect stdio_client to evolve to accept timeout parameter like - # other client. + if isinstance(self._connection_params, StdioConnectionParams): client = stdio_client( - server=self._connection_params, errlog=self._errlog + server=self._connection_params.server_params, + errlog=self._errlog, ) - elif isinstance(self._connection_params, SseServerParams): + elif isinstance(self._connection_params, SseConnectionParams): client = sse_client( url=self._connection_params.url, headers=self._connection_params.headers, timeout=self._connection_params.timeout, sse_read_timeout=self._connection_params.sse_read_timeout, ) - elif isinstance(self._connection_params, StreamableHTTPServerParams): + elif isinstance(self._connection_params, StreamableHTTPConnectionParams): client = streamablehttp_client( url=self._connection_params.url, headers=self._connection_params.headers, @@ -208,24 +253,13 @@ class MCPSessionManager: transports = await self._exit_stack.enter_async_context(client) # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. - # The StdioServerParameters does not provide a timeout parameter for the - # session, so we need to set a default timeout for it. Other clients - # (SseServerParams and StreamableHTTPServerParams) already provide a - # timeout parameter in their configuration. - if isinstance(self._connection_params, StdioServerParameters): - # Default timeout for MCP session is 5 seconds, same as SseServerParams - # and StreamableHTTPServerParams. - # TODO : - # 1. make timeout configurable - # 2. Add StdioConnectionParams to include StdioServerParameters as a - # field and rename other two params to XXXXConnetionParams. Ohter - # two params are actually connection params, while stdio is - # special, stdio_client takes the resposibility of starting the - # server and working as a client. + if isinstance(self._connection_params, StdioConnectionParams): session = await self._exit_stack.enter_async_context( ClientSession( *transports[:2], - read_timeout_seconds=timedelta(seconds=5), + read_timeout_seconds=timedelta( + seconds=self._connection_params.timeout + ), ) ) else: @@ -257,3 +291,8 @@ class MCPSessionManager: finally: self._exit_stack = None self._session = None + + +SseServerParams = SseConnectionParams + +StreamableHTTPServerParams = StreamableHTTPConnectionParams diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 56c05ba8..8076752b 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -27,7 +27,10 @@ from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_closed_resource +from .mcp_session_manager import SseConnectionParams from .mcp_session_manager import SseServerParams +from .mcp_session_manager import StdioConnectionParams +from .mcp_session_manager import StreamableHTTPConnectionParams from .mcp_session_manager import StreamableHTTPServerParams # Attempt to import MCP Tool from the MCP library, and hints user to upgrade @@ -85,9 +88,12 @@ class MCPToolset(BaseToolset): def __init__( self, *, - connection_params: ( - StdioServerParameters | SseServerParams | StreamableHTTPServerParams - ), + connection_params: Union[ + StdioServerParameters, + StdioConnectionParams, + SseConnectionParams, + StreamableHTTPConnectionParams, + ], tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, errlog: TextIO = sys.stderr, ): @@ -95,12 +101,16 @@ class MCPToolset(BaseToolset): Args: connection_params: The connection parameters to the MCP server. Can be: - `StdioServerParameters` for using local mcp server (e.g. using `npx` or - `python3`); or `SseServerParams` for a local/remote SSE server; or - `StreamableHTTPServerParams` for local/remote Streamable http server. - tool_filter: Optional filter to select specific tools. Can be either: - - A list of tool names to include - - A ToolPredicate function for custom filtering logic + `StdioConnectionParams` for using local mcp server (e.g. using `npx` or + `python3`); or `SseConnectionParams` for a local/remote SSE server; or + `StreamableHTTPConnectionParams` for local/remote Streamable http + server. Note, `StdioServerParameters` is also supported for using local + mcp server (e.g. using `npx` or `python3` ), but it does not support + timeout, and we recommend to use `StdioConnectionParams` instead when + timeout is needed. + tool_filter: Optional filter to select specific tools. Can be either: - A + list of tool names to include - A ToolPredicate function for custom + filtering logic errlog: TextIO stream for error logging. """ super().__init__(tool_filter=tool_filter)