// // Copyright (c) Microsoft Corporation. All rights reserved. // namespace System.ServiceModel.Channels { using System; using System.IdentityModel.Selectors; using System.IO; using System.Net; using System.Net.WebSockets; using System.Runtime; using System.ServiceModel.Diagnostics.Application; using System.ServiceModel.Security; using System.ServiceModel.Security.Tokens; class ClientWebSocketTransportDuplexSessionChannel : WebSocketTransportDuplexSessionChannel { readonly ClientWebSocketFactory connectionFactory; HttpChannelFactory channelFactory; Stream connection; SecurityTokenProviderContainer webRequestTokenProvider; SecurityTokenProviderContainer webRequestProxyTokenProvider; HttpWebRequest httpWebRequest; string webSocketKey; volatile bool cleanupStarted; volatile bool cleanupIdentity; static ClientWebSocketTransportDuplexSessionChannel() { WebSocket.RegisterPrefixes(); } public ClientWebSocketTransportDuplexSessionChannel(HttpChannelFactory channelFactory, ClientWebSocketFactory connectionFactory, EndpointAddress remoteAddresss, Uri via, ConnectionBufferPool bufferPool) : base(channelFactory, remoteAddresss, via, bufferPool) { this.channelFactory = channelFactory; this.connectionFactory = connectionFactory; } protected override bool IsStreamedOutput { get { return TransferModeHelper.IsRequestStreamed(this.TransferMode); } } protected override IAsyncResult OnBeginOpen(TimeSpan timeout, AsyncCallback callback, object state) { bool success = false; try { if (TD.WebSocketConnectionRequestSendStartIsEnabled()) { TD.WebSocketConnectionRequestSendStart( this.EventTraceActivity, this.RemoteAddress != null ? this.RemoteAddress.ToString() : string.Empty); } this.httpWebRequest = this.CreateHttpWebRequest(timeout); IAsyncResult result = this.httpWebRequest.BeginGetResponse(callback, state); success = true; return result; } catch (WebException ex) { if (TD.WebSocketConnectionFailedIsEnabled()) { TD.WebSocketConnectionFailed(this.EventTraceActivity, ex.Message); } TryConvertAndThrow(ex); throw FxTrace.Exception.AsError(HttpChannelUtilities.CreateRequestWebException(ex, this.httpWebRequest, HttpAbortReason.None)); } finally { if (!success) { this.CleanupTokenProviders(); this.CleanupOnError(this.httpWebRequest, null); } } } protected override void OnEndOpen(IAsyncResult result) { bool success = false; HttpWebResponse response = null; try { response = (HttpWebResponse)this.httpWebRequest.EndGetResponse(result); this.HandleHttpWebResponse(this.httpWebRequest, response); this.RemoveIdentityMapping(false); success = true; if (TD.WebSocketConnectionRequestSendStopIsEnabled()) { TD.WebSocketConnectionRequestSendStop( this.EventTraceActivity, this.WebSocket != null ? this.WebSocket.GetHashCode() : -1); } } catch (WebException ex) { if (TD.WebSocketConnectionFailedIsEnabled()) { TD.WebSocketConnectionFailed(this.EventTraceActivity, ex.Message); } TryConvertAndThrow(ex); throw FxTrace.Exception.AsError(HttpChannelUtilities.CreateRequestWebException(ex, this.httpWebRequest, HttpAbortReason.None)); } finally { this.CleanupTokenProviders(); if (!success) { this.CleanupOnError(this.httpWebRequest, response); } } } protected override void OnOpen(TimeSpan timeout) { TimeoutHelper helper = new TimeoutHelper(timeout); HttpWebRequest request = null; HttpWebResponse response = null; bool success = false; try { if (TD.WebSocketConnectionRequestSendStartIsEnabled()) { TD.WebSocketConnectionRequestSendStart( this.EventTraceActivity, this.RemoteAddress != null ? this.RemoteAddress.ToString() : string.Empty); } request = this.CreateHttpWebRequest(helper.RemainingTime()); response = (HttpWebResponse)request.GetResponse(); this.HandleHttpWebResponse(request, response); this.RemoveIdentityMapping(false); success = true; if (TD.WebSocketConnectionRequestSendStopIsEnabled()) { TD.WebSocketConnectionRequestSendStop( this.EventTraceActivity, this.WebSocket != null ? this.WebSocket.GetHashCode() : -1); } } catch (WebException ex) { if (TD.WebSocketConnectionFailedIsEnabled()) { TD.WebSocketConnectionFailed(this.EventTraceActivity, ex.Message); } TryConvertAndThrow(ex); throw FxTrace.Exception.AsError(HttpChannelUtilities.CreateRequestWebException(ex, request, HttpAbortReason.None)); } finally { this.CleanupTokenProviders(); if (!success) { this.CleanupOnError(request, response); } } } protected override void OnCleanup() { this.cleanupStarted = true; base.OnCleanup(); if (this.connection != null) { this.connection.Close(); } } static void CheckResponseHeader(HttpWebResponse response, string headerKey, string expectedValue, bool ignoreCase) { string actualValue = response.Headers[headerKey]; if (actualValue == null) { throw FxTrace.Exception.AsError(new CommunicationException( SR.GetString(SR.WebSocketTransportError), new WebSocketException(SR.GetString( SR.WebSocketUpgradeFailedHeaderMissingError, headerKey)))); } StringComparison comparisonType = ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal; if (!actualValue.Equals(expectedValue, comparisonType)) { throw FxTrace.Exception.AsError(new CommunicationException( SR.GetString(SR.WebSocketTransportError), new WebSocketException(SR.GetString( SR.WebSocketUpgradeFailedWrongHeaderError, headerKey, actualValue, expectedValue)))); } } static void TryConvertAndThrow(WebException ex) { if (ex.Response != null) { HttpWebResponse webResponse = (HttpWebResponse)ex.Response; if (webResponse.StatusCode == HttpStatusCode.BadRequest) { string serverContentType = webResponse.Headers[WebSocketTransportSettings.SoapContentTypeHeader]; if (!string.IsNullOrWhiteSpace(serverContentType)) { string serverTransferMode = webResponse.Headers[WebSocketTransportSettings.BinaryEncoderTransferModeHeader]; if (!string.IsNullOrWhiteSpace(serverTransferMode)) { throw FxTrace.Exception.AsError(new CommunicationException(SR.GetString(SR.WebSocketContentTypeAndTransferModeMismatchFromServer), ex)); } else { throw FxTrace.Exception.AsError(new CommunicationException(SR.GetString(SR.WebSocketContentTypeMismatchFromServer), ex)); } } } else if (webResponse.StatusCode == HttpStatusCode.UpgradeRequired) { string serverVersion = webResponse.Headers[WebSocketHelper.SecWebSocketVersion]; if (!string.IsNullOrWhiteSpace(serverVersion)) { throw FxTrace.Exception.AsError(new CommunicationException(SR.GetString(SR.WebSocketVersionMismatchFromServer, serverVersion), ex)); } string serverSubProtocol = webResponse.Headers[WebSocketHelper.SecWebSocketProtocol]; if (!string.IsNullOrWhiteSpace(serverSubProtocol)) { throw FxTrace.Exception.AsError(new CommunicationException(SR.GetString(SR.WebSocketSubProtocolMismatchFromServer, serverSubProtocol), ex)); } } } } void ConfigureHttpWebRequestHeader(HttpWebRequest request) { if (this.WebSocketSettings.SubProtocol != null) { request.Headers[WebSocketHelper.SecWebSocketProtocol] = this.WebSocketSettings.SubProtocol; } // These headers were added for WCF specific handshake to avoid encoder or transfermode mismatch between client and server. // For BinaryMessageEncoder, since we are using a sessionful channel for websocket, the encoder is actually different when // we are using Buffered or Stramed transfermode. So we need an extra header to identify the transfermode we are using, just // to make people a little bit easier to diagnose these mismatch issues. if (this.channelFactory.MessageVersion != MessageVersion.None) { request.Headers[WebSocketTransportSettings.SoapContentTypeHeader] = this.channelFactory.WebSocketSoapContentType; if (this.channelFactory.MessageEncoderFactory is BinaryMessageEncoderFactory) { request.Headers[WebSocketTransportSettings.BinaryEncoderTransferModeHeader] = this.channelFactory.TransferMode.ToString(); } } } void CleanupOnError(HttpWebRequest request, HttpWebResponse response) { if (response != null) { response.Close(); } if (request != null) { request.Abort(); } this.Cleanup(); this.RemoveIdentityMapping(true); } void RemoveIdentityMapping(bool aborting) { if (this.cleanupIdentity) { lock (this.ThisLock) { if (this.cleanupIdentity) { this.cleanupIdentity = false; HttpTransportSecurityHelpers.RemoveIdentityMapping(Via, RemoteAddress, !aborting); } } } } HttpWebRequest CreateHttpWebRequest(TimeSpan timeout) { TimeoutHelper helper = new TimeoutHelper(timeout); ChannelParameterCollection channelParameterCollection = new ChannelParameterCollection(); HttpWebRequest request; if (HttpChannelFactory.MapIdentity(this.RemoteAddress, this.channelFactory.AuthenticationScheme)) { lock (ThisLock) { this.cleanupIdentity = HttpTransportSecurityHelpers.AddIdentityMapping(Via, RemoteAddress); } } this.channelFactory.CreateAndOpenTokenProviders( this.RemoteAddress, this.Via, channelParameterCollection, helper.RemainingTime(), out this.webRequestTokenProvider, out this.webRequestProxyTokenProvider); SecurityTokenContainer clientCertificateToken = null; HttpsChannelFactory httpsChannelFactory = this.channelFactory as HttpsChannelFactory; if (httpsChannelFactory != null && httpsChannelFactory.RequireClientCertificate) { SecurityTokenProvider certificateProvider = httpsChannelFactory.CreateAndOpenCertificateTokenProvider(this.RemoteAddress, this.Via, channelParameterCollection, helper.RemainingTime()); clientCertificateToken = httpsChannelFactory.GetCertificateSecurityToken(certificateProvider, this.RemoteAddress, this.Via, channelParameterCollection, ref helper); } request = this.channelFactory.GetWebRequest(this.RemoteAddress, this.Via, this.webRequestTokenProvider, this.webRequestProxyTokenProvider, clientCertificateToken, helper.RemainingTime(), true); // If a web socket connection factory is specified (for example, when using web sockets on pre-Win8 OS), // we're going to use the protocol version from it. At the moment, on pre-Win8 OS, the HttpWebRequest // created above doesn't have the version header specified. if (this.connectionFactory != null) { this.UseWebSocketVersionFromFactory(request); } this.webSocketKey = request.Headers[WebSocketHelper.SecWebSocketKey]; this.ConfigureHttpWebRequestHeader(request); request.Timeout = (int)helper.RemainingTime().TotalMilliseconds; return request; } void CleanupTokenProviders() { if (this.webRequestTokenProvider != null) { this.webRequestTokenProvider.Abort(); this.webRequestTokenProvider = null; } if (this.webRequestProxyTokenProvider != null) { this.webRequestProxyTokenProvider.Abort(); this.webRequestProxyTokenProvider = null; } } [System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, FxCop.Rule.WrapExceptionsRule, Justification = "The exception thrown here is already wrapped.")] void HandleHttpWebResponse(HttpWebRequest request, HttpWebResponse response) { this.ValidateHttpWebResponse(response); this.connection = response.GetResponseStream(); WebSocket clientWebSocket = null; try { if (this.connectionFactory != null) { this.WebSocket = clientWebSocket = this.CreateWebSocketWithFactory(); } else { byte[] internalBuffer = this.TakeBuffer(); try { this.WebSocket = clientWebSocket = WebSocket.CreateClientWebSocket( this.connection, this.WebSocketSettings.SubProtocol, WebSocketHelper.GetReceiveBufferSize(this.channelFactory.MaxReceivedMessageSize), WebSocketDefaults.BufferSize, this.WebSocketSettings.GetEffectiveKeepAliveInterval(), this.WebSocketSettings.DisablePayloadMasking, new ArraySegment(internalBuffer)); } finally { // even when setting this.InternalBuffer in the finally block // there is still a potential race condition, which could result // in not returning 'internalBuffer' to the pool. // This is acceptable since it is extremely unlikely, only for // the error case and there is no big harm if the buffers are // occasionally not returned to the pool. WebSocketBufferPool.Take() // will just allocate new buffers; this.InternalBuffer = internalBuffer; } } } finally { // There is a race condition betwene OnCleanup and OnOpen that // can result in cleaning up while the clientWebSocket instance is // created. In this case OnCleanup won't be called anymore and would // not clean up the WebSocket instance immediately - only GC would // cleanup during finalization. // To avoid this we abort the WebSocket (and implicitly this.connection) if (clientWebSocket != null && this.cleanupStarted) { clientWebSocket.Abort(); CommunicationObjectAbortedException communicationObjectAbortedException = new CommunicationObjectAbortedException( new WebSocketException(WebSocketError.ConnectionClosedPrematurely).Message); FxTrace.Exception.AsWarning(communicationObjectAbortedException); throw communicationObjectAbortedException; } } bool inputUseStreaming = TransferModeHelper.IsResponseStreamed(this.TransferMode); SecurityMessageProperty handshakeReplySecurityMessageProperty = this.channelFactory.CreateReplySecurityProperty(request, response); if (handshakeReplySecurityMessageProperty != null) { this.RemoteSecurity = handshakeReplySecurityMessageProperty; } this.SetMessageSource(new WebSocketMessageSource( this, this.WebSocket, inputUseStreaming, this)); } void ValidateHttpWebResponse(HttpWebResponse response) { if (response.StatusCode != HttpStatusCode.SwitchingProtocols) { throw FxTrace.Exception.AsError(new CommunicationException( SR.GetString(SR.WebSocketTransportError), new WebSocketException(SR.GetString( SR.WebSocketUpgradeFailedError, (int)response.StatusCode, response.StatusDescription, (int)HttpStatusCode.SwitchingProtocols, HttpStatusCode.SwitchingProtocols)))); } CheckResponseHeader(response, HttpTransportDefaults.ConnectionHeader, WebSocketDefaults.WebSocketConnectionHeaderValue, true); CheckResponseHeader(response, HttpTransportDefaults.UpgradeHeader, WebSocketDefaults.WebSocketUpgradeHeaderValue, true); string expectedAcceptHeader = WebSocketHelper.ComputeAcceptHeader(this.webSocketKey); CheckResponseHeader(response, WebSocketHelper.SecWebSocketAccept, expectedAcceptHeader, false); if (this.WebSocketSettings.SubProtocol != null) { CheckResponseHeader(response, WebSocketHelper.SecWebSocketProtocol, this.WebSocketSettings.SubProtocol, true); } else { string headerValue = response.Headers[WebSocketHelper.SecWebSocketProtocol]; if (!string.IsNullOrWhiteSpace(headerValue)) { throw FxTrace.Exception.AsError(new CommunicationException( SR.GetString(SR.WebSocketTransportError), new WebSocketException(SR.GetString( SR.WebSocketUpgradeFailedInvalidProtocolError, headerValue)))); } } } void UseWebSocketVersionFromFactory(HttpWebRequest request) { Fx.Assert(this.connectionFactory != null, "Invalid call: UseWebSocketVersionFromFactory."); if (TD.WebSocketUseVersionFromClientWebSocketFactoryIsEnabled()) { TD.WebSocketUseVersionFromClientWebSocketFactory(this.EventTraceActivity, this.connectionFactory.GetType().FullName); } // Obtain the WebSocketVersion from the factory. string webSocketVersion; try { webSocketVersion = this.connectionFactory.WebSocketVersion; } catch (Exception e) { if (Fx.IsFatal(e)) { throw; } throw FxTrace.Exception.AsError(new InvalidOperationException(SR.GetString(SR.ClientWebSocketFactory_GetWebSocketVersionFailed, this.connectionFactory.GetType().Name), e)); } // The WebSocketVersion is a required http header, to initiate a web-socket connection. if (string.IsNullOrWhiteSpace(webSocketVersion)) { throw FxTrace.Exception.AsError(new InvalidOperationException(SR.GetString(SR.ClientWebSocketFactory_InvalidWebSocketVersion, this.connectionFactory.GetType().Name))); } try { request.Headers[WebSocketHelper.SecWebSocketVersion] = webSocketVersion; } catch (ArgumentException e) { throw FxTrace.Exception.AsError(new InvalidOperationException(SR.GetString(SR.ClientWebSocketFactory_InvalidWebSocketVersion, this.connectionFactory.GetType().Name), e)); } } WebSocket CreateWebSocketWithFactory() { Fx.Assert(this.connectionFactory != null, "Invalid call: CreateWebSocketWithFactory."); if (TD.WebSocketCreateClientWebSocketWithFactoryIsEnabled()) { TD.WebSocketCreateClientWebSocketWithFactory(this.EventTraceActivity, this.connectionFactory.GetType().FullName); } // Create the client WebSocket with the factory. WebSocket ws; try { ws = this.connectionFactory.CreateWebSocket(this.connection, this.WebSocketSettings.Clone()); } catch (Exception e) { if (Fx.IsFatal(e)) { throw; } throw FxTrace.Exception.AsError(new InvalidOperationException(SR.GetString(SR.ClientWebSocketFactory_CreateWebSocketFailed, this.connectionFactory.GetType().Name), e)); } // The returned WebSocket should be valid (non-null), in an opened state and with the same SubProtocol that we requested. if (ws == null) { throw FxTrace.Exception.AsError(new InvalidOperationException(SR.GetString(SR.ClientWebSocketFactory_InvalidWebSocket, this.connectionFactory.GetType().Name))); } else if (ws.State != WebSocketState.Open) { ws.Dispose(); throw FxTrace.Exception.AsError(new InvalidOperationException(SR.GetString(SR.ClientWebSocketFactory_InvalidWebSocket, this.connectionFactory.GetType().Name))); } else { string requested = this.WebSocketSettings.SubProtocol; string obtained = ws.SubProtocol; if (!(requested == null ? string.IsNullOrWhiteSpace(obtained) : requested.Equals(obtained, StringComparison.OrdinalIgnoreCase))) { ws.Dispose(); throw FxTrace.Exception.AsError(new InvalidOperationException(SR.GetString(SR.ClientWebSocketFactory_InvalidSubProtocol, this.connectionFactory.GetType().Name, obtained, requested))); } } return ws; } } }