feat: Add progress_callback support to MCPTool and MCPToolset

Fixes: https://github.com/google/adk-python/issues/3811

Co-authored-by: Xuan Yang <xygoogle@google.com>
PiperOrigin-RevId: 866025995
This commit is contained in:
Xuan Yang
2026-02-05 11:04:01 -08:00
committed by Copybara-Service
parent 9b112e2d13
commit adbc37fea1
7 changed files with 695 additions and 19 deletions
@@ -0,0 +1,15 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
@@ -0,0 +1,166 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample agent demonstrating MCP progress callback feature.
This sample shows how to use the progress_callback parameter in McpToolset
to receive progress notifications from MCP servers during long-running tool
executions.
There are two ways to use progress callbacks:
1. Simple callback (shared by all tools):
Pass a ProgressFnT callback that receives (progress, total, message).
2. Factory function (per-tool callbacks with runtime context):
Pass a ProgressCallbackFactory that takes (tool_name, callback_context, **kwargs)
and returns a ProgressFnT or None. This allows different tools to have different
progress handling logic, and the factory can access and modify session state
via the CallbackContext. The **kwargs ensures forward compatibility for future
parameters.
IMPORTANT: Progress callbacks only work when the MCP server actually sends
progress notifications. Most simple MCP servers (like the filesystem server)
do not send progress updates. This sample uses a mock server that demonstrates
progress reporting.
Usage:
adk run contributing/samples/mcp_progress_callback_agent
Then try:
"Run the long running task with 5 steps"
"Process these items: apple, banana, cherry"
"""
import os
import sys
from typing import Any
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.tools.mcp_tool import McpToolset
from google.adk.tools.mcp_tool import StdioConnectionParams
from mcp import StdioServerParameters
from mcp.shared.session import ProgressFnT
_current_dir = os.path.dirname(os.path.abspath(__file__))
_mock_server_path = os.path.join(_current_dir, "mock_progress_server.py")
# Option 1: Simple shared callback
async def simple_progress_callback(
progress: float,
total: float | None,
message: str | None,
) -> None:
"""Handle progress notifications from MCP server.
This callback is shared by all tools in the toolset.
"""
if total is not None:
percentage = (progress / total) * 100
bar_length = 20
filled = int(bar_length * progress / total)
bar = "=" * filled + "-" * (bar_length - filled)
print(f"[{bar}] {percentage:.0f}% ({progress}/{total}) {message or ''}")
else:
print(f"Progress: {progress} {f'- {message}' if message else ''}")
# Option 2: Factory function for per-tool callbacks with runtime context
def progress_callback_factory(
tool_name: str,
*,
callback_context: CallbackContext | None = None,
**kwargs: Any,
) -> ProgressFnT | None:
"""Create a progress callback for a specific tool.
This factory allows different tools to have different progress handling.
It receives a CallbackContext for accessing and modifying runtime information
like session state. The **kwargs parameter ensures forward compatibility.
Args:
tool_name: The name of the MCP tool.
callback_context: The callback context providing access to session,
state, artifacts, and other runtime information. Allows modifying
state via ctx.state['key'] = value. May be None if not available.
**kwargs: Additional keyword arguments for future extensibility.
Returns:
A progress callback function, or None if no callback is needed.
"""
# Example: Access session info from context (if available)
session_id = "unknown"
if callback_context and callback_context.session:
session_id = callback_context.session.id
async def callback(
progress: float,
total: float | None,
message: str | None,
) -> None:
# Include tool name and session info in the progress output
prefix = f"[{tool_name}][session:{session_id}]"
if total is not None:
percentage = (progress / total) * 100
bar_length = 20
filled = int(bar_length * progress / total)
bar = "=" * filled + "-" * (bar_length - filled)
print(f"{prefix} [{bar}] {percentage:.0f}% {message or ''}")
# Example: Store progress in state (callback_context allows modification)
if callback_context:
callback_context.state["last_progress"] = progress
callback_context.state["last_total"] = total
else:
print(
f"{prefix} Progress: {progress} {f'- {message}' if message else ''}"
)
return callback
root_agent = LlmAgent(
model="gemini-2.5-flash",
name="progress_demo_agent",
instruction="""\
You are a helpful assistant that can run long-running tasks.
Available tools:
- long_running_task: Simulates a task with multiple steps. You can specify
the number of steps and delay between them.
- process_items: Processes a list of items one by one with progress updates.
When the user asks you to run a task, use these tools and the progress
will be logged automatically.
Example requests:
- "Run a long task with 5 steps"
- "Process these items: apple, banana, cherry, date"
""",
tools=[
McpToolset(
connection_params=StdioConnectionParams(
server_params=StdioServerParameters(
command=sys.executable, # Use current Python interpreter
args=[_mock_server_path],
),
timeout=60,
),
# Use factory function for per-tool callbacks (Option 2)
# Or use simple_progress_callback for shared callback (Option 1)
progress_callback=progress_callback_factory,
)
],
)
@@ -0,0 +1,161 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mock MCP server that sends progress notifications.
This server demonstrates how MCP servers can send progress updates
during long-running tool execution.
Run this server directly:
python mock_progress_server.py
Or use it with the sample agent:
See agent_with_mock_server.py
"""
import asyncio
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent
from mcp.types import Tool
server = Server("mock-progress-server")
@server.list_tools()
async def list_tools() -> list[Tool]:
"""List available tools."""
return [
Tool(
name="long_running_task",
description=(
"A simulated long-running task that reports progress. "
"Use this to test progress callback functionality."
),
inputSchema={
"type": "object",
"properties": {
"steps": {
"type": "integer",
"description": "Number of steps to simulate (default: 5)",
"default": 5,
},
"delay": {
"type": "number",
"description": (
"Delay in seconds between steps (default: 0.5)"
),
"default": 0.5,
},
},
},
),
Tool(
name="process_items",
description="Process a list of items with progress reporting.",
inputSchema={
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {"type": "string"},
"description": "List of items to process",
},
},
"required": ["items"],
},
),
]
@server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle tool calls with progress reporting."""
ctx = server.request_context
if name == "long_running_task":
steps = arguments.get("steps", 5)
delay = arguments.get("delay", 0.5)
# Get progress token from request metadata
progress_token = None
if ctx.meta and hasattr(ctx.meta, "progressToken"):
progress_token = ctx.meta.progressToken
for i in range(steps):
# Simulate work
await asyncio.sleep(delay)
# Send progress notification if client supports it
if progress_token is not None:
await ctx.session.send_progress_notification(
progress_token=progress_token,
progress=i + 1,
total=steps,
message=f"Completed step {i + 1} of {steps}",
)
return [
TextContent(
type="text",
text=f"Successfully completed {steps} steps!",
)
]
elif name == "process_items":
items = arguments.get("items", [])
total = len(items)
progress_token = None
if ctx.meta and hasattr(ctx.meta, "progressToken"):
progress_token = ctx.meta.progressToken
results = []
for i, item in enumerate(items):
# Simulate processing
await asyncio.sleep(0.3)
results.append(f"Processed: {item}")
# Send progress
if progress_token is not None:
await ctx.session.send_progress_notification(
progress_token=progress_token,
progress=i + 1,
total=total,
message=f"Processing item: {item}",
)
return [
TextContent(
type="text",
text="\n".join(results),
)
]
return [TextContent(type="text", text=f"Unknown tool: {name}")]
async def main():
"""Run the MCP server."""
async with stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
)
if __name__ == "__main__":
asyncio.run(main())
+128 -4
View File
@@ -14,6 +14,7 @@
from __future__ import annotations
import asyncio
import base64
import inspect
import logging
@@ -21,14 +22,18 @@ from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Protocol
from typing import runtime_checkable
from typing import Union
import warnings
from fastapi.openapi.models import APIKeyIn
from google.genai.types import FunctionDeclaration
from mcp.shared.session import ProgressFnT
from mcp.types import Tool as McpBaseTool
from typing_extensions import override
from ...agents.callback_context import CallbackContext
from ...agents.readonly_context import ReadonlyContext
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
@@ -37,7 +42,6 @@ from ...features import FeatureName
from ...features import is_feature_enabled
from .._gemini_schema_util import _to_gemini_schema
from ..base_authenticated_tool import BaseAuthenticatedTool
# import
from ..tool_context import ToolContext
from .mcp_session_manager import MCPSessionManager
from .mcp_session_manager import retry_on_errors
@@ -45,6 +49,68 @@ from .mcp_session_manager import retry_on_errors
logger = logging.getLogger("google_adk." + __name__)
@runtime_checkable
class ProgressCallbackFactory(Protocol):
"""Factory protocol for creating per-tool progress callbacks.
This protocol allows users to create different progress callbacks for
different tools based on tool name and runtime context. The factory receives
the tool name, a CallbackContext for accessing and modifying session state,
and additional keyword arguments for forward compatibility.
Example usage::
def my_callback_factory(
tool_name: str,
*,
callback_context: CallbackContext | None = None,
**kwargs
) -> ProgressFnT | None:
session_id = callback_context.session.id if callback_context else "N/A"
async def callback(progress, total, message):
print(f"[{tool_name}] Session {session_id}: {progress}/{total}")
# Can modify state in the callback
if callback_context:
callback_context.state['last_progress'] = progress
return callback
toolset = McpToolset(
connection_params=...,
progress_callback=my_callback_factory,
)
Note:
The **kwargs parameter is required for forward compatibility. Future
versions may pass additional parameters. Implementations should accept
**kwargs even if they don't use them.
"""
def __call__(
self,
tool_name: str,
*,
callback_context: Optional[CallbackContext] = None,
**kwargs: Any,
) -> Optional[ProgressFnT]:
"""Create a progress callback for a specific tool.
Args:
tool_name: The name of the MCP tool.
callback_context: The callback context providing access to session,
state, artifacts, and other runtime information. Allows modifying
state via ctx.state['key'] = value. May be None if not available.
**kwargs: Additional keyword arguments for future extensibility.
Implementations should accept **kwargs for forward compatibility.
Returns:
A progress callback function, or None if no callback is needed
for this tool.
"""
...
class McpTool(BaseAuthenticatedTool):
"""Turns an MCP Tool into an ADK Tool.
@@ -66,6 +132,9 @@ class McpTool(BaseAuthenticatedTool):
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
progress_callback: Optional[
Union[ProgressFnT, ProgressCallbackFactory]
] = None,
):
"""Initializes an McpTool.
@@ -81,6 +150,17 @@ class McpTool(BaseAuthenticatedTool):
or a callable that takes the function's arguments and returns a
boolean. If the callable returns True, the tool will require
confirmation from the user.
header_provider: Optional function to provide dynamic headers.
progress_callback: Optional callback to receive progress notifications
from MCP server during long-running tool execution. Can be either:
- A ``ProgressFnT`` callback that receives (progress, total, message).
This callback will be used for all invocations.
- A ``ProgressCallbackFactory`` that creates per-invocation callbacks.
The factory receives (tool_name, callback_context, **kwargs) and
returns a ProgressFnT or None. This allows callbacks to access
and modify runtime context like session state.
Raises:
ValueError: If mcp_tool or mcp_session_manager is None.
@@ -98,6 +178,7 @@ class McpTool(BaseAuthenticatedTool):
self._mcp_session_manager = mcp_session_manager
self._require_confirmation = require_confirmation
self._header_provider = header_provider
self._progress_callback = progress_callback
@override
def _get_declaration(self) -> FunctionDeclaration:
@@ -237,9 +318,50 @@ class McpTool(BaseAuthenticatedTool):
headers=final_headers
)
response = await session.call_tool(self._mcp_tool.name, arguments=args)
# Resolve progress callback (may be a factory that needs runtime context)
resolved_callback = self._resolve_progress_callback(tool_context)
response = await session.call_tool(
self._mcp_tool.name,
arguments=args,
progress_callback=resolved_callback,
)
return response.model_dump(exclude_none=True, mode="json")
def _resolve_progress_callback(
self, tool_context: ToolContext
) -> Optional[ProgressFnT]:
"""Resolve the progress callback for the current invocation.
If progress_callback is a ProgressCallbackFactory, call it to create
a callback with runtime context. Otherwise, return the callback directly.
Args:
tool_context: The tool context for the current invocation.
Returns:
The resolved progress callback, or None if not configured.
"""
if (
not hasattr(self, "_progress_callback")
or self._progress_callback is None
):
return None
# Determine if callback is a factory by checking if it's a coroutine
# function. ProgressFnT is an async function, while ProgressCallbackFactory
# is a sync function that returns an async function.
if asyncio.iscoroutinefunction(self._progress_callback):
return self._progress_callback
# If it's a regular callable (not async), treat it as a factory
if callable(self._progress_callback) and not inspect.iscoroutinefunction(
self._progress_callback
):
return self._progress_callback(self.name, callback_context=tool_context)
return self._progress_callback
async def _get_headers(
self, tool_context: ToolContext, credential: AuthCredential
) -> Optional[dict[str, str]]:
@@ -253,7 +375,8 @@ class McpTool(BaseAuthenticatedTool):
Dictionary of headers to add to the request, or None if no auth.
Raises:
ValueError: If API key authentication is configured for non-header location.
ValueError: If API key authentication is configured for non-header
location.
"""
headers: Optional[dict[str, str]] = None
if credential:
@@ -284,7 +407,8 @@ class McpTool(BaseAuthenticatedTool):
# Handle other HTTP schemes with token
headers = {
"Authorization": (
f"{credential.http.scheme} {credential.http.credentials.token}"
f"{credential.http.scheme}"
f" {credential.http.credentials.token}"
)
}
elif credential.api_key:
+33 -12
View File
@@ -31,6 +31,7 @@ from typing import Union
import warnings
from mcp import StdioServerParameters
from mcp.shared.session import ProgressFnT
from mcp.types import ListResourcesResult
from mcp.types import ListToolsResult
from pydantic import model_validator
@@ -51,6 +52,7 @@ from .mcp_session_manager import SseConnectionParams
from .mcp_session_manager import StdioConnectionParams
from .mcp_session_manager import StreamableHTTPConnectionParams
from .mcp_tool import MCPTool
from .mcp_tool import ProgressCallbackFactory
logger = logging.getLogger("google_adk." + __name__)
@@ -72,7 +74,8 @@ class McpToolset(BaseToolset):
command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"],
),
tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools
tool_filter=['read_file', 'list_directory'] # Optional: filter specific
tools
)
# Use in an agent
@@ -106,18 +109,21 @@ class McpToolset(BaseToolset):
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
progress_callback: Optional[
Union[ProgressFnT, ProgressCallbackFactory]
] = None,
):
"""Initializes the McpToolset.
Args:
connection_params: The connection parameters to the MCP server. Can be:
``StdioConnectionParams`` for using local mcp server (e.g. using ``npx`` or
``python3``); or ``SseConnectionParams`` for a local/remote SSE server; or
``StreamableHTTPConnectionParams`` for local/remote Streamable http
server. Note, ``StdioServerParameters`` is also supported for using local
mcp server (e.g. using ``npx`` or ``python3`` ), but it does not support
timeout, and we recommend to use ``StdioConnectionParams`` instead when
timeout is needed.
``StdioConnectionParams`` for using local mcp server (e.g. using ``npx``
or ``python3``); or ``SseConnectionParams`` for a local/remote SSE
server; or ``StreamableHTTPConnectionParams`` for local/remote
Streamable http server. Note, ``StdioServerParameters`` is also
supported for using local mcp server (e.g. using ``npx`` or ``python3``
), but it does not support timeout, and we recommend to use
``StdioConnectionParams`` instead when timeout is needed.
tool_filter: Optional filter to select specific tools. Can be either: - A
list of tool names to include - A ToolPredicate function for custom
filtering logic
@@ -126,11 +132,22 @@ class McpToolset(BaseToolset):
errlog: TextIO stream for error logging.
auth_scheme: The auth scheme of the tool for tool calling
auth_credential: The auth credential of the tool for tool calling
require_confirmation: Whether tools in this toolset require
confirmation. Can be a single boolean or a callable to apply to all
tools.
require_confirmation: Whether tools in this toolset require confirmation.
Can be a single boolean or a callable to apply to all tools.
header_provider: A callable that takes a ReadonlyContext and returns a
dictionary of headers to be used for the MCP session.
progress_callback: Optional callback to receive progress notifications
from MCP server during long-running tool execution. Can be either:
- A ``ProgressFnT`` callback that receives (progress, total, message).
This callback will be shared by all tools in the toolset.
- A ``ProgressCallbackFactory`` that creates per-tool callbacks. The
factory receives (tool_name, callback_context, **kwargs) and returns
a ProgressFnT or None. This allows different tools to have different
progress handling logic and access/modify session state via the
CallbackContext. The **kwargs parameter allows for future
extensibility.
"""
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
@@ -140,6 +157,7 @@ class McpToolset(BaseToolset):
self._connection_params = connection_params
self._errlog = errlog
self._header_provider = header_provider
self._progress_callback = progress_callback
# Create the session manager that will handle the MCP connection
self._mcp_session_manager = MCPSessionManager(
@@ -270,7 +288,7 @@ class McpToolset(BaseToolset):
Args:
readonly_context: Context used to filter tools available to the agent.
If None, all tools in the toolset are returned.
If None, all tools in the toolset are returned.
Returns:
List[BaseTool]: A list of tools available under the specified context.
@@ -292,6 +310,9 @@ class McpToolset(BaseToolset):
auth_credential=self._auth_credential,
require_confirmation=self._require_confirmation,
header_provider=self._header_provider,
progress_callback=self._progress_callback
if hasattr(self, "_progress_callback")
else None,
)
if self._is_tool_selected(mcp_tool, readonly_context):
@@ -225,7 +225,7 @@ class TestMCPTool:
)
# Fix: call_tool uses 'arguments' parameter, not positional args
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args
"test_tool", arguments=args, progress_callback=None
)
@pytest.mark.asyncio
@@ -778,7 +778,7 @@ class TestMCPTool:
headers=expected_headers
)
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args
"test_tool", arguments=args, progress_callback=None
)
@pytest.mark.asyncio
@@ -821,5 +821,99 @@ class TestMCPTool:
"X-Tenant-ID": "test-tenant",
}
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args
"test_tool", arguments=args, progress_callback=None
)
def test_init_with_progress_callback(self):
"""Test initialization with progress_callback."""
async def my_progress_callback(
progress: float, total: float | None, message: str | None
) -> None:
pass
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
progress_callback=my_progress_callback,
)
assert tool._progress_callback == my_progress_callback
@pytest.mark.asyncio
async def test_run_async_impl_with_progress_callback(self):
"""Test running tool with progress_callback."""
progress_updates = []
async def my_progress_callback(
progress: float, total: float | None, message: str | None
) -> None:
progress_updates.append((progress, total, message))
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
progress_callback=my_progress_callback,
)
# Mock the session response
mcp_response = CallToolResult(
content=[TextContent(type="text", text="success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
tool_context = Mock(spec=ToolContext)
args = {"param1": "test_value"}
result = await tool._run_async_impl(
args=args, tool_context=tool_context, credential=None
)
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
self.mock_session_manager.create_session.assert_called_once_with(
headers=None
)
# Verify progress_callback was passed to call_tool
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args, progress_callback=my_progress_callback
)
@pytest.mark.asyncio
async def test_run_async_impl_with_progress_callback_factory(self):
"""Test running tool with progress_callback factory that receives context."""
factory_calls = []
def my_callback_factory(tool_name: str, *, callback_context=None, **kwargs):
factory_calls.append((tool_name, callback_context))
async def callback(
progress: float, total: float | None, message: str | None
) -> None:
pass
return callback
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
progress_callback=my_callback_factory,
)
# Mock the session response
mcp_response = CallToolResult(
content=[TextContent(type="text", text="success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
tool_context = Mock(spec=ToolContext)
args = {"param1": "test_value"}
await tool._run_async_impl(
args=args, tool_context=tool_context, credential=None
)
# Verify factory was called with tool name and tool_context as callback_context
assert len(factory_calls) == 1
assert factory_calls[0][0] == "test_tool"
# callback_context is the tool_context itself (ToolContext extends CallbackContext)
assert factory_calls[0][1] is tool_context
@@ -360,6 +360,101 @@ class TestMcpToolset:
assert tools[0].name == "tool1"
assert tools[1].name == "tool2"
def test_init_with_progress_callback(self):
"""Test initialization with progress_callback."""
async def my_progress_callback(
progress: float, total: float | None, message: str | None
) -> None:
pass
toolset = McpToolset(
connection_params=self.mock_stdio_params,
progress_callback=my_progress_callback,
)
assert toolset._progress_callback == my_progress_callback
@pytest.mark.asyncio
async def test_get_tools_passes_progress_callback_to_mcp_tools(self):
"""Test that get_tools passes progress_callback to created MCPTool instances."""
progress_updates = []
async def my_progress_callback(
progress: float, total: float | None, message: str | None
) -> None:
progress_updates.append((progress, total, message))
mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)
toolset = McpToolset(
connection_params=self.mock_stdio_params,
progress_callback=my_progress_callback,
)
toolset._mcp_session_manager = self.mock_session_manager
tools = await toolset.get_tools()
assert len(tools) == 2
# Verify each tool has the progress_callback set
for tool in tools:
assert tool._progress_callback == my_progress_callback
def test_init_with_progress_callback_factory(self):
"""Test initialization with a ProgressCallbackFactory."""
def my_callback_factory(tool_name: str, *, readonly_context=None, **kwargs):
async def callback(
progress: float, total: float | None, message: str | None
) -> None:
pass
return callback
toolset = McpToolset(
connection_params=self.mock_stdio_params,
progress_callback=my_callback_factory,
)
assert toolset._progress_callback == my_callback_factory
@pytest.mark.asyncio
async def test_get_tools_passes_factory_to_mcp_tools(self):
"""Test that get_tools passes factory directly to MCPTool instances.
The factory is resolved at runtime in McpTool._run_async_impl, not at
tool creation time. This allows the factory to receive ReadonlyContext.
"""
def my_callback_factory(tool_name: str, *, readonly_context=None, **kwargs):
async def callback(
progress: float, total: float | None, message: str | None
) -> None:
pass
return callback
mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)
toolset = McpToolset(
connection_params=self.mock_stdio_params,
progress_callback=my_callback_factory,
)
toolset._mcp_session_manager = self.mock_session_manager
tools = await toolset.get_tools()
assert len(tools) == 2
# Factory is passed directly to each tool (resolved at runtime)
for tool in tools:
assert tool._progress_callback == my_callback_factory
@pytest.mark.asyncio
async def test_list_resources(self):
"""Test listing resources."""