e79aa3c0ed
Former-commit-id: a2155e9bd80020e49e72e86c44da02a8ac0e57a4
325 lines
14 KiB
C#
325 lines
14 KiB
C#
//------------------------------------------------------------------------------
|
|
// <copyright file="SqlConnectionHelper.cs" company="Microsoft">
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// </copyright>
|
|
//------------------------------------------------------------------------------
|
|
|
|
namespace System.Web.DataAccess {
|
|
|
|
using System;
|
|
using System.Collections.Specialized;
|
|
using System.Configuration;
|
|
using System.Configuration.Provider;
|
|
using System.Data;
|
|
using System.Data.SqlClient;
|
|
using System.Diagnostics;
|
|
using System.Globalization;
|
|
using System.IO;
|
|
using System.Security.Permissions;
|
|
using System.Threading;
|
|
using System.Web.Configuration;
|
|
using System.Web.Hosting;
|
|
using System.Web.Management;
|
|
using System.Web.Util;
|
|
|
|
internal static class SqlConnectionHelper {
|
|
internal const string s_strDataDir = "DataDirectory";
|
|
internal const string s_strUpperDataDirWithToken = "|DATADIRECTORY|";
|
|
internal const string s_strSqlExprFileExt = ".MDF";
|
|
internal const string s_strUpperUserInstance = "USER INSTANCE";
|
|
private const string s_localDbName = "(LOCALDB)";
|
|
private static object s_lock = new object();
|
|
|
|
internal static void EnsureNoUserInstance(string connectionString) {
|
|
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connectionString);
|
|
if (builder.UserInstance) {
|
|
throw new ProviderException(SR.GetString(SR.LocalDB_cannot_have_userinstance_flag));
|
|
}
|
|
}
|
|
|
|
internal static SqlConnectionHolder GetConnection(string connectionString, bool revertImpersonation) {
|
|
string strTempConnection = connectionString.ToUpperInvariant();
|
|
if (strTempConnection.Contains(s_strUpperDataDirWithToken)) {
|
|
EnsureDBFile(connectionString);
|
|
}
|
|
|
|
// Only block UserInstance for LocalDB connections
|
|
if (strTempConnection.Contains(s_localDbName)) {
|
|
EnsureNoUserInstance(connectionString);
|
|
}
|
|
|
|
SqlConnectionHolder holder = new SqlConnectionHolder(connectionString);
|
|
bool closeConn = true;
|
|
try {
|
|
try {
|
|
holder.Open(null, revertImpersonation);
|
|
closeConn = false;
|
|
}
|
|
finally {
|
|
if (closeConn) {
|
|
holder.Close();
|
|
holder = null;
|
|
}
|
|
}
|
|
}
|
|
catch {
|
|
throw;
|
|
}
|
|
return holder;
|
|
}
|
|
|
|
internal static string GetConnectionString(string specifiedConnectionString, bool lookupConnectionString, bool appLevel) {
|
|
System.Web.Util.Debug.Assert((specifiedConnectionString != null) && (specifiedConnectionString.Length != 0));
|
|
if (specifiedConnectionString == null || specifiedConnectionString.Length < 1)
|
|
return null;
|
|
|
|
string connectionString = null;
|
|
|
|
// Step 1: Check <connectionStrings> config section for this connection string
|
|
if (lookupConnectionString) {
|
|
RuntimeConfig config = (appLevel) ? RuntimeConfig.GetAppConfig() : RuntimeConfig.GetConfig();
|
|
ConnectionStringSettings connObj = config.ConnectionStrings.ConnectionStrings[specifiedConnectionString];
|
|
if (connObj != null)
|
|
connectionString = connObj.ConnectionString;
|
|
|
|
if (connectionString == null)
|
|
return null;
|
|
|
|
//HandlerBase.CheckAndReadRegistryValue (ref connectionString, true);
|
|
}
|
|
else {
|
|
connectionString = specifiedConnectionString;
|
|
}
|
|
|
|
return connectionString;
|
|
}
|
|
|
|
[PermissionSet(SecurityAction.Assert, Unrestricted = true)]
|
|
internal static string GetDataDirectory() {
|
|
if (HostingEnvironment.IsHosted)
|
|
return Path.Combine(HttpRuntime.AppDomainAppPath, HttpRuntime.DataDirectoryName);
|
|
|
|
string dataDir = AppDomain.CurrentDomain.GetData(s_strDataDir) as string;
|
|
if (string.IsNullOrEmpty(dataDir)) {
|
|
string appPath = null;
|
|
|
|
#if !FEATURE_PAL // FEATURE_PAL does not support ProcessModule
|
|
Process p = Process.GetCurrentProcess();
|
|
ProcessModule pm = (p != null ? p.MainModule : null);
|
|
string exeName = (pm != null ? pm.FileName : null);
|
|
|
|
if (!string.IsNullOrEmpty(exeName))
|
|
appPath = Path.GetDirectoryName(exeName);
|
|
#endif // !FEATURE_PAL
|
|
|
|
if (string.IsNullOrEmpty(appPath))
|
|
appPath = Environment.CurrentDirectory;
|
|
|
|
dataDir = Path.Combine(appPath, HttpRuntime.DataDirectoryName);
|
|
AppDomain.CurrentDomain.SetData(s_strDataDir, dataDir, new FileIOPermission(FileIOPermissionAccess.PathDiscovery, dataDir));
|
|
}
|
|
|
|
return dataDir;
|
|
}
|
|
|
|
private static void EnsureDBFile(string connectionString) {
|
|
string partialFileName = null;
|
|
string fullFileName = null;
|
|
string dataDir = GetDataDirectory();
|
|
bool lookingForDataDir = true;
|
|
bool lookingForDB = true;
|
|
string[] splitedConnStr = connectionString.Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries);
|
|
bool lookingForUserInstance = !connectionString.ToUpperInvariant().Contains(s_localDbName); // We don't require UserInstance=True for LocalDb
|
|
bool lookingForTimeout = true;
|
|
|
|
foreach (string str in splitedConnStr) {
|
|
string strUpper = str.ToUpper(CultureInfo.InvariantCulture).Trim();
|
|
|
|
if (lookingForDataDir && strUpper.Contains(s_strUpperDataDirWithToken)) {
|
|
lookingForDataDir = false;
|
|
|
|
// Replace the AttachDBFilename part with "Pooling=false"
|
|
connectionString = connectionString.Replace(str, "Pooling=false");
|
|
|
|
// Extract the filenames
|
|
int startPos = strUpper.IndexOf(s_strUpperDataDirWithToken, StringComparison.Ordinal) + s_strUpperDataDirWithToken.Length;
|
|
partialFileName = strUpper.Substring(startPos).Trim();
|
|
while (partialFileName.StartsWith("\\", StringComparison.Ordinal))
|
|
partialFileName = partialFileName.Substring(1);
|
|
if (partialFileName.Contains("..")) // don't allow it to traverse-up
|
|
partialFileName = null;
|
|
else
|
|
fullFileName = Path.Combine(dataDir, partialFileName);
|
|
if (!lookingForDB)
|
|
break; // done
|
|
}
|
|
else if (lookingForDB && (strUpper.StartsWith("INITIAL CATALOG", StringComparison.Ordinal) || strUpper.StartsWith("DATABASE", StringComparison.Ordinal))) {
|
|
lookingForDB = false;
|
|
connectionString = connectionString.Replace(str, "Database=master");
|
|
if (!lookingForDataDir)
|
|
break; // done
|
|
}
|
|
else if (lookingForUserInstance && strUpper.StartsWith(s_strUpperUserInstance, StringComparison.Ordinal)) {
|
|
lookingForUserInstance = false;
|
|
int pos = strUpper.IndexOf('=');
|
|
if (pos < 0)
|
|
return;
|
|
string strTemp = strUpper.Substring(pos + 1).Trim();
|
|
if (strTemp != "TRUE")
|
|
return;
|
|
}
|
|
else if (lookingForTimeout && strUpper.StartsWith("CONNECT TIMEOUT", StringComparison.Ordinal)) {
|
|
lookingForTimeout = false;
|
|
}
|
|
}
|
|
if (lookingForUserInstance)
|
|
return;
|
|
|
|
if (fullFileName == null)
|
|
throw new ProviderException(SR.GetString(SR.SqlExpress_file_not_found_in_connection_string));
|
|
|
|
if (File.Exists(fullFileName))
|
|
return;
|
|
|
|
if (!HttpRuntime.HasAspNetHostingPermission(AspNetHostingPermissionLevel.High))
|
|
throw new ProviderException(SR.GetString(SR.Provider_can_not_create_file_in_this_trust_level));
|
|
|
|
if (!connectionString.Contains("Database=master"))
|
|
connectionString += ";Database=master";
|
|
if (lookingForTimeout)
|
|
connectionString += ";Connect Timeout=45";
|
|
using (new ApplicationImpersonationContext())
|
|
lock (s_lock)
|
|
if (!File.Exists(fullFileName))
|
|
CreateMdfFile(fullFileName, dataDir, connectionString);
|
|
}
|
|
|
|
[PermissionSet(SecurityAction.Assert, Unrestricted = true)]
|
|
private static void CreateMdfFile(string fullFileName, string dataDir, string connectionString) {
|
|
bool creatingDir = false;
|
|
string databaseName = null;
|
|
HttpContext context = HttpContext.Current;
|
|
string tempFileName = null;
|
|
|
|
try {
|
|
if (!Directory.Exists(dataDir)) {
|
|
creatingDir = true;
|
|
Directory.CreateDirectory(dataDir);
|
|
creatingDir = false;
|
|
try {
|
|
if (context != null)
|
|
HttpRuntime.RestrictIISFolders(context);
|
|
}
|
|
catch { }
|
|
}
|
|
|
|
fullFileName = fullFileName.ToUpper(CultureInfo.InvariantCulture);
|
|
char[] strippedFileNameChars = Path.GetFileNameWithoutExtension(fullFileName).ToCharArray();
|
|
for (int iter = 0; iter < strippedFileNameChars.Length; iter++)
|
|
if (!char.IsLetterOrDigit(strippedFileNameChars[iter]))
|
|
strippedFileNameChars[iter] = '_';
|
|
string strippedFileName = new string(strippedFileNameChars);
|
|
if (strippedFileName.Length > 30)
|
|
databaseName = strippedFileName.Substring(0, 30) + "_" + Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture);
|
|
else
|
|
databaseName = strippedFileName + "_" + Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture);
|
|
|
|
tempFileName = Path.Combine(Path.GetDirectoryName(fullFileName), strippedFileName + "_TMP" + s_strSqlExprFileExt);
|
|
|
|
// Auto create the temporary database
|
|
SqlServices.Install(databaseName, tempFileName, connectionString);
|
|
DetachDB(databaseName, connectionString);
|
|
try {
|
|
File.Move(tempFileName, fullFileName);
|
|
}
|
|
catch {
|
|
if (!File.Exists(fullFileName)) {
|
|
File.Copy(tempFileName, fullFileName);
|
|
try {
|
|
File.Delete(tempFileName);
|
|
}
|
|
catch { }
|
|
}
|
|
}
|
|
try {
|
|
File.Delete(tempFileName.Replace("_TMP.MDF", "_TMP_log.LDF"));
|
|
}
|
|
catch { }
|
|
}
|
|
catch (Exception e) {
|
|
if (context == null || context.IsCustomErrorEnabled)
|
|
throw;
|
|
HttpException httpExec = new HttpException(e.Message, e);
|
|
if (e is UnauthorizedAccessException)
|
|
httpExec.SetFormatter(new SqlExpressConnectionErrorFormatter(creatingDir ? DataConnectionErrorEnum.CanNotCreateDataDir : DataConnectionErrorEnum.CanNotWriteToDataDir));
|
|
else
|
|
httpExec.SetFormatter(new SqlExpressDBFileAutoCreationErrorFormatter(e));
|
|
throw httpExec;
|
|
}
|
|
}
|
|
|
|
private static void DetachDB(string databaseName, string connectionString) {
|
|
SqlConnection connection = new SqlConnection(connectionString);
|
|
try {
|
|
connection.Open();
|
|
SqlCommand command = new SqlCommand("USE master", connection);
|
|
command.ExecuteNonQuery();
|
|
command = new SqlCommand("sp_detach_db", connection);
|
|
command.CommandType = CommandType.StoredProcedure;
|
|
command.Parameters.AddWithValue("@dbname", databaseName);
|
|
command.Parameters.AddWithValue("@skipchecks", "true");
|
|
command.ExecuteNonQuery();
|
|
}
|
|
catch {
|
|
}
|
|
finally {
|
|
connection.Close();
|
|
}
|
|
}
|
|
}
|
|
|
|
internal sealed class SqlConnectionHolder {
|
|
internal SqlConnection _Connection;
|
|
private bool _Opened;
|
|
|
|
internal SqlConnection Connection {
|
|
get { return _Connection; }
|
|
}
|
|
|
|
internal SqlConnectionHolder(string connectionString) {
|
|
try {
|
|
_Connection = new SqlConnection(connectionString);
|
|
System.Web.Util.Debug.Assert(_Connection != null);
|
|
}
|
|
catch (ArgumentException e) {
|
|
throw new ArgumentException(SR.GetString(SR.SqlError_Connection_String), "connectionString", e);
|
|
}
|
|
}
|
|
|
|
internal void Open(HttpContext context, bool revertImpersonate) {
|
|
if (_Opened)
|
|
return; // Already opened
|
|
|
|
if (revertImpersonate) {
|
|
using (new ApplicationImpersonationContext()) {
|
|
Connection.Open();
|
|
}
|
|
}
|
|
else {
|
|
Connection.Open();
|
|
}
|
|
|
|
_Opened = true; // Open worked!
|
|
}
|
|
|
|
internal void Close() {
|
|
if (!_Opened) // Not open!
|
|
return;
|
|
// Close connection
|
|
Connection.Close();
|
|
_Opened = false;
|
|
}
|
|
}
|
|
}
|
|
|