You've already forked linux-packaging-mono
							
							
		
			
				
	
	
		
			268 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
			
		
		
	
	
			268 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
| // <copyright>
 | |
| // Copyright (c) Microsoft Corporation.  All rights reserved.
 | |
| // </copyright>
 | |
| 
 | |
| namespace System.ServiceModel.Channels
 | |
| {
 | |
|     using System;
 | |
|     using System.Collections;
 | |
|     using System.Collections.Generic;
 | |
|     using System.Linq;
 | |
|     using System.Net;
 | |
|     using System.Net.Http;
 | |
|     using System.Net.WebSockets;
 | |
|     using System.Runtime;
 | |
|     using System.Threading;
 | |
| 
 | |
|     class DefaultWebSocketConnectionHandler : WebSocketConnectionHandler
 | |
|     {
 | |
|         string currentVersion;
 | |
|         string subProtocol;
 | |
|         MessageEncoder encoder;
 | |
|         string transferMode;
 | |
|         bool needToCheckContentType;
 | |
|         bool needToCheckTransferMode;
 | |
|         Func<string, bool> checkVersionFunc;
 | |
|         Func<string, bool> checkContentTypeFunc;
 | |
|         Func<string, bool> checkTransferModeFunc;
 | |
| 
 | |
|         public DefaultWebSocketConnectionHandler(string subProtocol, string currentVersion, MessageVersion messageVersion, MessageEncoderFactory encoderFactory, TransferMode transferMode)
 | |
|         {
 | |
|             this.subProtocol = subProtocol;
 | |
|             this.currentVersion = currentVersion;
 | |
|             this.checkVersionFunc = new Func<string, bool>(this.CheckVersion);
 | |
| 
 | |
|             if (messageVersion != MessageVersion.None)
 | |
|             {
 | |
|                 this.needToCheckContentType = true;
 | |
|                 this.encoder = encoderFactory.CreateSessionEncoder();
 | |
|                 this.checkContentTypeFunc = new Func<string, bool>(this.CheckContentType);
 | |
| 
 | |
|                 if (encoderFactory is BinaryMessageEncoderFactory)
 | |
|                 {
 | |
|                     this.needToCheckTransferMode = true;
 | |
|                     this.transferMode = transferMode.ToString();
 | |
|                     this.checkTransferModeFunc = new Func<string, bool>(this.CheckTransferMode);
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         protected internal override HttpResponseMessage AcceptWebSocket(HttpRequestMessage request, CancellationToken cancellationToken)
 | |
|         {
 | |
|             if (!CheckHttpHeader(request, WebSocketHelper.SecWebSocketVersion, this.checkVersionFunc))
 | |
|             {
 | |
|                 return GetUpgradeRequiredResponseMessageWithVersion(request, this.currentVersion);
 | |
|             }
 | |
| 
 | |
|             if (this.needToCheckContentType)
 | |
|             {
 | |
|                 if (!CheckHttpHeader(request, WebSocketTransportSettings.SoapContentTypeHeader, this.checkContentTypeFunc))
 | |
|                 {
 | |
|                     return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request); 
 | |
|                 }
 | |
| 
 | |
|                 if (this.needToCheckTransferMode && !CheckHttpHeader(request, WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.checkTransferModeFunc))
 | |
|                 {
 | |
|                     return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request);
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             HttpResponseMessage response = GetWebSocketAcceptedResponseMessage(request);
 | |
| 
 | |
|             SubprotocolParseResult subprotocolParseResult = ParseSubprotocolValues(request);
 | |
|             if (subprotocolParseResult.HeaderFound)
 | |
|             {
 | |
|                 if (!subprotocolParseResult.HeaderValid)
 | |
|                 {
 | |
|                     return GetBadRequestResponseMessage(request);
 | |
|                 }
 | |
| 
 | |
|                 string negotiatedProtocol = null;
 | |
| 
 | |
|                 // match client protocols vs server protocol
 | |
|                 foreach (string protocol in subprotocolParseResult.ParsedSubprotocols)
 | |
|                 {
 | |
|                     if (string.Compare(protocol, this.subProtocol, StringComparison.OrdinalIgnoreCase) == 0)
 | |
|                     {
 | |
|                         negotiatedProtocol = protocol;
 | |
|                         break;
 | |
|                     }
 | |
|                 }
 | |
| 
 | |
|                 if (negotiatedProtocol == null)
 | |
|                 {
 | |
|                     FxTrace.Exception.AsWarning(new WebException(
 | |
|                         SR.GetString(SR.WebSocketInvalidProtocolNotInClientList, this.subProtocol, string.Join(", ", subprotocolParseResult.ParsedSubprotocols))));
 | |
| 
 | |
|                     return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
 | |
|                 }
 | |
| 
 | |
|                 // set response header
 | |
|                 response.Headers.Remove(WebSocketHelper.SecWebSocketProtocol);
 | |
|                 if (negotiatedProtocol != string.Empty)
 | |
|                 {
 | |
|                     response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, negotiatedProtocol);
 | |
|                 }
 | |
|             }
 | |
|             else
 | |
|             {
 | |
|                 if (!string.IsNullOrEmpty(this.subProtocol))
 | |
|                 {
 | |
|                     FxTrace.Exception.AsWarning(new WebException(
 | |
|                         SR.GetString(SR.WebSocketInvalidProtocolNoHeader, this.subProtocol, WebSocketHelper.SecWebSocketProtocol)));
 | |
| 
 | |
|                     return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             return response;
 | |
|         }
 | |
| 
 | |
|         static SubprotocolParseResult ParseSubprotocolValues(HttpRequestMessage request)
 | |
|         {
 | |
|             Fx.Assert(request != null, "request should not be null");
 | |
|             IEnumerable<string> clientProtocols = null;
 | |
| 
 | |
|             if (request.Headers.TryGetValues(WebSocketHelper.SecWebSocketProtocol, out clientProtocols))
 | |
|             {
 | |
|                 List<string> tokenList = new List<string>();
 | |
| 
 | |
|                 // We may have multiple subprotocol header in the response. We will build up a list with all the subprotocol values.
 | |
|                 // There might be duplicated ones inside the list, but it doesn't matter since we will always take the first matching value.
 | |
|                 foreach (string headerValue in clientProtocols)
 | |
|                 {
 | |
|                     List<string> protocolList;
 | |
|                     if (WebSocketHelper.TryParseSubProtocol(headerValue, out protocolList))
 | |
|                     {
 | |
|                         tokenList.AddRange(protocolList);
 | |
|                     }
 | |
|                     else
 | |
|                     {
 | |
|                         return SubprotocolParseResult.HeaderInvalid;
 | |
|                     }
 | |
|                 }
 | |
| 
 | |
|                 // If this method returns true, we should ensure that clientProtocols always contains at least one entry
 | |
|                 if (tokenList.Count == 0)
 | |
|                 {
 | |
|                     tokenList.Add(string.Empty);
 | |
|                 }
 | |
| 
 | |
|                 return new SubprotocolParseResult(true, true, tokenList);
 | |
|             }
 | |
| 
 | |
|             return SubprotocolParseResult.HeaderNotFound;
 | |
|         }
 | |
| 
 | |
|         static HttpResponseMessage GetUpgradeRequiredResponseMessageWithSubProtocol(HttpRequestMessage request, string subprotocol)
 | |
|         {
 | |
|             HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
 | |
|             if (!string.IsNullOrEmpty(subprotocol))
 | |
|             {
 | |
|                 response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, subprotocol);
 | |
|             }
 | |
| 
 | |
|             return response;
 | |
|         }
 | |
| 
 | |
|         static HttpResponseMessage GetUpgradeRequiredResponseMessageWithVersion(HttpRequestMessage request, string version)
 | |
|         {
 | |
|             HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
 | |
|             response.Headers.Add(WebSocketHelper.SecWebSocketVersion, version);
 | |
| 
 | |
|             return response;
 | |
|         }
 | |
| 
 | |
|         static bool CheckHttpHeader(HttpRequestMessage request, string header, Func<string, bool> validator)
 | |
|         {
 | |
|             Fx.Assert(request != null, "request should not be null.");
 | |
|             Fx.Assert(header != null, "header should not be null.");
 | |
|             Fx.Assert(validator != null, "validator should not be null.");
 | |
| 
 | |
|             IEnumerable<string> headerValues;
 | |
|             if (!request.Headers.TryGetValues(header, out headerValues))
 | |
|             {
 | |
|                 return false;
 | |
|             }
 | |
| 
 | |
|             bool isValid = false;
 | |
|             foreach (string headerValue in headerValues)
 | |
|             {
 | |
|                 if (headerValue != null)
 | |
|                 {
 | |
|                     isValid = validator(headerValue.Trim());
 | |
|                     if (!isValid)
 | |
|                     {
 | |
|                         return false;
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             return true;
 | |
|         }
 | |
| 
 | |
|         bool CheckVersion(string headerValue)
 | |
|         {
 | |
|             Fx.Assert(headerValue != null, "headerValue should not be null.");
 | |
|             return headerValue == this.currentVersion;
 | |
|         }
 | |
| 
 | |
|         bool CheckContentType(string headerValue)
 | |
|         {
 | |
|             Fx.Assert(headerValue != null, "headerValue should not be null.");
 | |
|             return this.encoder.IsContentTypeSupported(headerValue);
 | |
|         }
 | |
| 
 | |
|         bool CheckTransferMode(string headerValue)
 | |
|         {
 | |
|             Fx.Assert(headerValue != null, "headerValue should not be null.");
 | |
|             return headerValue.Equals(this.transferMode, StringComparison.OrdinalIgnoreCase);
 | |
|         }
 | |
|         
 | |
|         HttpResponseMessage GetBadRequestResponseMessageWithContentTypeAndTransfermode(HttpRequestMessage request)
 | |
|         {
 | |
|             Fx.Assert(this.needToCheckContentType, "needToCheckContentType should be true.");
 | |
|             HttpResponseMessage response = GetBadRequestResponseMessage(request);
 | |
|             response.Headers.Add(WebSocketTransportSettings.SoapContentTypeHeader, this.encoder.ContentType);
 | |
|             if (this.needToCheckTransferMode)
 | |
|             {
 | |
|                 response.Headers.Add(WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.transferMode.ToString());
 | |
|             }
 | |
| 
 | |
|             return response;
 | |
|         }
 | |
| 
 | |
|         struct SubprotocolParseResult
 | |
|         {
 | |
|             public static readonly SubprotocolParseResult HeaderInvalid = new SubprotocolParseResult(true, false, null);
 | |
|             public static readonly SubprotocolParseResult HeaderNotFound = new SubprotocolParseResult(false, false, null);
 | |
| 
 | |
|             bool headerFound;
 | |
|             bool headerValid;
 | |
|             IEnumerable<string> parsedSubprotocols;
 | |
| 
 | |
|             public SubprotocolParseResult(bool headerFound, bool headerValid, IEnumerable<string> parsedSubprotocols)
 | |
|             {
 | |
|                 this.headerFound = headerFound;
 | |
|                 this.headerValid = headerValid;
 | |
|                 this.parsedSubprotocols = parsedSubprotocols;
 | |
|             }
 | |
| 
 | |
|             public bool HeaderFound
 | |
|             {
 | |
|                 get { return this.headerFound; }
 | |
|             }
 | |
| 
 | |
|             public bool HeaderValid
 | |
|             {
 | |
|                 get { return this.headerValid; }
 | |
|             }
 | |
| 
 | |
|             public IEnumerable<string> ParsedSubprotocols
 | |
|             {
 | |
|                 get { return this.parsedSubprotocols; }
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 |