//------------------------------------------------------------------------------ // // Copyright (c) Microsoft Corporation. All rights reserved. // // balnee // krishnib //------------------------------------------------------------------------------ namespace System.Data.SqlClient { using System; using System.Diagnostics; using System.Reflection; using System.Security; using System.Security.Cryptography; using System.Text; internal static class SqlSecurityUtility { /// /// Computes a keyed hash of a given text and returns. It fills the buffer "hash" with computed hash value. /// /// Plain text bytes whose hash has to be computed. /// key used for the HMAC /// Output buffer where the computed hash value is stored. If its less that 64 bytes, the hash is truncated /// HMAC value internal static void GetHMACWithSHA256(byte[] plainText, byte[] key, byte[] hash) { const int MaxSHA256HashBytes = 32; Debug.Assert(key != null && plainText != null); Debug.Assert(hash.Length != 0 && hash.Length <= MaxSHA256HashBytes); using (HMACSHA256 hmac = new HMACSHA256(key)) { byte[] computedHash = hmac.ComputeHash(plainText); // Truncate the hash if needed Buffer.BlockCopy (computedHash, 0, hash, 0, hash.Length); } } /// /// Computes SHA256 hash of a given input /// /// input byte array which needs to be hashed /// Returns SHA256 hash in a string form internal static string GetSHA256Hash(byte[] input) { Debug.Assert(input != null); using (SHA256 sha256 = SHA256Cng.Create()) { byte[] hashValue = sha256.ComputeHash(input); return GetHexString(hashValue); } } /// /// Generates cryptographicall random bytes /// /// No of cryptographically random bytes to be generated /// A byte array containing cryptographically generated random bytes internal static void GenerateRandomBytes(byte[] randomBytes) { // Generate random bytes cryptographically. RNGCryptoServiceProvider rngCsp = new RNGCryptoServiceProvider(); rngCsp.GetBytes(randomBytes); } /// /// Compares two byte arrays and returns true if all bytes are equal /// /// input buffer /// another buffer to be compared against /// returns true if both the arrays have the same byte values else returns false internal static bool CompareBytes(byte[] buffer1, byte[] buffer2, int buffer2Index, int lengthToCompare) { if (null == buffer1 || null == buffer2) { return false; } Debug.Assert (buffer2Index > -1 && buffer2Index < buffer2.Length, "invalid index");// bounds on buffer2Index if ((buffer2.Length -buffer2Index) < lengthToCompare) { return false; } for (int index = 0; index < buffer1.Length && index < lengthToCompare; ++index) { if (buffer1[index] != buffer2[buffer2Index + index]) { return false; } } return true; } /// /// Gets hex representation of byte array. /// input byte array /// internal static string GetHexString(byte[] input) { Debug.Assert(input != null); StringBuilder str = new StringBuilder(); foreach (byte b in input) { str.AppendFormat(b.ToString(@"X2")); } return str.ToString(); } /// /// Returns the caller's function name in the format of [ClassName].[FunctionName] /// internal static string GetCurrentFunctionName() { StackTrace stackTrace = new StackTrace(); StackFrame stackFrame = stackTrace.GetFrame(1); MethodBase methodBase = stackFrame.GetMethod(); return string.Format(@"{0}.{1}", methodBase.DeclaringType.Name, methodBase.Name); } /// /// Return the algorithm name mapped to an Id. /// /// /// private static string ValidateAndGetEncryptionAlgorithmName (byte cipherAlgorithmId, string cipherAlgorithmName) { if (TdsEnums.CustomCipherAlgorithmId == cipherAlgorithmId) { if (null == cipherAlgorithmName) { throw SQL.NullColumnEncryptionAlgorithm(SqlClientEncryptionAlgorithmFactoryList.GetInstance().GetRegisteredCipherAlgorithmNames()); } return cipherAlgorithmName; } else if (TdsEnums.AEAD_AES_256_CBC_HMAC_SHA256 == cipherAlgorithmId) { return SqlAeadAes256CbcHmac256Algorithm.AlgorithmName; } else if (TdsEnums.AES_256_CBC == cipherAlgorithmId) { return SqlAes256CbcAlgorithm.AlgorithmName; } else { throw SQL.UnknownColumnEncryptionAlgorithmId(cipherAlgorithmId, GetRegisteredCipherAlgorithmIds()); } } /// /// Retrieves a string with comma separated list of registered algorithm Ids (enclosed in quotes). /// private static string GetRegisteredCipherAlgorithmIds () { return @"'1', '2'"; } /// /// Encrypts the plaintext. /// internal static byte[] EncryptWithKey (byte[] plainText, SqlCipherMetadata md, string serverName) { Debug.Assert(serverName != null, @"serverName should not be null in EncryptWithKey."); // Initialize cipherAlgo if not already done. if (!md.IsAlgorithmInitialized()) { SqlSecurityUtility.DecryptSymmetricKey(md, serverName); } Debug.Assert(md.IsAlgorithmInitialized(), "Encryption Algorithm is not initialized"); byte[] cipherText = md.CipherAlgorithm.EncryptData(plainText); // this call succeeds or throws. if (null == cipherText || 0 == cipherText.Length) { SQL.NullCipherText(); } return cipherText; } /// /// Gets a string with first/last 10 bytes in the buff (useful for exception handling). /// internal static string GetBytesAsString(byte[] buff, bool fLast, int countOfBytes) { int count = (buff.Length > countOfBytes) ? countOfBytes : buff.Length; int startIndex = 0; if (fLast) { startIndex = buff.Length - count; Debug.Assert(startIndex >= 0); } return BitConverter.ToString(buff, startIndex, count); } /// /// Decrypts the ciphertext. /// internal static byte[] DecryptWithKey(byte[] cipherText, SqlCipherMetadata md, string serverName) { Debug.Assert(serverName != null, @"serverName should not be null in DecryptWithKey."); // Initialize cipherAlgo if not already done. if (!md.IsAlgorithmInitialized()) { SqlSecurityUtility.DecryptSymmetricKey(md, serverName); } Debug.Assert(md.IsAlgorithmInitialized(), "Decryption Algorithm is not initialized"); try { byte[] plainText = md.CipherAlgorithm.DecryptData(cipherText); // this call succeeds or throws. if (null == plainText) { throw SQL.NullPlainText (); } return plainText; } catch (Exception e) { // compute the strings to pass string keyStr = GetBytesAsString(md.EncryptionKeyInfo.Value.encryptedKey, fLast:true, countOfBytes:10); string valStr = GetBytesAsString(cipherText, fLast:false, countOfBytes:10); throw SQL.ThrowDecryptionFailed(keyStr, valStr, e); } } /// /// Decrypts the symmetric key and saves it in metadata. In addition, intializes /// the SqlClientEncryptionAlgorithm for rapid decryption. /// internal static void DecryptSymmetricKey(SqlCipherMetadata md, string serverName) { Debug.Assert(serverName != null, @"serverName should not be null in DecryptSymmetricKey."); Debug.Assert(md != null, "md should not be null in DecryptSymmetricKey."); Debug.Assert(md.EncryptionInfo.HasValue, "md.EncryptionInfo should not be null in DecryptSymmetricKey."); Debug.Assert(md.EncryptionInfo.Value.ColumnEncryptionKeyValues != null, "md.EncryptionInfo.ColumnEncryptionKeyValues should not be null in DecryptSymmetricKey."); SqlClientSymmetricKey symKey = null; SqlEncryptionKeyInfo? encryptionkeyInfoChosen = null; SqlSymmetricKeyCache cache = SqlSymmetricKeyCache.GetInstance(); Exception lastException = null; foreach (SqlEncryptionKeyInfo keyInfo in md.EncryptionInfo.Value.ColumnEncryptionKeyValues) { try { if (cache.GetKey(keyInfo, serverName, out symKey)) { encryptionkeyInfoChosen = keyInfo; break; } } catch (Exception e) { lastException = e; } } if (null == symKey) { Debug.Assert (null != lastException, "CEK decryption failed without raising exceptions"); throw lastException; } Debug.Assert(encryptionkeyInfoChosen.HasValue, "encryptionkeyInfoChosen must have a value."); // Given the symmetric key instantiate a SqlClientEncryptionAlgorithm object and cache it in metadata md.CipherAlgorithm = null; SqlClientEncryptionAlgorithm cipherAlgorithm = null; string algorithmName = ValidateAndGetEncryptionAlgorithmName(md.CipherAlgorithmId, md.CipherAlgorithmName); // may throw SqlClientEncryptionAlgorithmFactoryList.GetInstance().GetAlgorithm(symKey, md.EncryptionType, algorithmName, out cipherAlgorithm); // will validate algorithm name and type Debug.Assert(cipherAlgorithm != null); md.CipherAlgorithm = cipherAlgorithm; md.EncryptionKeyInfo = encryptionkeyInfoChosen; return; } /// /// Calculates the length of the Base64 string used to represent a byte[] with the specified length. /// /// /// internal static int GetBase64LengthFromByteLength(int byteLength) { Debug.Assert(byteLength <= UInt16.MaxValue, @"Encrypted column encryption key cannot be larger than 65536 bytes"); // Base64 encoding uses 1 character to encode 6 bits which means 4 characters for 3 bytes and pads to 4 byte multiples. return (int)((double)byteLength * 4 / 3) + 4; } } }