Jo Shields a575963da9 Imported Upstream version 3.6.0
Former-commit-id: da6be194a6b1221998fc28233f2503bd61dd9d14
2014-08-13 10:39:27 +01:00

983 lines
23 KiB
C#

// Transport Security Layer (TLS)
// Copyright (c) 2003-2004 Carlos Guzman Alvarez
// Copyright (C) 2006-2007 Novell, Inc (http://www.novell.com)
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//
using System;
using System.Collections;
using System.IO;
using System.Threading;
using Mono.Security.Protocol.Tls.Handshake;
namespace Mono.Security.Protocol.Tls
{
internal abstract class RecordProtocol
{
#region Fields
private static ManualResetEvent record_processing = new ManualResetEvent (true);
protected Stream innerStream;
protected Context context;
#endregion
#region Properties
public Context Context
{
get { return this.context; }
set { this.context = value; }
}
#endregion
#region Constructors
public RecordProtocol(Stream innerStream, Context context)
{
this.innerStream = innerStream;
this.context = context;
this.context.RecordProtocol = this;
}
#endregion
#region Abstract Methods
public virtual void SendRecord(HandshakeType type)
{
IAsyncResult ar = this.BeginSendRecord(type, null, null);
this.EndSendRecord(ar);
}
protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
protected virtual void ProcessChangeCipherSpec ()
{
Context ctx = this.Context;
// Reset sequence numbers
ctx.ReadSequenceNumber = 0;
if (ctx is ClientContext) {
ctx.EndSwitchingSecurityParameters (true);
} else {
ctx.StartSwitchingSecurityParameters (false);
}
}
public virtual HandshakeMessage GetMessage(HandshakeType type)
{
throw new NotSupportedException();
}
#endregion
#region Receive Record Async Result
private class ReceiveRecordAsyncResult : IAsyncResult
{
private object locker = new object ();
private AsyncCallback _userCallback;
private object _userState;
private Exception _asyncException;
private ManualResetEvent handle;
private byte[] _resultingBuffer;
private Stream _record;
private bool completed;
private byte[] _initialBuffer;
public ReceiveRecordAsyncResult(AsyncCallback userCallback, object userState, byte[] initialBuffer, Stream record)
{
_userCallback = userCallback;
_userState = userState;
_initialBuffer = initialBuffer;
_record = record;
}
public Stream Record
{
get { return _record; }
}
public byte[] ResultingBuffer
{
get { return _resultingBuffer; }
}
public byte[] InitialBuffer
{
get { return _initialBuffer; }
}
public object AsyncState
{
get { return _userState; }
}
public Exception AsyncException
{
get { return _asyncException; }
}
public bool CompletedWithError
{
get {
if (!IsCompleted)
return false; // Perhaps throw InvalidOperationExcetion?
return null != _asyncException;
}
}
public WaitHandle AsyncWaitHandle
{
get {
lock (locker) {
if (handle == null)
handle = new ManualResetEvent (completed);
}
return handle;
}
}
public bool CompletedSynchronously
{
get { return false; }
}
public bool IsCompleted
{
get {
lock (locker) {
return completed;
}
}
}
private void SetComplete(Exception ex, byte[] resultingBuffer)
{
lock (locker) {
if (completed)
return;
completed = true;
_asyncException = ex;
_resultingBuffer = resultingBuffer;
if (handle != null)
handle.Set ();
if (_userCallback != null)
_userCallback.BeginInvoke (this, null, null);
}
}
public void SetComplete(Exception ex)
{
SetComplete(ex, null);
}
public void SetComplete(byte[] resultingBuffer)
{
SetComplete(null, resultingBuffer);
}
public void SetComplete()
{
SetComplete(null, null);
}
}
#endregion
#region Receive Record Async Result
private class SendRecordAsyncResult : IAsyncResult
{
private object locker = new object ();
private AsyncCallback _userCallback;
private object _userState;
private Exception _asyncException;
private ManualResetEvent handle;
private HandshakeMessage _message;
private bool completed;
public SendRecordAsyncResult(AsyncCallback userCallback, object userState, HandshakeMessage message)
{
_userCallback = userCallback;
_userState = userState;
_message = message;
}
public HandshakeMessage Message
{
get { return _message; }
}
public object AsyncState
{
get { return _userState; }
}
public Exception AsyncException
{
get { return _asyncException; }
}
public bool CompletedWithError
{
get {
if (!IsCompleted)
return false; // Perhaps throw InvalidOperationExcetion?
return null != _asyncException;
}
}
public WaitHandle AsyncWaitHandle
{
get {
lock (locker) {
if (handle == null)
handle = new ManualResetEvent (completed);
}
return handle;
}
}
public bool CompletedSynchronously
{
get { return false; }
}
public bool IsCompleted
{
get {
lock (locker) {
return completed;
}
}
}
public void SetComplete(Exception ex)
{
lock (locker) {
if (completed)
return;
completed = true;
if (handle != null)
handle.Set ();
if (_userCallback != null)
_userCallback.BeginInvoke (this, null, null);
_asyncException = ex;
}
}
public void SetComplete()
{
SetComplete(null);
}
}
#endregion
#region Reveive Record Methods
public IAsyncResult BeginReceiveRecord(Stream record, AsyncCallback callback, object state)
{
if (this.context.ReceivedConnectionEnd)
{
throw new TlsException(
AlertDescription.InternalError,
"The session is finished and it's no longer valid.");
}
record_processing.Reset ();
byte[] recordTypeBuffer = new byte[1];
ReceiveRecordAsyncResult internalResult = new ReceiveRecordAsyncResult(callback, state, recordTypeBuffer, record);
record.BeginRead(internalResult.InitialBuffer, 0, internalResult.InitialBuffer.Length, new AsyncCallback(InternalReceiveRecordCallback), internalResult);
return internalResult;
}
private void InternalReceiveRecordCallback(IAsyncResult asyncResult)
{
ReceiveRecordAsyncResult internalResult = asyncResult.AsyncState as ReceiveRecordAsyncResult;
Stream record = internalResult.Record;
try
{
int bytesRead = internalResult.Record.EndRead(asyncResult);
//We're at the end of the stream. Time to bail.
if (bytesRead == 0)
{
internalResult.SetComplete((byte[])null);
return;
}
// Try to read the Record Content Type
int type = internalResult.InitialBuffer[0];
// Set last handshake message received to None
this.context.LastHandshakeMsg = HandshakeType.ClientHello;
ContentType contentType = (ContentType)type;
byte[] buffer = this.ReadRecordBuffer(type, record);
if (buffer == null)
{
// record incomplete (at the moment)
internalResult.SetComplete((byte[])null);
return;
}
// Decrypt message contents if needed
if (contentType == ContentType.Alert && buffer.Length == 2)
{
}
else if ((this.Context.Read != null) && (this.Context.Read.Cipher != null))
{
buffer = this.decryptRecordFragment (contentType, buffer);
DebugHelper.WriteLine ("Decrypted record data", buffer);
}
// Process record
switch (contentType)
{
case ContentType.Alert:
this.ProcessAlert((AlertLevel)buffer [0], (AlertDescription)buffer [1]);
if (record.CanSeek)
{
// don't reprocess that memory block
record.SetLength (0);
}
buffer = null;
break;
case ContentType.ChangeCipherSpec:
this.ProcessChangeCipherSpec();
break;
case ContentType.ApplicationData:
break;
case ContentType.Handshake:
TlsStream message = new TlsStream (buffer);
while (!message.EOF)
{
this.ProcessHandshakeMessage(message);
}
break;
case (ContentType)0x80:
this.context.HandshakeMessages.Write (buffer);
break;
default:
throw new TlsException(
AlertDescription.UnexpectedMessage,
"Unknown record received from server.");
}
internalResult.SetComplete(buffer);
}
catch (Exception ex)
{
internalResult.SetComplete(ex);
}
}
public byte[] EndReceiveRecord(IAsyncResult asyncResult)
{
ReceiveRecordAsyncResult internalResult = asyncResult as ReceiveRecordAsyncResult;
if (null == internalResult)
throw new ArgumentException("Either the provided async result is null or was not created by this RecordProtocol.");
if (!internalResult.IsCompleted)
internalResult.AsyncWaitHandle.WaitOne();
if (internalResult.CompletedWithError)
throw internalResult.AsyncException;
byte[] result = internalResult.ResultingBuffer;
record_processing.Set ();
return result;
}
public byte[] ReceiveRecord(Stream record)
{
IAsyncResult ar = this.BeginReceiveRecord(record, null, null);
return this.EndReceiveRecord(ar);
}
private byte[] ReadRecordBuffer (int contentType, Stream record)
{
switch (contentType)
{
case 0x80:
return this.ReadClientHelloV2(record);
default:
if (!Enum.IsDefined(typeof(ContentType), (ContentType)contentType))
{
throw new TlsException(AlertDescription.DecodeError);
}
return this.ReadStandardRecordBuffer(record);
}
}
private byte[] ReadClientHelloV2 (Stream record)
{
int msgLength = record.ReadByte ();
// process further only if the whole record is available
if (record.CanSeek && (msgLength + 1 > record.Length))
{
return null;
}
byte[] message = new byte[msgLength];
record.Read (message, 0, msgLength);
int msgType = message [0];
if (msgType != 1)
{
throw new TlsException(AlertDescription.DecodeError);
}
int protocol = (message [1] << 8 | message [2]);
int cipherSpecLength = (message [3] << 8 | message [4]);
int sessionIdLength = (message [5] << 8 | message [6]);
int challengeLength = (message [7] << 8 | message [8]);
int length = (challengeLength > 32) ? 32 : challengeLength;
// Read CipherSpecs
byte[] cipherSpecV2 = new byte[cipherSpecLength];
Buffer.BlockCopy (message, 9, cipherSpecV2, 0, cipherSpecLength);
// Read session ID
byte[] sessionId = new byte[sessionIdLength];
Buffer.BlockCopy (message, 9 + cipherSpecLength, sessionId, 0, sessionIdLength);
// Read challenge ID
byte[] challenge = new byte[challengeLength];
Buffer.BlockCopy (message, 9 + cipherSpecLength + sessionIdLength, challenge, 0, challengeLength);
if (challengeLength < 16 || cipherSpecLength == 0 || (cipherSpecLength % 3) != 0)
{
throw new TlsException(AlertDescription.DecodeError);
}
// Updated the Session ID
if (sessionId.Length > 0)
{
this.context.SessionId = sessionId;
}
// Update the protocol version
this.Context.ChangeProtocol((short)protocol);
// Select the Cipher suite
this.ProcessCipherSpecV2Buffer(this.Context.SecurityProtocol, cipherSpecV2);
// Updated the Client Random
this.context.ClientRandom = new byte [32]; // Always 32
// 1. if challenge is bigger than 32 bytes only use the last 32 bytes
// 2. right justify (0) challenge in ClientRandom if less than 32
Buffer.BlockCopy (challenge, challenge.Length - length, this.context.ClientRandom, 32 - length, length);
// Set
this.context.LastHandshakeMsg = HandshakeType.ClientHello;
this.context.ProtocolNegotiated = true;
return message;
}
private byte[] ReadStandardRecordBuffer (Stream record)
{
byte[] header = new byte[4];
if (record.Read (header, 0, 4) != 4)
throw new TlsException ("buffer underrun");
short protocol = (short)((header [0] << 8) | header [1]);
short length = (short)((header [2] << 8) | header [3]);
// process further only if the whole record is available
// note: the first 5 bytes aren't part of the length
if (record.CanSeek && (length + 5 > record.Length))
{
return null;
}
// Read Record data
int totalReceived = 0;
byte[] buffer = new byte[length];
while (totalReceived != length)
{
int justReceived = record.Read(buffer, totalReceived, buffer.Length - totalReceived);
//Make sure we get some data so we don't end up in an infinite loop here before shutdown.
if (0 == justReceived)
{
throw new TlsException(AlertDescription.CloseNotify, "Received 0 bytes from stream. It must be closed.");
}
totalReceived += justReceived;
}
// Check that the message has a valid protocol version
if (protocol != this.context.Protocol && this.context.ProtocolNegotiated)
{
throw new TlsException(
AlertDescription.ProtocolVersion, "Invalid protocol version on message received");
}
DebugHelper.WriteLine("Record data", buffer);
return buffer;
}
private void ProcessAlert(AlertLevel alertLevel, AlertDescription alertDesc)
{
switch (alertLevel)
{
case AlertLevel.Fatal:
throw new TlsException(alertLevel, alertDesc);
case AlertLevel.Warning:
default:
switch (alertDesc)
{
case AlertDescription.CloseNotify:
this.context.ReceivedConnectionEnd = true;
break;
}
break;
}
}
#endregion
#region Send Alert Methods
public void SendAlert(AlertDescription description)
{
this.SendAlert(new Alert(description));
}
public void SendAlert(
AlertLevel level,
AlertDescription description)
{
this.SendAlert(new Alert(level, description));
}
public void SendAlert(Alert alert)
{
AlertLevel level;
AlertDescription description;
bool close;
if (alert == null) {
DebugHelper.WriteLine(">>>> Write Alert NULL");
level = AlertLevel.Fatal;
description = AlertDescription.InternalError;
close = true;
} else {
DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
level = alert.Level;
description = alert.Description;
close = alert.IsCloseNotify;
}
// Write record
this.SendRecord (ContentType.Alert, new byte[2] { (byte) level, (byte) description });
if (close) {
this.context.SentConnectionEnd = true;
}
}
#endregion
#region Send Record Methods
public void SendChangeCipherSpec()
{
DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
// Send Change Cipher Spec message with the current cipher
// or as plain text if this is the initial negotiation
this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
Context ctx = this.context;
// Reset sequence numbers
ctx.WriteSequenceNumber = 0;
// all further data sent will be encrypted with the negotiated
// security parameters (now the current parameters)
if (ctx is ClientContext) {
ctx.StartSwitchingSecurityParameters (true);
} else {
ctx.EndSwitchingSecurityParameters (false);
}
}
public IAsyncResult BeginSendRecord(HandshakeType handshakeType, AsyncCallback callback, object state)
{
HandshakeMessage msg = this.GetMessage(handshakeType);
msg.Process();
DebugHelper.WriteLine(">>>> Write handshake record ({0}|{1})", context.Protocol, msg.ContentType);
SendRecordAsyncResult internalResult = new SendRecordAsyncResult(callback, state, msg);
this.BeginSendRecord(msg.ContentType, msg.EncodeMessage(), new AsyncCallback(InternalSendRecordCallback), internalResult);
return internalResult;
}
private void InternalSendRecordCallback(IAsyncResult ar)
{
SendRecordAsyncResult internalResult = ar.AsyncState as SendRecordAsyncResult;
try
{
this.EndSendRecord(ar);
// Update session
internalResult.Message.Update();
// Reset message contents
internalResult.Message.Reset();
internalResult.SetComplete();
}
catch (Exception ex)
{
internalResult.SetComplete(ex);
}
}
public IAsyncResult BeginSendRecord(ContentType contentType, byte[] recordData, AsyncCallback callback, object state)
{
if (this.context.SentConnectionEnd)
{
throw new TlsException(
AlertDescription.InternalError,
"The session is finished and it's no longer valid.");
}
byte[] record = this.EncodeRecord(contentType, recordData);
return this.innerStream.BeginWrite(record, 0, record.Length, callback, state);
}
public void EndSendRecord(IAsyncResult asyncResult)
{
if (asyncResult is SendRecordAsyncResult)
{
SendRecordAsyncResult internalResult = asyncResult as SendRecordAsyncResult;
if (!internalResult.IsCompleted)
internalResult.AsyncWaitHandle.WaitOne();
if (internalResult.CompletedWithError)
throw internalResult.AsyncException;
}
else
{
this.innerStream.EndWrite(asyncResult);
}
}
public void SendRecord(ContentType contentType, byte[] recordData)
{
IAsyncResult ar = this.BeginSendRecord(contentType, recordData, null, null);
this.EndSendRecord(ar);
}
public byte[] EncodeRecord(ContentType contentType, byte[] recordData)
{
return this.EncodeRecord(
contentType,
recordData,
0,
recordData.Length);
}
public byte[] EncodeRecord(
ContentType contentType,
byte[] recordData,
int offset,
int count)
{
if (this.context.SentConnectionEnd)
{
throw new TlsException(
AlertDescription.InternalError,
"The session is finished and it's no longer valid.");
}
TlsStream record = new TlsStream();
int position = offset;
while (position < ( offset + count ))
{
short fragmentLength = 0;
byte[] fragment;
if ((count + offset - position) > Context.MAX_FRAGMENT_SIZE)
{
fragmentLength = Context.MAX_FRAGMENT_SIZE;
}
else
{
fragmentLength = (short)(count + offset - position);
}
// Fill the fragment data
fragment = new byte[fragmentLength];
Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
if ((this.Context.Write != null) && (this.Context.Write.Cipher != null))
{
// Encrypt fragment
fragment = this.encryptRecordFragment (contentType, fragment);
}
// Write tls message
record.Write((byte)contentType);
record.Write(this.context.Protocol);
record.Write((short)fragment.Length);
record.Write(fragment);
DebugHelper.WriteLine("Record data", fragment);
// Update buffer position
position += fragmentLength;
}
return record.ToArray();
}
#endregion
#region Cryptography Methods
private byte[] encryptRecordFragment(
ContentType contentType,
byte[] fragment)
{
byte[] mac = null;
// Calculate message MAC
if (this.Context is ClientContext)
{
mac = this.context.Write.Cipher.ComputeClientRecordMAC(contentType, fragment);
}
else
{
mac = this.context.Write.Cipher.ComputeServerRecordMAC (contentType, fragment);
}
DebugHelper.WriteLine(">>>> Record MAC", mac);
// Encrypt the message
byte[] ecr = this.context.Write.Cipher.EncryptRecord (fragment, mac);
// Update sequence number
this.context.WriteSequenceNumber++;
return ecr;
}
private byte[] decryptRecordFragment(
ContentType contentType,
byte[] fragment)
{
byte[] dcrFragment = null;
byte[] dcrMAC = null;
try
{
this.context.Read.Cipher.DecryptRecord (fragment, out dcrFragment, out dcrMAC);
}
catch
{
if (this.context is ServerContext)
{
this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
}
throw;
}
// Generate record MAC
byte[] mac = null;
if (this.Context is ClientContext)
{
mac = this.context.Read.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
}
else
{
mac = this.context.Read.Cipher.ComputeClientRecordMAC (contentType, dcrFragment);
}
DebugHelper.WriteLine(">>>> Record MAC", mac);
// Check record MAC
if (!Compare (mac, dcrMAC))
{
throw new TlsException(AlertDescription.BadRecordMAC, "Bad record MAC");
}
// Update sequence number
this.context.ReadSequenceNumber++;
return dcrFragment;
}
private bool Compare (byte[] array1, byte[] array2)
{
if (array1 == null)
return (array2 == null);
if (array2 == null)
return false;
if (array1.Length != array2.Length)
return false;
for (int i = 0; i < array1.Length; i++) {
if (array1[i] != array2[i])
return false;
}
return true;
}
#endregion
#region CipherSpecV2 processing
private void ProcessCipherSpecV2Buffer (SecurityProtocolType protocol, byte[] buffer)
{
TlsStream codes = new TlsStream(buffer);
string prefix = (protocol == SecurityProtocolType.Ssl3) ? "SSL_" : "TLS_";
while (codes.Position < codes.Length)
{
byte check = codes.ReadByte();
if (check == 0)
{
// SSL/TLS cipher spec
short code = codes.ReadInt16();
int index = this.Context.SupportedCiphers.IndexOf(code);
if (index != -1)
{
this.Context.Negotiating.Cipher = this.Context.SupportedCiphers[index];
break;
}
}
else
{
byte[] tmp = new byte[2];
codes.Read(tmp, 0, tmp.Length);
int tmpCode = ((check & 0xff) << 16) | ((tmp[0] & 0xff) << 8) | (tmp[1] & 0xff);
CipherSuite cipher = this.MapV2CipherCode(prefix, tmpCode);
if (cipher != null)
{
this.Context.Negotiating.Cipher = cipher;
break;
}
}
}
if (this.Context.Negotiating == null)
{
throw new TlsException(AlertDescription.InsuficientSecurity, "Insuficient Security");
}
}
private CipherSuite MapV2CipherCode(string prefix, int code)
{
try
{
switch (code)
{
case 65664:
// TLS_RC4_128_WITH_MD5
return this.Context.SupportedCiphers[prefix + "RSA_WITH_RC4_128_MD5"];
case 131200:
// TLS_RC4_128_EXPORT40_WITH_MD5
return this.Context.SupportedCiphers[prefix + "RSA_EXPORT_WITH_RC4_40_MD5"];
case 196736:
// TLS_RC2_CBC_128_CBC_WITH_MD5
return this.Context.SupportedCiphers[prefix + "RSA_EXPORT_WITH_RC2_CBC_40_MD5"];
case 262272:
// TLS_RC2_CBC_128_CBC_EXPORT40_WITH_MD5
return this.Context.SupportedCiphers[prefix + "RSA_EXPORT_WITH_RC2_CBC_40_MD5"];
case 327808:
// TLS_IDEA_128_CBC_WITH_MD5
return null;
case 393280:
// TLS_DES_64_CBC_WITH_MD5
return null;
case 458944:
// TLS_DES_192_EDE3_CBC_WITH_MD5
return null;
default:
return null;
}
}
catch
{
return null;
}
}
#endregion
}
}