// Copyright (c) Microsoft Corporation. All rights reserved. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Configuration.Provider; using System.Globalization; using System.Linq; using System.Web.Security; using WebMatrix.WebData.Resources; namespace WebMatrix.WebData { public class SimpleRoleProvider : RoleProvider { private RoleProvider _previousProvider; public SimpleRoleProvider() : this(null) { } public SimpleRoleProvider(RoleProvider previousProvider) { _previousProvider = previousProvider; } private RoleProvider PreviousProvider { get { if (_previousProvider == null) { throw new InvalidOperationException(WebDataResources.Security_InitializeMustBeCalledFirst); } else { return _previousProvider; } } } private string SafeUserTableName { get { return "[" + UserTableName + "]"; } } private string SafeUserNameColumn { get { return "[" + UserNameColumn + "]"; } } private string SafeUserIdColumn { get { return "[" + UserIdColumn + "]"; } } internal static string RoleTableName { get { return "webpages_Roles"; } } internal static string UsersInRoleTableName { get { return "webpages_UsersInRoles"; } } // represents the User table for the app public string UserTableName { get; set; } // represents the User created UserName column, i.e. Email public string UserNameColumn { get; set; } // Represents the User created id column, i.e. ID; // REVIEW: we could get this from the primary key of UserTable in the future public string UserIdColumn { get; set; } internal DatabaseConnectionInfo ConnectionInfo { get; set; } internal bool InitializeCalled { get; set; } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override string ApplicationName { get { if (InitializeCalled) { throw new NotSupportedException(); } else { return PreviousProvider.ApplicationName; } } set { if (InitializeCalled) { throw new NotSupportedException(); } else { PreviousProvider.ApplicationName = value; } } } private void VerifyInitialized() { if (!InitializeCalled) { throw new InvalidOperationException(WebDataResources.Security_InitializeMustBeCalledFirst); } } private IDatabase ConnectToDatabase() { return new DatabaseWrapper(ConnectionInfo.Connect()); } internal void CreateTablesIfNeeded() { using (var db = ConnectToDatabase()) { if (!SimpleMembershipProvider.CheckTableExists(db, RoleTableName)) { db.Execute(@"CREATE TABLE " + RoleTableName + @" ( RoleId int NOT NULL PRIMARY KEY IDENTITY, RoleName nvarchar(256) NOT NULL UNIQUE)"); db.Execute(@"CREATE TABLE " + UsersInRoleTableName + @" ( UserId int NOT NULL, RoleId int NOT NULL, PRIMARY KEY (UserId, RoleId), CONSTRAINT fk_UserId FOREIGN KEY (UserId) REFERENCES " + SafeUserTableName + "(" + SafeUserIdColumn + @"), CONSTRAINT fk_RoleId FOREIGN KEY (RoleId) REFERENCES " + RoleTableName + "(RoleId) )"); } } } private List GetUserIdsFromNames(IDatabase db, string[] usernames) { List userIds = new List(usernames.Length); foreach (string username in usernames) { int id = SimpleMembershipProvider.GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username); if (id == -1) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, username)); } userIds.Add(id); } return userIds; } private static List GetRoleIdsFromNames(IDatabase db, string[] roleNames) { List roleIds = new List(roleNames.Length); foreach (string role in roleNames) { int id = FindRoleId(db, role); if (id == -1) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.SimpleRoleProvider_NoRoleFound, role)); } roleIds.Add(id); } return roleIds; } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override void AddUsersToRoles(string[] usernames, string[] roleNames) { if (!InitializeCalled) { PreviousProvider.AddUsersToRoles(usernames, roleNames); } else { using (var db = ConnectToDatabase()) { int userCount = usernames.Length; int roleCount = roleNames.Length; List userIds = GetUserIdsFromNames(db, usernames); List roleIds = GetRoleIdsFromNames(db, roleNames); // Generate a INSERT INTO for each userid/rowid combination, where userIds are the first params, and roleIds follow for (int uId = 0; uId < userCount; uId++) { for (int rId = 0; rId < roleCount; rId++) { if (IsUserInRole(usernames[uId], roleNames[rId])) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.SimpleRoleProvder_UserAlreadyInRole, usernames[uId], roleNames[rId])); } // REVIEW: is there a way to batch up these inserts? int rows = db.Execute("INSERT INTO " + UsersInRoleTableName + " VALUES (" + userIds[uId] + "," + roleIds[rId] + "); "); if (rows != 1) { throw new ProviderException(WebDataResources.Security_DbFailure); } } } } } } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override void CreateRole(string roleName) { if (!InitializeCalled) { PreviousProvider.CreateRole(roleName); } else { using (var db = ConnectToDatabase()) { int roleId = FindRoleId(db, roleName); if (roleId != -1) { throw new InvalidOperationException(String.Format(CultureInfo.InvariantCulture, WebDataResources.SimpleRoleProvider_RoleExists, roleName)); } int rows = db.Execute("INSERT INTO " + RoleTableName + " (RoleName) VALUES (@0)", roleName); if (rows != 1) { throw new ProviderException(WebDataResources.Security_DbFailure); } } } } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override bool DeleteRole(string roleName, bool throwOnPopulatedRole) { if (!InitializeCalled) { return PreviousProvider.DeleteRole(roleName, throwOnPopulatedRole); } using (var db = ConnectToDatabase()) { int roleId = FindRoleId(db, roleName); if (roleId == -1) { return false; } if (throwOnPopulatedRole) { int usersInRole = db.Query(@"SELECT * FROM " + UsersInRoleTableName + " WHERE (RoleId = @0)", roleId).Count(); if (usersInRole > 0) { throw new InvalidOperationException(String.Format(CultureInfo.InvariantCulture, WebDataResources.SimpleRoleProvder_RolePopulated, roleName)); } } else { // Delete any users in this role first db.Execute(@"DELETE FROM " + UsersInRoleTableName + " WHERE (RoleId = @0)", roleId); } int rows = db.Execute(@"DELETE FROM " + RoleTableName + " WHERE (RoleId = @0)", roleId); return (rows == 1); // REVIEW: should this ever be > 1? } } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override string[] FindUsersInRole(string roleName, string usernameToMatch) { if (!InitializeCalled) { return PreviousProvider.FindUsersInRole(roleName, usernameToMatch); } using (var db = ConnectToDatabase()) { // REVIEW: Is there any way to directly get out a string[]? List userNames = db.Query(@"SELECT u." + SafeUserNameColumn + " FROM " + SafeUserTableName + " u, " + UsersInRoleTableName + " ur, " + RoleTableName + " r Where (r.RoleName = @0 and ur.RoleId = r.RoleId and ur.UserId = u." + SafeUserIdColumn + " and u." + SafeUserNameColumn + " LIKE @1)", new object[] { roleName, usernameToMatch }).ToList(); string[] users = new string[userNames.Count]; for (int i = 0; i < userNames.Count; i++) { users[i] = (string)userNames[i][0]; } return users; } } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override string[] GetAllRoles() { if (!InitializeCalled) { return PreviousProvider.GetAllRoles(); } using (var db = ConnectToDatabase()) { return db.Query(@"SELECT RoleName FROM " + RoleTableName).Select(d => (string)d[0]).ToArray(); } } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override string[] GetRolesForUser(string username) { if (!InitializeCalled) { return PreviousProvider.GetRolesForUser(username); } using (var db = ConnectToDatabase()) { int userId = SimpleMembershipProvider.GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, username); if (userId == -1) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.Security_NoUserFound, username)); } string query = @"SELECT r.RoleName FROM " + UsersInRoleTableName + " u, " + RoleTableName + " r Where (u.UserId = @0 and u.RoleId = r.RoleId) GROUP BY RoleName"; return db.Query(query, new object[] { userId }).Select(d => (string)d[0]).ToArray(); } } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override string[] GetUsersInRole(string roleName) { if (!InitializeCalled) { return PreviousProvider.GetUsersInRole(roleName); } using (var db = ConnectToDatabase()) { string query = @"SELECT u." + SafeUserNameColumn + " FROM " + SafeUserTableName + " u, " + UsersInRoleTableName + " ur, " + RoleTableName + " r Where (r.RoleName = @0 and ur.RoleId = r.RoleId and ur.UserId = u." + SafeUserIdColumn + ")"; return db.Query(query, new object[] { roleName }).Select(d => (string)d[0]).ToArray(); } } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override bool IsUserInRole(string username, string roleName) { if (!InitializeCalled) { return PreviousProvider.IsUserInRole(username, roleName); } using (var db = ConnectToDatabase()) { var count = db.QuerySingle("SELECT COUNT(*) FROM " + SafeUserTableName + " u, " + UsersInRoleTableName + " ur, " + RoleTableName + " r Where (u." + SafeUserNameColumn + " = @0 and r.RoleName = @1 and ur.RoleId = r.RoleId and ur.UserId = u." + SafeUserIdColumn + ")", username, roleName); return (count[0] == 1); } } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override void RemoveUsersFromRoles(string[] usernames, string[] roleNames) { if (!InitializeCalled) { PreviousProvider.RemoveUsersFromRoles(usernames, roleNames); } else { foreach (string rolename in roleNames) { if (!RoleExists(rolename)) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.SimpleRoleProvider_NoRoleFound, rolename)); } } foreach (string username in usernames) { foreach (string rolename in roleNames) { if (!IsUserInRole(username, rolename)) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, WebDataResources.SimpleRoleProvder_UserNotInRole, username, rolename)); } } } using (var db = ConnectToDatabase()) { List userIds = GetUserIdsFromNames(db, usernames); List roleIds = GetRoleIdsFromNames(db, roleNames); foreach (int userId in userIds) { foreach (int roleId in roleIds) { // Review: Is there a way to do these all in one query? int rows = db.Execute("DELETE FROM " + UsersInRoleTableName + " WHERE UserId = " + userId + " and RoleId = " + roleId); if (rows != 1) { throw new ProviderException(WebDataResources.Security_DbFailure); } } } } } } private static int FindRoleId(IDatabase db, string roleName) { var result = db.QuerySingle(@"SELECT RoleId FROM " + RoleTableName + " WHERE (RoleName = @0)", roleName); if (result == null) { return -1; } return (int)result[0]; } // Inherited from RoleProvider ==> Forwarded to previous provider if this provider hasn't been initialized public override bool RoleExists(string roleName) { if (!InitializeCalled) { return PreviousProvider.RoleExists(roleName); } using (var db = ConnectToDatabase()) { return (FindRoleId(db, roleName) != -1); } } } }