Jo Shields a575963da9 Imported Upstream version 3.6.0
Former-commit-id: da6be194a6b1221998fc28233f2503bd61dd9d14
2014-08-13 10:39:27 +01:00

422 lines
17 KiB
C#

// Copyright (c) Microsoft Corporation. All rights reserved. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Configuration.Provider;
using System.Globalization;
using System.Linq;
using System.Web.Security;
using WebMatrix.WebData.Resources;
namespace WebMatrix.WebData
{
public class SimpleRoleProvider : RoleProvider
{
private RoleProvider _previousProvider;
public SimpleRoleProvider()
: this(null)
{
}
public SimpleRoleProvider(RoleProvider previousProvider)
{
_previousProvider = previousProvider;
}
private RoleProvider PreviousProvider
{
get
{
if (_previousProvider == null)
{
throw new InvalidOperationException(WebDataResources.Security_InitializeMustBeCalledFirst);
}
else
{
return _previousProvider;
}
}
}
private string SafeUserTableName
{
get { return "[" + UserTableName + "]"; }
}
private string SafeUserNameColumn
{
get { return "[" + UserNameColumn + "]"; }
}
private string SafeUserIdColumn
{
get { return "[" + UserIdColumn + "]"; }
}
internal static string RoleTableName
{
get { return "webpages_Roles"; }
}
internal static string UsersInRoleTableName
{
get { return "webpages_UsersInRoles"; }
}
// represents the User table for the app
public string UserTableName { get; set; }
// represents the User created UserName column, i.e. Email
public string UserNameColumn { get; set; }
// Represents the User created id column, i.e. ID;
// REVIEW: we could get this from the primary key of UserTable in the future
public string UserIdColumn { get; set; }
internal DatabaseConnectionInfo ConnectionInfo { get; set; }
internal bool InitializeCalled { get; set; }
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override string ApplicationName
{
get
{
if (InitializeCalled)
{
throw new NotSupportedException();
}
else
{
return PreviousProvider.ApplicationName;
}
}
set
{
if (InitializeCalled)
{
throw new NotSupportedException();
}
else
{
PreviousProvider.ApplicationName = value;
}
}
}
private void VerifyInitialized()
{
if (!InitializeCalled)
{
throw new InvalidOperationException(WebDataResources.Security_InitializeMustBeCalledFirst);
}
}
private IDatabase ConnectToDatabase()
{
return new DatabaseWrapper(ConnectionInfo.Connect());
}
internal void CreateTablesIfNeeded()
{
using (var db = ConnectToDatabase())
{
if (!SimpleMembershipProvider.CheckTableExists(db, RoleTableName))
{
db.Execute(@"CREATE TABLE " + RoleTableName + @" (
RoleId int NOT NULL PRIMARY KEY IDENTITY,
RoleName nvarchar(256) NOT NULL UNIQUE)");
db.Execute(@"CREATE TABLE " + UsersInRoleTableName + @" (
UserId int NOT NULL,
RoleId int NOT NULL,
PRIMARY KEY (UserId, RoleId),
CONSTRAINT fk_UserId FOREIGN KEY (UserId) REFERENCES " + SafeUserTableName + "(" + SafeUserIdColumn + @"),
CONSTRAINT fk_RoleId FOREIGN KEY (RoleId) REFERENCES " + RoleTableName + "(RoleId) )");
}
}
}
private List<int> GetUserIdsFromNames(IDatabase db, string[] usernames)
{
List<int> userIds = new List<int>(usernames.Length);
foreach (string username in usernames)
{
int id = SimpleMembershipProvider.GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username);
if (id == -1)
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, username));
}
userIds.Add(id);
}
return userIds;
}
private static List<int> GetRoleIdsFromNames(IDatabase db, string[] roleNames)
{
List<int> roleIds = new List<int>(roleNames.Length);
foreach (string role in roleNames)
{
int id = FindRoleId(db, role);
if (id == -1)
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.SimpleRoleProvider_NoRoleFound, role));
}
roleIds.Add(id);
}
return roleIds;
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override void AddUsersToRoles(string[] usernames, string[] roleNames)
{
if (!InitializeCalled)
{
PreviousProvider.AddUsersToRoles(usernames, roleNames);
}
else
{
using (var db = ConnectToDatabase())
{
int userCount = usernames.Length;
int roleCount = roleNames.Length;
List<int> userIds = GetUserIdsFromNames(db, usernames);
List<int> roleIds = GetRoleIdsFromNames(db, roleNames);
// Generate a INSERT INTO for each userid/rowid combination, where userIds are the first params, and roleIds follow
for (int uId = 0; uId < userCount; uId++)
{
for (int rId = 0; rId < roleCount; rId++)
{
if (IsUserInRole(usernames[uId], roleNames[rId]))
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.SimpleRoleProvder_UserAlreadyInRole, usernames[uId], roleNames[rId]));
}
// REVIEW: is there a way to batch up these inserts?
int rows = db.Execute("INSERT INTO " + UsersInRoleTableName + " VALUES (" + userIds[uId] + "," + roleIds[rId] + "); ");
if (rows != 1)
{
throw new ProviderException(WebDataResources.Security_DbFailure);
}
}
}
}
}
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override void CreateRole(string roleName)
{
if (!InitializeCalled)
{
PreviousProvider.CreateRole(roleName);
}
else
{
using (var db = ConnectToDatabase())
{
int roleId = FindRoleId(db, roleName);
if (roleId != -1)
{
throw new InvalidOperationException(String.Format(CultureInfo.InvariantCulture, WebDataResources.SimpleRoleProvider_RoleExists, roleName));
}
int rows = db.Execute("INSERT INTO " + RoleTableName + " (RoleName) VALUES (@0)", roleName);
if (rows != 1)
{
throw new ProviderException(WebDataResources.Security_DbFailure);
}
}
}
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override bool DeleteRole(string roleName, bool throwOnPopulatedRole)
{
if (!InitializeCalled)
{
return PreviousProvider.DeleteRole(roleName, throwOnPopulatedRole);
}
using (var db = ConnectToDatabase())
{
int roleId = FindRoleId(db, roleName);
if (roleId == -1)
{
return false;
}
if (throwOnPopulatedRole)
{
int usersInRole = db.Query(@"SELECT * FROM " + UsersInRoleTableName + " WHERE (RoleId = @0)", roleId).Count();
if (usersInRole > 0)
{
throw new InvalidOperationException(String.Format(CultureInfo.InvariantCulture, WebDataResources.SimpleRoleProvder_RolePopulated, roleName));
}
}
else
{
// Delete any users in this role first
db.Execute(@"DELETE FROM " + UsersInRoleTableName + " WHERE (RoleId = @0)", roleId);
}
int rows = db.Execute(@"DELETE FROM " + RoleTableName + " WHERE (RoleId = @0)", roleId);
return (rows == 1); // REVIEW: should this ever be > 1?
}
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override string[] FindUsersInRole(string roleName, string usernameToMatch)
{
if (!InitializeCalled)
{
return PreviousProvider.FindUsersInRole(roleName, usernameToMatch);
}
using (var db = ConnectToDatabase())
{
// REVIEW: Is there any way to directly get out a string[]?
List<dynamic> userNames = db.Query(@"SELECT u." + SafeUserNameColumn + " FROM " + SafeUserTableName + " u, " + UsersInRoleTableName + " ur, " + RoleTableName + " r Where (r.RoleName = @0 and ur.RoleId = r.RoleId and ur.UserId = u." + SafeUserIdColumn + " and u." + SafeUserNameColumn + " LIKE @1)", new object[] { roleName, usernameToMatch }).ToList();
string[] users = new string[userNames.Count];
for (int i = 0; i < userNames.Count; i++)
{
users[i] = (string)userNames[i][0];
}
return users;
}
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override string[] GetAllRoles()
{
if (!InitializeCalled)
{
return PreviousProvider.GetAllRoles();
}
using (var db = ConnectToDatabase())
{
return db.Query(@"SELECT RoleName FROM " + RoleTableName).Select<dynamic, string>(d => (string)d[0]).ToArray();
}
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override string[] GetRolesForUser(string username)
{
if (!InitializeCalled)
{
return PreviousProvider.GetRolesForUser(username);
}
using (var db = ConnectToDatabase())
{
int userId = SimpleMembershipProvider.GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username);
if (userId == -1)
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, username));
}
string query = @"SELECT r.RoleName FROM " + UsersInRoleTableName + " u, " + RoleTableName + " r Where (u.UserId = @0 and u.RoleId = r.RoleId) GROUP BY RoleName";
return db.Query(query, new object[] { userId }).Select<dynamic, string>(d => (string)d[0]).ToArray();
}
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override string[] GetUsersInRole(string roleName)
{
if (!InitializeCalled)
{
return PreviousProvider.GetUsersInRole(roleName);
}
using (var db = ConnectToDatabase())
{
string query = @"SELECT u." + SafeUserNameColumn + " FROM " + SafeUserTableName + " u, " + UsersInRoleTableName + " ur, " + RoleTableName + " r Where (r.RoleName = @0 and ur.RoleId = r.RoleId and ur.UserId = u." + SafeUserIdColumn + ")";
return db.Query(query, new object[] { roleName }).Select<dynamic, string>(d => (string)d[0]).ToArray();
}
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override bool IsUserInRole(string username, string roleName)
{
if (!InitializeCalled)
{
return PreviousProvider.IsUserInRole(username, roleName);
}
using (var db = ConnectToDatabase())
{
var count = db.QuerySingle("SELECT COUNT(*) FROM " + SafeUserTableName + " u, " + UsersInRoleTableName + " ur, " + RoleTableName + " r Where (u." + SafeUserNameColumn + " = @0 and r.RoleName = @1 and ur.RoleId = r.RoleId and ur.UserId = u." + SafeUserIdColumn + ")", username, roleName);
return (count[0] == 1);
}
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override void RemoveUsersFromRoles(string[] usernames, string[] roleNames)
{
if (!InitializeCalled)
{
PreviousProvider.RemoveUsersFromRoles(usernames, roleNames);
}
else
{
foreach (string rolename in roleNames)
{
if (!RoleExists(rolename))
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.SimpleRoleProvider_NoRoleFound, rolename));
}
}
foreach (string username in usernames)
{
foreach (string rolename in roleNames)
{
if (!IsUserInRole(username, rolename))
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.SimpleRoleProvder_UserNotInRole, username, rolename));
}
}
}
using (var db = ConnectToDatabase())
{
List<int> userIds = GetUserIdsFromNames(db, usernames);
List<int> roleIds = GetRoleIdsFromNames(db, roleNames);
foreach (int userId in userIds)
{
foreach (int roleId in roleIds)
{
// Review: Is there a way to do these all in one query?
int rows = db.Execute("DELETE FROM " + UsersInRoleTableName + " WHERE UserId = " + userId + " and RoleId = " + roleId);
if (rows != 1)
{
throw new ProviderException(WebDataResources.Security_DbFailure);
}
}
}
}
}
}
private static int FindRoleId(IDatabase db, string roleName)
{
var result = db.QuerySingle(@"SELECT RoleId FROM " + RoleTableName + " WHERE (RoleName = @0)", roleName);
if (result == null)
{
return -1;
}
return (int)result[0];
}
// Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override bool RoleExists(string roleName)
{
if (!InitializeCalled)
{
return PreviousProvider.RoleExists(roleName);
}
using (var db = ConnectToDatabase())
{
return (FindRoleId(db, roleName) != -1);
}
}
}
}