From adbc37fea1ae783416c62f74f954dd4b7e249fa1 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Thu, 5 Feb 2026 11:04:01 -0800 Subject: [PATCH] feat: Add progress_callback support to MCPTool and MCPToolset Fixes: https://github.com/google/adk-python/issues/3811 Co-authored-by: Xuan Yang PiperOrigin-RevId: 866025995 --- .../mcp_progress_callback_agent/__init__.py | 15 ++ .../mcp_progress_callback_agent/agent.py | 166 ++++++++++++++++++ .../mock_progress_server.py | 161 +++++++++++++++++ src/google/adk/tools/mcp_tool/mcp_tool.py | 132 +++++++++++++- src/google/adk/tools/mcp_tool/mcp_toolset.py | 45 +++-- .../unittests/tools/mcp_tool/test_mcp_tool.py | 100 ++++++++++- .../tools/mcp_tool/test_mcp_toolset.py | 95 ++++++++++ 7 files changed, 695 insertions(+), 19 deletions(-) create mode 100644 contributing/samples/mcp_progress_callback_agent/__init__.py create mode 100644 contributing/samples/mcp_progress_callback_agent/agent.py create mode 100644 contributing/samples/mcp_progress_callback_agent/mock_progress_server.py diff --git a/contributing/samples/mcp_progress_callback_agent/__init__.py b/contributing/samples/mcp_progress_callback_agent/__init__.py new file mode 100644 index 00000000..4015e47d --- /dev/null +++ b/contributing/samples/mcp_progress_callback_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/mcp_progress_callback_agent/agent.py b/contributing/samples/mcp_progress_callback_agent/agent.py new file mode 100644 index 00000000..756d646f --- /dev/null +++ b/contributing/samples/mcp_progress_callback_agent/agent.py @@ -0,0 +1,166 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample agent demonstrating MCP progress callback feature. + +This sample shows how to use the progress_callback parameter in McpToolset +to receive progress notifications from MCP servers during long-running tool +executions. + +There are two ways to use progress callbacks: + +1. Simple callback (shared by all tools): + Pass a ProgressFnT callback that receives (progress, total, message). + +2. Factory function (per-tool callbacks with runtime context): + Pass a ProgressCallbackFactory that takes (tool_name, callback_context, **kwargs) + and returns a ProgressFnT or None. This allows different tools to have different + progress handling logic, and the factory can access and modify session state + via the CallbackContext. The **kwargs ensures forward compatibility for future + parameters. + +IMPORTANT: Progress callbacks only work when the MCP server actually sends +progress notifications. Most simple MCP servers (like the filesystem server) +do not send progress updates. This sample uses a mock server that demonstrates +progress reporting. + +Usage: + adk run contributing/samples/mcp_progress_callback_agent + +Then try: + "Run the long running task with 5 steps" + "Process these items: apple, banana, cherry" +""" + +import os +import sys +from typing import Any + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.mcp_tool import McpToolset +from google.adk.tools.mcp_tool import StdioConnectionParams +from mcp import StdioServerParameters +from mcp.shared.session import ProgressFnT + +_current_dir = os.path.dirname(os.path.abspath(__file__)) +_mock_server_path = os.path.join(_current_dir, "mock_progress_server.py") + + +# Option 1: Simple shared callback +async def simple_progress_callback( + progress: float, + total: float | None, + message: str | None, +) -> None: + """Handle progress notifications from MCP server. + + This callback is shared by all tools in the toolset. + """ + if total is not None: + percentage = (progress / total) * 100 + bar_length = 20 + filled = int(bar_length * progress / total) + bar = "=" * filled + "-" * (bar_length - filled) + print(f"[{bar}] {percentage:.0f}% ({progress}/{total}) {message or ''}") + else: + print(f"Progress: {progress} {f'- {message}' if message else ''}") + + +# Option 2: Factory function for per-tool callbacks with runtime context +def progress_callback_factory( + tool_name: str, + *, + callback_context: CallbackContext | None = None, + **kwargs: Any, +) -> ProgressFnT | None: + """Create a progress callback for a specific tool. + + This factory allows different tools to have different progress handling. + It receives a CallbackContext for accessing and modifying runtime information + like session state. The **kwargs parameter ensures forward compatibility. + + Args: + tool_name: The name of the MCP tool. + callback_context: The callback context providing access to session, + state, artifacts, and other runtime information. Allows modifying + state via ctx.state['key'] = value. May be None if not available. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + A progress callback function, or None if no callback is needed. + """ + # Example: Access session info from context (if available) + session_id = "unknown" + if callback_context and callback_context.session: + session_id = callback_context.session.id + + async def callback( + progress: float, + total: float | None, + message: str | None, + ) -> None: + # Include tool name and session info in the progress output + prefix = f"[{tool_name}][session:{session_id}]" + if total is not None: + percentage = (progress / total) * 100 + bar_length = 20 + filled = int(bar_length * progress / total) + bar = "=" * filled + "-" * (bar_length - filled) + print(f"{prefix} [{bar}] {percentage:.0f}% {message or ''}") + # Example: Store progress in state (callback_context allows modification) + if callback_context: + callback_context.state["last_progress"] = progress + callback_context.state["last_total"] = total + else: + print( + f"{prefix} Progress: {progress} {f'- {message}' if message else ''}" + ) + + return callback + + +root_agent = LlmAgent( + model="gemini-2.5-flash", + name="progress_demo_agent", + instruction="""\ +You are a helpful assistant that can run long-running tasks. + +Available tools: +- long_running_task: Simulates a task with multiple steps. You can specify + the number of steps and delay between them. +- process_items: Processes a list of items one by one with progress updates. + +When the user asks you to run a task, use these tools and the progress +will be logged automatically. + +Example requests: +- "Run a long task with 5 steps" +- "Process these items: apple, banana, cherry, date" + """, + tools=[ + McpToolset( + connection_params=StdioConnectionParams( + server_params=StdioServerParameters( + command=sys.executable, # Use current Python interpreter + args=[_mock_server_path], + ), + timeout=60, + ), + # Use factory function for per-tool callbacks (Option 2) + # Or use simple_progress_callback for shared callback (Option 1) + progress_callback=progress_callback_factory, + ) + ], +) diff --git a/contributing/samples/mcp_progress_callback_agent/mock_progress_server.py b/contributing/samples/mcp_progress_callback_agent/mock_progress_server.py new file mode 100644 index 00000000..948522d6 --- /dev/null +++ b/contributing/samples/mcp_progress_callback_agent/mock_progress_server.py @@ -0,0 +1,161 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mock MCP server that sends progress notifications. + +This server demonstrates how MCP servers can send progress updates +during long-running tool execution. + +Run this server directly: + python mock_progress_server.py + +Or use it with the sample agent: + See agent_with_mock_server.py +""" + +import asyncio + +from mcp.server import Server +from mcp.server.stdio import stdio_server +from mcp.types import TextContent +from mcp.types import Tool + +server = Server("mock-progress-server") + + +@server.list_tools() +async def list_tools() -> list[Tool]: + """List available tools.""" + return [ + Tool( + name="long_running_task", + description=( + "A simulated long-running task that reports progress. " + "Use this to test progress callback functionality." + ), + inputSchema={ + "type": "object", + "properties": { + "steps": { + "type": "integer", + "description": "Number of steps to simulate (default: 5)", + "default": 5, + }, + "delay": { + "type": "number", + "description": ( + "Delay in seconds between steps (default: 0.5)" + ), + "default": 0.5, + }, + }, + }, + ), + Tool( + name="process_items", + description="Process a list of items with progress reporting.", + inputSchema={ + "type": "object", + "properties": { + "items": { + "type": "array", + "items": {"type": "string"}, + "description": "List of items to process", + }, + }, + "required": ["items"], + }, + ), + ] + + +@server.call_tool() +async def call_tool(name: str, arguments: dict) -> list[TextContent]: + """Handle tool calls with progress reporting.""" + ctx = server.request_context + + if name == "long_running_task": + steps = arguments.get("steps", 5) + delay = arguments.get("delay", 0.5) + + # Get progress token from request metadata + progress_token = None + if ctx.meta and hasattr(ctx.meta, "progressToken"): + progress_token = ctx.meta.progressToken + + for i in range(steps): + # Simulate work + await asyncio.sleep(delay) + + # Send progress notification if client supports it + if progress_token is not None: + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=i + 1, + total=steps, + message=f"Completed step {i + 1} of {steps}", + ) + + return [ + TextContent( + type="text", + text=f"Successfully completed {steps} steps!", + ) + ] + + elif name == "process_items": + items = arguments.get("items", []) + total = len(items) + + progress_token = None + if ctx.meta and hasattr(ctx.meta, "progressToken"): + progress_token = ctx.meta.progressToken + + results = [] + for i, item in enumerate(items): + # Simulate processing + await asyncio.sleep(0.3) + results.append(f"Processed: {item}") + + # Send progress + if progress_token is not None: + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=i + 1, + total=total, + message=f"Processing item: {item}", + ) + + return [ + TextContent( + type="text", + text="\n".join(results), + ) + ] + + return [TextContent(type="text", text=f"Unknown tool: {name}")] + + +async def main(): + """Run the MCP server.""" + async with stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 719ed662..8c36a9c8 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import base64 import inspect import logging @@ -21,14 +22,18 @@ from typing import Any from typing import Callable from typing import Dict from typing import Optional +from typing import Protocol +from typing import runtime_checkable from typing import Union import warnings from fastapi.openapi.models import APIKeyIn from google.genai.types import FunctionDeclaration +from mcp.shared.session import ProgressFnT from mcp.types import Tool as McpBaseTool from typing_extensions import override +from ...agents.callback_context import CallbackContext from ...agents.readonly_context import ReadonlyContext from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme @@ -37,7 +42,6 @@ from ...features import FeatureName from ...features import is_feature_enabled from .._gemini_schema_util import _to_gemini_schema from ..base_authenticated_tool import BaseAuthenticatedTool -# import from ..tool_context import ToolContext from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_errors @@ -45,6 +49,68 @@ from .mcp_session_manager import retry_on_errors logger = logging.getLogger("google_adk." + __name__) +@runtime_checkable +class ProgressCallbackFactory(Protocol): + """Factory protocol for creating per-tool progress callbacks. + + This protocol allows users to create different progress callbacks for + different tools based on tool name and runtime context. The factory receives + the tool name, a CallbackContext for accessing and modifying session state, + and additional keyword arguments for forward compatibility. + + Example usage:: + + def my_callback_factory( + tool_name: str, + *, + callback_context: CallbackContext | None = None, + **kwargs + ) -> ProgressFnT | None: + session_id = callback_context.session.id if callback_context else "N/A" + + async def callback(progress, total, message): + print(f"[{tool_name}] Session {session_id}: {progress}/{total}") + # Can modify state in the callback + if callback_context: + callback_context.state['last_progress'] = progress + + return callback + + toolset = McpToolset( + connection_params=..., + progress_callback=my_callback_factory, + ) + + Note: + The **kwargs parameter is required for forward compatibility. Future + versions may pass additional parameters. Implementations should accept + **kwargs even if they don't use them. + """ + + def __call__( + self, + tool_name: str, + *, + callback_context: Optional[CallbackContext] = None, + **kwargs: Any, + ) -> Optional[ProgressFnT]: + """Create a progress callback for a specific tool. + + Args: + tool_name: The name of the MCP tool. + callback_context: The callback context providing access to session, + state, artifacts, and other runtime information. Allows modifying + state via ctx.state['key'] = value. May be None if not available. + **kwargs: Additional keyword arguments for future extensibility. + Implementations should accept **kwargs for forward compatibility. + + Returns: + A progress callback function, or None if no callback is needed + for this tool. + """ + ... + + class McpTool(BaseAuthenticatedTool): """Turns an MCP Tool into an ADK Tool. @@ -66,6 +132,9 @@ class McpTool(BaseAuthenticatedTool): header_provider: Optional[ Callable[[ReadonlyContext], Dict[str, str]] ] = None, + progress_callback: Optional[ + Union[ProgressFnT, ProgressCallbackFactory] + ] = None, ): """Initializes an McpTool. @@ -81,6 +150,17 @@ class McpTool(BaseAuthenticatedTool): or a callable that takes the function's arguments and returns a boolean. If the callable returns True, the tool will require confirmation from the user. + header_provider: Optional function to provide dynamic headers. + progress_callback: Optional callback to receive progress notifications + from MCP server during long-running tool execution. Can be either: + + - A ``ProgressFnT`` callback that receives (progress, total, message). + This callback will be used for all invocations. + + - A ``ProgressCallbackFactory`` that creates per-invocation callbacks. + The factory receives (tool_name, callback_context, **kwargs) and + returns a ProgressFnT or None. This allows callbacks to access + and modify runtime context like session state. Raises: ValueError: If mcp_tool or mcp_session_manager is None. @@ -98,6 +178,7 @@ class McpTool(BaseAuthenticatedTool): self._mcp_session_manager = mcp_session_manager self._require_confirmation = require_confirmation self._header_provider = header_provider + self._progress_callback = progress_callback @override def _get_declaration(self) -> FunctionDeclaration: @@ -237,9 +318,50 @@ class McpTool(BaseAuthenticatedTool): headers=final_headers ) - response = await session.call_tool(self._mcp_tool.name, arguments=args) + # Resolve progress callback (may be a factory that needs runtime context) + resolved_callback = self._resolve_progress_callback(tool_context) + + response = await session.call_tool( + self._mcp_tool.name, + arguments=args, + progress_callback=resolved_callback, + ) return response.model_dump(exclude_none=True, mode="json") + def _resolve_progress_callback( + self, tool_context: ToolContext + ) -> Optional[ProgressFnT]: + """Resolve the progress callback for the current invocation. + + If progress_callback is a ProgressCallbackFactory, call it to create + a callback with runtime context. Otherwise, return the callback directly. + + Args: + tool_context: The tool context for the current invocation. + + Returns: + The resolved progress callback, or None if not configured. + """ + if ( + not hasattr(self, "_progress_callback") + or self._progress_callback is None + ): + return None + + # Determine if callback is a factory by checking if it's a coroutine + # function. ProgressFnT is an async function, while ProgressCallbackFactory + # is a sync function that returns an async function. + if asyncio.iscoroutinefunction(self._progress_callback): + return self._progress_callback + + # If it's a regular callable (not async), treat it as a factory + if callable(self._progress_callback) and not inspect.iscoroutinefunction( + self._progress_callback + ): + return self._progress_callback(self.name, callback_context=tool_context) + + return self._progress_callback + async def _get_headers( self, tool_context: ToolContext, credential: AuthCredential ) -> Optional[dict[str, str]]: @@ -253,7 +375,8 @@ class McpTool(BaseAuthenticatedTool): Dictionary of headers to add to the request, or None if no auth. Raises: - ValueError: If API key authentication is configured for non-header location. + ValueError: If API key authentication is configured for non-header + location. """ headers: Optional[dict[str, str]] = None if credential: @@ -284,7 +407,8 @@ class McpTool(BaseAuthenticatedTool): # Handle other HTTP schemes with token headers = { "Authorization": ( - f"{credential.http.scheme} {credential.http.credentials.token}" + f"{credential.http.scheme}" + f" {credential.http.credentials.token}" ) } elif credential.api_key: diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 598452c7..43641137 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -31,6 +31,7 @@ from typing import Union import warnings from mcp import StdioServerParameters +from mcp.shared.session import ProgressFnT from mcp.types import ListResourcesResult from mcp.types import ListToolsResult from pydantic import model_validator @@ -51,6 +52,7 @@ from .mcp_session_manager import SseConnectionParams from .mcp_session_manager import StdioConnectionParams from .mcp_session_manager import StreamableHTTPConnectionParams from .mcp_tool import MCPTool +from .mcp_tool import ProgressCallbackFactory logger = logging.getLogger("google_adk." + __name__) @@ -72,7 +74,8 @@ class McpToolset(BaseToolset): command='npx', args=["-y", "@modelcontextprotocol/server-filesystem"], ), - tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools + tool_filter=['read_file', 'list_directory'] # Optional: filter specific + tools ) # Use in an agent @@ -106,18 +109,21 @@ class McpToolset(BaseToolset): header_provider: Optional[ Callable[[ReadonlyContext], Dict[str, str]] ] = None, + progress_callback: Optional[ + Union[ProgressFnT, ProgressCallbackFactory] + ] = None, ): """Initializes the McpToolset. Args: connection_params: The connection parameters to the MCP server. Can be: - ``StdioConnectionParams`` for using local mcp server (e.g. using ``npx`` or - ``python3``); or ``SseConnectionParams`` for a local/remote SSE server; or - ``StreamableHTTPConnectionParams`` for local/remote Streamable http - server. Note, ``StdioServerParameters`` is also supported for using local - mcp server (e.g. using ``npx`` or ``python3`` ), but it does not support - timeout, and we recommend to use ``StdioConnectionParams`` instead when - timeout is needed. + ``StdioConnectionParams`` for using local mcp server (e.g. using ``npx`` + or ``python3``); or ``SseConnectionParams`` for a local/remote SSE + server; or ``StreamableHTTPConnectionParams`` for local/remote + Streamable http server. Note, ``StdioServerParameters`` is also + supported for using local mcp server (e.g. using ``npx`` or ``python3`` + ), but it does not support timeout, and we recommend to use + ``StdioConnectionParams`` instead when timeout is needed. tool_filter: Optional filter to select specific tools. Can be either: - A list of tool names to include - A ToolPredicate function for custom filtering logic @@ -126,11 +132,22 @@ class McpToolset(BaseToolset): errlog: TextIO stream for error logging. auth_scheme: The auth scheme of the tool for tool calling auth_credential: The auth credential of the tool for tool calling - require_confirmation: Whether tools in this toolset require - confirmation. Can be a single boolean or a callable to apply to all - tools. + require_confirmation: Whether tools in this toolset require confirmation. + Can be a single boolean or a callable to apply to all tools. header_provider: A callable that takes a ReadonlyContext and returns a dictionary of headers to be used for the MCP session. + progress_callback: Optional callback to receive progress notifications + from MCP server during long-running tool execution. Can be either: + + - A ``ProgressFnT`` callback that receives (progress, total, message). + This callback will be shared by all tools in the toolset. + + - A ``ProgressCallbackFactory`` that creates per-tool callbacks. The + factory receives (tool_name, callback_context, **kwargs) and returns + a ProgressFnT or None. This allows different tools to have different + progress handling logic and access/modify session state via the + CallbackContext. The **kwargs parameter allows for future + extensibility. """ super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix) @@ -140,6 +157,7 @@ class McpToolset(BaseToolset): self._connection_params = connection_params self._errlog = errlog self._header_provider = header_provider + self._progress_callback = progress_callback # Create the session manager that will handle the MCP connection self._mcp_session_manager = MCPSessionManager( @@ -270,7 +288,7 @@ class McpToolset(BaseToolset): Args: readonly_context: Context used to filter tools available to the agent. - If None, all tools in the toolset are returned. + If None, all tools in the toolset are returned. Returns: List[BaseTool]: A list of tools available under the specified context. @@ -292,6 +310,9 @@ class McpToolset(BaseToolset): auth_credential=self._auth_credential, require_confirmation=self._require_confirmation, header_provider=self._header_provider, + progress_callback=self._progress_callback + if hasattr(self, "_progress_callback") + else None, ) if self._is_tool_selected(mcp_tool, readonly_context): diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 09e3529a..64dfac48 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -225,7 +225,7 @@ class TestMCPTool: ) # Fix: call_tool uses 'arguments' parameter, not positional args self.mock_session.call_tool.assert_called_once_with( - "test_tool", arguments=args + "test_tool", arguments=args, progress_callback=None ) @pytest.mark.asyncio @@ -778,7 +778,7 @@ class TestMCPTool: headers=expected_headers ) self.mock_session.call_tool.assert_called_once_with( - "test_tool", arguments=args + "test_tool", arguments=args, progress_callback=None ) @pytest.mark.asyncio @@ -821,5 +821,99 @@ class TestMCPTool: "X-Tenant-ID": "test-tenant", } self.mock_session.call_tool.assert_called_once_with( - "test_tool", arguments=args + "test_tool", arguments=args, progress_callback=None ) + + def test_init_with_progress_callback(self): + """Test initialization with progress_callback.""" + + async def my_progress_callback( + progress: float, total: float | None, message: str | None + ) -> None: + pass + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + progress_callback=my_progress_callback, + ) + + assert tool._progress_callback == my_progress_callback + + @pytest.mark.asyncio + async def test_run_async_impl_with_progress_callback(self): + """Test running tool with progress_callback.""" + progress_updates = [] + + async def my_progress_callback( + progress: float, total: float | None, message: str | None + ) -> None: + progress_updates.append((progress, total, message)) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + progress_callback=my_progress_callback, + ) + + # Mock the session response + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=None + ) + + assert result == mcp_response.model_dump(exclude_none=True, mode="json") + self.mock_session_manager.create_session.assert_called_once_with( + headers=None + ) + # Verify progress_callback was passed to call_tool + self.mock_session.call_tool.assert_called_once_with( + "test_tool", arguments=args, progress_callback=my_progress_callback + ) + + @pytest.mark.asyncio + async def test_run_async_impl_with_progress_callback_factory(self): + """Test running tool with progress_callback factory that receives context.""" + factory_calls = [] + + def my_callback_factory(tool_name: str, *, callback_context=None, **kwargs): + factory_calls.append((tool_name, callback_context)) + + async def callback( + progress: float, total: float | None, message: str | None + ) -> None: + pass + + return callback + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + progress_callback=my_callback_factory, + ) + + # Mock the session response + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + await tool._run_async_impl( + args=args, tool_context=tool_context, credential=None + ) + + # Verify factory was called with tool name and tool_context as callback_context + assert len(factory_calls) == 1 + assert factory_calls[0][0] == "test_tool" + # callback_context is the tool_context itself (ToolContext extends CallbackContext) + assert factory_calls[0][1] is tool_context diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 2c925609..b57ec476 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -360,6 +360,101 @@ class TestMcpToolset: assert tools[0].name == "tool1" assert tools[1].name == "tool2" + def test_init_with_progress_callback(self): + """Test initialization with progress_callback.""" + + async def my_progress_callback( + progress: float, total: float | None, message: str | None + ) -> None: + pass + + toolset = McpToolset( + connection_params=self.mock_stdio_params, + progress_callback=my_progress_callback, + ) + + assert toolset._progress_callback == my_progress_callback + + @pytest.mark.asyncio + async def test_get_tools_passes_progress_callback_to_mcp_tools(self): + """Test that get_tools passes progress_callback to created MCPTool instances.""" + progress_updates = [] + + async def my_progress_callback( + progress: float, total: float | None, message: str | None + ) -> None: + progress_updates.append((progress, total, message)) + + mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + toolset = McpToolset( + connection_params=self.mock_stdio_params, + progress_callback=my_progress_callback, + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + # Verify each tool has the progress_callback set + for tool in tools: + assert tool._progress_callback == my_progress_callback + + def test_init_with_progress_callback_factory(self): + """Test initialization with a ProgressCallbackFactory.""" + + def my_callback_factory(tool_name: str, *, readonly_context=None, **kwargs): + async def callback( + progress: float, total: float | None, message: str | None + ) -> None: + pass + + return callback + + toolset = McpToolset( + connection_params=self.mock_stdio_params, + progress_callback=my_callback_factory, + ) + + assert toolset._progress_callback == my_callback_factory + + @pytest.mark.asyncio + async def test_get_tools_passes_factory_to_mcp_tools(self): + """Test that get_tools passes factory directly to MCPTool instances. + + The factory is resolved at runtime in McpTool._run_async_impl, not at + tool creation time. This allows the factory to receive ReadonlyContext. + """ + + def my_callback_factory(tool_name: str, *, readonly_context=None, **kwargs): + async def callback( + progress: float, total: float | None, message: str | None + ) -> None: + pass + + return callback + + mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + toolset = McpToolset( + connection_params=self.mock_stdio_params, + progress_callback=my_callback_factory, + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + # Factory is passed directly to each tool (resolved at runtime) + for tool in tools: + assert tool._progress_callback == my_callback_factory + @pytest.mark.asyncio async def test_list_resources(self): """Test listing resources."""