//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
namespace System.ServiceModel.Channels
{
using System;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IO;
using System.Net.Http;
using System.Net.WebSockets;
using System.Runtime;
using System.Runtime.Diagnostics;
using System.Security.Principal;
using System.ServiceModel.Diagnostics;
using System.ServiceModel.Diagnostics.Application;
using System.ServiceModel.Security;
using System.Threading;
using System.Threading.Tasks;
abstract class WebSocketTransportDuplexSessionChannel : TransportDuplexSessionChannel
{
static AsyncCallback streamedWriteCallback = Fx.ThunkCallback(StreamWriteCallback);
WebSocket webSocket = null;
WebSocketTransportSettings webSocketSettings;
TransferMode transferMode;
int maxBufferSize;
WaitCallback waitCallback;
object state;
WebSocketStream webSocketStream;
byte[] internalBuffer;
ConnectionBufferPool bufferPool;
int cleanupStatus = WebSocketHelper.OperationNotStarted;
ITransportFactorySettings transportFactorySettings;
WebSocketCloseDetails webSocketCloseDetails = new WebSocketCloseDetails();
bool shouldDisposeWebSocketAfterClosed = true;
Exception pendingWritingMessageException;
public WebSocketTransportDuplexSessionChannel(HttpChannelListener channelListener, EndpointAddress localAddress, Uri localVia, ConnectionBufferPool bufferPool)
: base(channelListener, channelListener, localAddress, localVia, EndpointAddress.AnonymousAddress, channelListener.MessageVersion.Addressing.AnonymousUri)
{
Fx.Assert(channelListener.WebSocketSettings != null, "channelListener.WebSocketTransportSettings should not be null.");
this.webSocketSettings = channelListener.WebSocketSettings;
this.transferMode = channelListener.TransferMode;
this.maxBufferSize = channelListener.MaxBufferSize;
this.bufferPool = bufferPool;
this.transportFactorySettings = channelListener;
}
public WebSocketTransportDuplexSessionChannel(HttpChannelFactory channelFactory, EndpointAddress remoteAddresss, Uri via, ConnectionBufferPool bufferPool)
: base(channelFactory, channelFactory, EndpointAddress.AnonymousAddress, channelFactory.MessageVersion.Addressing.AnonymousUri, remoteAddresss, via)
{
Fx.Assert(channelFactory.WebSocketSettings != null, "channelFactory.WebSocketTransportSettings should not be null.");
this.webSocketSettings = channelFactory.WebSocketSettings;
this.transferMode = channelFactory.TransferMode;
this.maxBufferSize = channelFactory.MaxBufferSize;
this.bufferPool = bufferPool;
this.transportFactorySettings = channelFactory;
}
protected WebSocket WebSocket
{
get
{
return this.webSocket;
}
set
{
Fx.Assert(value != null, "value should not be null.");
Fx.Assert(this.webSocket == null, "webSocket should not be set before this set call.");
this.webSocket = value;
}
}
protected WebSocketTransportSettings WebSocketSettings
{
get { return this.webSocketSettings; }
}
protected TransferMode TransferMode
{
get { return this.transferMode; }
}
protected int MaxBufferSize
{
get
{
return this.maxBufferSize;
}
}
protected ITransportFactorySettings TransportFactorySettings
{
get
{
return this.transportFactorySettings;
}
}
protected byte[] InternalBuffer
{
get
{
return this.internalBuffer;
}
set
{
// We allow setting the property to null as long as we don't overwrite an existing non-null 'internalBuffer'. Because otherwise
// we get NullRefs in other places. So if you change/remove the assert below, make sure we still assert for this case.
Fx.Assert(this.internalBuffer == null, "internalBuffer should not be set twice.");
this.internalBuffer = value;
}
}
protected bool ShouldDisposeWebSocketAfterClosed
{
set
{
this.shouldDisposeWebSocketAfterClosed = value;
}
}
protected override void OnAbort()
{
if (TD.WebSocketConnectionAbortedIsEnabled())
{
TD.WebSocketConnectionAborted(
this.EventTraceActivity,
this.WebSocket != null ? this.WebSocket.GetHashCode() : -1);
}
this.Cleanup();
}
public override T GetProperty()
{
if (typeof(T) == typeof(IWebSocketCloseDetails))
{
return this.webSocketCloseDetails as T;
}
return base.GetProperty();
}
protected override void CompleteClose(TimeSpan timeout)
{
if (TD.WebSocketCloseSentIsEnabled())
{
TD.WebSocketCloseSent(
this.WebSocket.GetHashCode(),
this.webSocketCloseDetails.OutputCloseStatus.ToString(),
this.RemoteAddress != null ? this.RemoteAddress.ToString() : string.Empty);
}
Task closeTask = this.CloseAsync();
closeTask.Wait(timeout, WebSocketHelper.ThrowCorrectException, WebSocketHelper.CloseOperation);
if (TD.WebSocketConnectionClosedIsEnabled())
{
TD.WebSocketConnectionClosed(this.WebSocket.GetHashCode());
}
}
protected byte[] TakeBuffer()
{
Fx.Assert(this.bufferPool != null, "'bufferPool' MUST NOT be NULL.");
return this.bufferPool.Take();
}
protected override void CloseOutputSessionCore(TimeSpan timeout)
{
if (TD.WebSocketCloseOutputSentIsEnabled())
{
TD.WebSocketCloseOutputSent(
this.WebSocket.GetHashCode(),
this.webSocketCloseDetails.OutputCloseStatus.ToString(),
this.RemoteAddress != null ? this.RemoteAddress.ToString() : string.Empty);
}
Task task = this.CloseOutputAsync(CancellationToken.None);
task.Wait(timeout, WebSocketHelper.ThrowCorrectException, WebSocketHelper.CloseOperation);
}
protected override void OnClose(TimeSpan timeout)
{
try
{
base.OnClose(timeout);
}
finally
{
this.Cleanup();
}
}
protected override void ReturnConnectionIfNecessary(bool abort, TimeSpan timeout)
{
}
protected override AsyncCompletionResult StartWritingBufferedMessage(Message message, ArraySegment messageData, bool allowOutputBatching, TimeSpan timeout, Threading.WaitCallback callback, object state)
{
Fx.Assert(callback != null, "callback should not be null.");
TimeoutHelper helper = new TimeoutHelper(timeout);
WebSocketMessageType outgoingMessageType = GetWebSocketMessageType(message);
IOThreadCancellationTokenSource cancellationTokenSource = new IOThreadCancellationTokenSource(helper.RemainingTime());
if (TD.WebSocketAsyncWriteStartIsEnabled())
{
TD.WebSocketAsyncWriteStart(
this.WebSocket.GetHashCode(),
messageData.Count,
this.RemoteAddress != null ? this.RemoteAddress.ToString() : string.Empty);
}
Task task = this.WebSocket.SendAsync(messageData, outgoingMessageType, true, cancellationTokenSource.Token);
Fx.Assert(this.pendingWritingMessageException == null, "'pendingWritingMessageException' MUST be NULL at this point.");
task.ContinueWith(t =>
{
try
{
if (TD.WebSocketAsyncWriteStopIsEnabled())
{
TD.WebSocketAsyncWriteStop(this.webSocket.GetHashCode());
}
cancellationTokenSource.Dispose();
WebSocketHelper.ThrowExceptionOnTaskFailure(t, timeout, WebSocketHelper.SendOperation);
}
catch (Exception error)
{
// Intentionally not following the usual pattern to rethrow fatal exceptions.
// Any rethrown exception would just be ----ed, because nobody awaits the
// Task returned from ContinueWith in this case.
FxTrace.Exception.TraceHandledException(error, TraceEventType.Information);
this.pendingWritingMessageException = error;
}
finally
{
callback.Invoke(state);
}
}, CancellationToken.None);
return AsyncCompletionResult.Queued;
}
protected override void FinishWritingMessage()
{
ThrowOnPendingException(ref this.pendingWritingMessageException);
base.FinishWritingMessage();
}
protected override AsyncCompletionResult StartWritingStreamedMessage(Message message, TimeSpan timeout, WaitCallback callback, object state)
{
TimeoutHelper helper = new TimeoutHelper(timeout);
WebSocketMessageType outgoingMessageType = GetWebSocketMessageType(message);
WebSocketStream webSocketStream = new WebSocketStream(this.WebSocket, outgoingMessageType, helper.RemainingTime());
this.waitCallback = callback;
this.state = state;
this.webSocketStream = webSocketStream;
IAsyncResult result = this.MessageEncoder.BeginWriteMessage(message, new TimeoutStream(webSocketStream, ref helper), streamedWriteCallback, this);
if (!result.CompletedSynchronously)
{
return AsyncCompletionResult.Queued;
}
this.MessageEncoder.EndWriteMessage(result);
webSocketStream.WriteEndOfMessageAsync(helper.RemainingTime(), callback, state);
return AsyncCompletionResult.Queued;
}
protected override AsyncCompletionResult BeginCloseOutput(TimeSpan timeout, Threading.WaitCallback callback, object state)
{
Fx.Assert(callback != null, "callback should not be null.");
IOThreadCancellationTokenSource cancellationTokenSource = new IOThreadCancellationTokenSource(timeout);
Task task = this.CloseOutputAsync(cancellationTokenSource.Token);
Fx.Assert(this.pendingWritingMessageException == null, "'pendingWritingMessageException' MUST be NULL at this point.");
task.ContinueWith(t =>
{
try
{
cancellationTokenSource.Dispose();
WebSocketHelper.ThrowExceptionOnTaskFailure(t, timeout, WebSocketHelper.CloseOperation);
}
catch (Exception error)
{
// Intentionally not following the usual pattern to rethrow fatal exceptions.
// Any rethrown exception would just be ----ed, because nobody awaits the
// Task returned from ContinueWith in this case.
FxTrace.Exception.TraceHandledException(error, TraceEventType.Information);
this.pendingWritingMessageException = error;
}
finally
{
callback.Invoke(state);
}
});
return AsyncCompletionResult.Queued;
}
protected override void OnSendCore(Message message, TimeSpan timeout)
{
Fx.Assert(message != null, "message should not be null.");
TimeoutHelper helper = new TimeoutHelper(timeout);
WebSocketMessageType outgoingMessageType = GetWebSocketMessageType(message);
if (this.IsStreamedOutput)
{
WebSocketStream webSocketStream = new WebSocketStream(this.WebSocket, outgoingMessageType, helper.RemainingTime());
TimeoutStream timeoutStream = new TimeoutStream(webSocketStream, ref helper);
this.MessageEncoder.WriteMessage(message, timeoutStream);
webSocketStream.WriteEndOfMessage(helper.RemainingTime());
}
else
{
ArraySegment messageData = this.EncodeMessage(message);
bool success = false;
try
{
if (TD.WebSocketAsyncWriteStartIsEnabled())
{
TD.WebSocketAsyncWriteStart(
this.WebSocket.GetHashCode(),
messageData.Count,
this.RemoteAddress != null ? this.RemoteAddress.ToString() : string.Empty);
}
Task task = this.WebSocket.SendAsync(messageData, outgoingMessageType, true, CancellationToken.None);
task.Wait(helper.RemainingTime(), WebSocketHelper.ThrowCorrectException, WebSocketHelper.SendOperation);
if (TD.WebSocketAsyncWriteStopIsEnabled())
{
TD.WebSocketAsyncWriteStop(this.webSocket.GetHashCode());
}
success = true;
}
finally
{
try
{
this.BufferManager.ReturnBuffer(messageData.Array);
}
catch (Exception ex)
{
if (Fx.IsFatal(ex) || success)
{
throw;
}
FxTrace.Exception.TraceUnhandledException(ex);
}
}
}
}
protected override ArraySegment EncodeMessage(Message message)
{
return MessageEncoder.WriteMessage(message, int.MaxValue, this.BufferManager, 0);
}
protected void Cleanup()
{
if (Interlocked.CompareExchange(ref this.cleanupStatus, WebSocketHelper.OperationFinished, WebSocketHelper.OperationNotStarted) == WebSocketHelper.OperationNotStarted)
{
this.OnCleanup();
}
}
protected virtual void OnCleanup()
{
Fx.Assert(this.cleanupStatus == WebSocketHelper.OperationFinished,
"This method should only be called by this.Cleanup(). Make sure that you never call overriden OnCleanup()-methods directly in subclasses");
if (this.shouldDisposeWebSocketAfterClosed && this.webSocket != null)
{
this.webSocket.Dispose();
}
if (this.internalBuffer != null)
{
this.bufferPool.Return(this.internalBuffer);
this.internalBuffer = null;
}
}
private static void ThrowOnPendingException(ref Exception pendingException)
{
Exception exceptionToThrow = pendingException;
if (exceptionToThrow != null)
{
pendingException = null;
throw FxTrace.Exception.AsError(exceptionToThrow);
}
}
[System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, FxCop.Rule.WrapExceptionsRule, Justification = "The exceptions thrown here are already wrapped.")]
private Task CloseAsync()
{
try
{
return this.WebSocket.CloseAsync(this.webSocketCloseDetails.OutputCloseStatus, this.webSocketCloseDetails.OutputCloseStatusDescription, CancellationToken.None);
}
catch (Exception e)
{
if (Fx.IsFatal(e))
{
throw;
}
throw WebSocketHelper.ConvertAndTraceException(e);
}
}
[System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, FxCop.Rule.WrapExceptionsRule, Justification = "The exceptions thrown here are already wrapped.")]
private Task CloseOutputAsync(CancellationToken cancellationToken)
{
try
{
return this.WebSocket.CloseOutputAsync(this.webSocketCloseDetails.OutputCloseStatus, this.webSocketCloseDetails.OutputCloseStatusDescription, cancellationToken);
}
catch (Exception e)
{
if (Fx.IsFatal(e))
{
throw;
}
throw WebSocketHelper.ConvertAndTraceException(e);
}
}
static WebSocketMessageType GetWebSocketMessageType(Message message)
{
WebSocketMessageType outgoingMessageType = WebSocketDefaults.DefaultWebSocketMessageType;
WebSocketMessageProperty webSocketMessageProperty;
if (message.Properties.TryGetValue(WebSocketMessageProperty.Name, out webSocketMessageProperty))
{
outgoingMessageType = webSocketMessageProperty.MessageType;
}
return outgoingMessageType;
}
static void StreamWriteCallback(IAsyncResult ar)
{
if (ar.CompletedSynchronously)
{
return;
}
WebSocketTransportDuplexSessionChannel thisPtr = (WebSocketTransportDuplexSessionChannel)ar.AsyncState;
try
{
thisPtr.MessageEncoder.EndWriteMessage(ar);
// We are goverend here by the TimeoutStream, no need to pass a CancellationToken here.
thisPtr.webSocketStream.WriteEndOfMessage(TimeSpan.MaxValue);
thisPtr.waitCallback.Invoke(thisPtr.state);
}
catch (Exception ex)
{
if (Fx.IsFatal(ex))
{
throw;
}
thisPtr.AddPendingException(ex);
}
}
protected class WebSocketMessageSource : IMessageSource
{
static readonly Action