e79aa3c0ed
Former-commit-id: a2155e9bd80020e49e72e86c44da02a8ac0e57a4
268 lines
10 KiB
C#
268 lines
10 KiB
C#
// <copyright>
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// </copyright>
|
|
|
|
namespace System.ServiceModel.Channels
|
|
{
|
|
using System;
|
|
using System.Collections;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using System.Net;
|
|
using System.Net.Http;
|
|
using System.Net.WebSockets;
|
|
using System.Runtime;
|
|
using System.Threading;
|
|
|
|
class DefaultWebSocketConnectionHandler : WebSocketConnectionHandler
|
|
{
|
|
string currentVersion;
|
|
string subProtocol;
|
|
MessageEncoder encoder;
|
|
string transferMode;
|
|
bool needToCheckContentType;
|
|
bool needToCheckTransferMode;
|
|
Func<string, bool> checkVersionFunc;
|
|
Func<string, bool> checkContentTypeFunc;
|
|
Func<string, bool> checkTransferModeFunc;
|
|
|
|
public DefaultWebSocketConnectionHandler(string subProtocol, string currentVersion, MessageVersion messageVersion, MessageEncoderFactory encoderFactory, TransferMode transferMode)
|
|
{
|
|
this.subProtocol = subProtocol;
|
|
this.currentVersion = currentVersion;
|
|
this.checkVersionFunc = new Func<string, bool>(this.CheckVersion);
|
|
|
|
if (messageVersion != MessageVersion.None)
|
|
{
|
|
this.needToCheckContentType = true;
|
|
this.encoder = encoderFactory.CreateSessionEncoder();
|
|
this.checkContentTypeFunc = new Func<string, bool>(this.CheckContentType);
|
|
|
|
if (encoderFactory is BinaryMessageEncoderFactory)
|
|
{
|
|
this.needToCheckTransferMode = true;
|
|
this.transferMode = transferMode.ToString();
|
|
this.checkTransferModeFunc = new Func<string, bool>(this.CheckTransferMode);
|
|
}
|
|
}
|
|
}
|
|
|
|
protected internal override HttpResponseMessage AcceptWebSocket(HttpRequestMessage request, CancellationToken cancellationToken)
|
|
{
|
|
if (!CheckHttpHeader(request, WebSocketHelper.SecWebSocketVersion, this.checkVersionFunc))
|
|
{
|
|
return GetUpgradeRequiredResponseMessageWithVersion(request, this.currentVersion);
|
|
}
|
|
|
|
if (this.needToCheckContentType)
|
|
{
|
|
if (!CheckHttpHeader(request, WebSocketTransportSettings.SoapContentTypeHeader, this.checkContentTypeFunc))
|
|
{
|
|
return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request);
|
|
}
|
|
|
|
if (this.needToCheckTransferMode && !CheckHttpHeader(request, WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.checkTransferModeFunc))
|
|
{
|
|
return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request);
|
|
}
|
|
}
|
|
|
|
HttpResponseMessage response = GetWebSocketAcceptedResponseMessage(request);
|
|
|
|
SubprotocolParseResult subprotocolParseResult = ParseSubprotocolValues(request);
|
|
if (subprotocolParseResult.HeaderFound)
|
|
{
|
|
if (!subprotocolParseResult.HeaderValid)
|
|
{
|
|
return GetBadRequestResponseMessage(request);
|
|
}
|
|
|
|
string negotiatedProtocol = null;
|
|
|
|
// match client protocols vs server protocol
|
|
foreach (string protocol in subprotocolParseResult.ParsedSubprotocols)
|
|
{
|
|
if (string.Compare(protocol, this.subProtocol, StringComparison.OrdinalIgnoreCase) == 0)
|
|
{
|
|
negotiatedProtocol = protocol;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (negotiatedProtocol == null)
|
|
{
|
|
FxTrace.Exception.AsWarning(new WebException(
|
|
SR.GetString(SR.WebSocketInvalidProtocolNotInClientList, this.subProtocol, string.Join(", ", subprotocolParseResult.ParsedSubprotocols))));
|
|
|
|
return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
|
|
}
|
|
|
|
// set response header
|
|
response.Headers.Remove(WebSocketHelper.SecWebSocketProtocol);
|
|
if (negotiatedProtocol != string.Empty)
|
|
{
|
|
response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, negotiatedProtocol);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (!string.IsNullOrEmpty(this.subProtocol))
|
|
{
|
|
FxTrace.Exception.AsWarning(new WebException(
|
|
SR.GetString(SR.WebSocketInvalidProtocolNoHeader, this.subProtocol, WebSocketHelper.SecWebSocketProtocol)));
|
|
|
|
return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
|
|
}
|
|
}
|
|
|
|
return response;
|
|
}
|
|
|
|
static SubprotocolParseResult ParseSubprotocolValues(HttpRequestMessage request)
|
|
{
|
|
Fx.Assert(request != null, "request should not be null");
|
|
IEnumerable<string> clientProtocols = null;
|
|
|
|
if (request.Headers.TryGetValues(WebSocketHelper.SecWebSocketProtocol, out clientProtocols))
|
|
{
|
|
List<string> tokenList = new List<string>();
|
|
|
|
// We may have multiple subprotocol header in the response. We will build up a list with all the subprotocol values.
|
|
// There might be duplicated ones inside the list, but it doesn't matter since we will always take the first matching value.
|
|
foreach (string headerValue in clientProtocols)
|
|
{
|
|
List<string> protocolList;
|
|
if (WebSocketHelper.TryParseSubProtocol(headerValue, out protocolList))
|
|
{
|
|
tokenList.AddRange(protocolList);
|
|
}
|
|
else
|
|
{
|
|
return SubprotocolParseResult.HeaderInvalid;
|
|
}
|
|
}
|
|
|
|
// If this method returns true, we should ensure that clientProtocols always contains at least one entry
|
|
if (tokenList.Count == 0)
|
|
{
|
|
tokenList.Add(string.Empty);
|
|
}
|
|
|
|
return new SubprotocolParseResult(true, true, tokenList);
|
|
}
|
|
|
|
return SubprotocolParseResult.HeaderNotFound;
|
|
}
|
|
|
|
static HttpResponseMessage GetUpgradeRequiredResponseMessageWithSubProtocol(HttpRequestMessage request, string subprotocol)
|
|
{
|
|
HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
|
|
if (!string.IsNullOrEmpty(subprotocol))
|
|
{
|
|
response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, subprotocol);
|
|
}
|
|
|
|
return response;
|
|
}
|
|
|
|
static HttpResponseMessage GetUpgradeRequiredResponseMessageWithVersion(HttpRequestMessage request, string version)
|
|
{
|
|
HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
|
|
response.Headers.Add(WebSocketHelper.SecWebSocketVersion, version);
|
|
|
|
return response;
|
|
}
|
|
|
|
static bool CheckHttpHeader(HttpRequestMessage request, string header, Func<string, bool> validator)
|
|
{
|
|
Fx.Assert(request != null, "request should not be null.");
|
|
Fx.Assert(header != null, "header should not be null.");
|
|
Fx.Assert(validator != null, "validator should not be null.");
|
|
|
|
IEnumerable<string> headerValues;
|
|
if (!request.Headers.TryGetValues(header, out headerValues))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
bool isValid = false;
|
|
foreach (string headerValue in headerValues)
|
|
{
|
|
if (headerValue != null)
|
|
{
|
|
isValid = validator(headerValue.Trim());
|
|
if (!isValid)
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool CheckVersion(string headerValue)
|
|
{
|
|
Fx.Assert(headerValue != null, "headerValue should not be null.");
|
|
return headerValue == this.currentVersion;
|
|
}
|
|
|
|
bool CheckContentType(string headerValue)
|
|
{
|
|
Fx.Assert(headerValue != null, "headerValue should not be null.");
|
|
return this.encoder.IsContentTypeSupported(headerValue);
|
|
}
|
|
|
|
bool CheckTransferMode(string headerValue)
|
|
{
|
|
Fx.Assert(headerValue != null, "headerValue should not be null.");
|
|
return headerValue.Equals(this.transferMode, StringComparison.OrdinalIgnoreCase);
|
|
}
|
|
|
|
HttpResponseMessage GetBadRequestResponseMessageWithContentTypeAndTransfermode(HttpRequestMessage request)
|
|
{
|
|
Fx.Assert(this.needToCheckContentType, "needToCheckContentType should be true.");
|
|
HttpResponseMessage response = GetBadRequestResponseMessage(request);
|
|
response.Headers.Add(WebSocketTransportSettings.SoapContentTypeHeader, this.encoder.ContentType);
|
|
if (this.needToCheckTransferMode)
|
|
{
|
|
response.Headers.Add(WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.transferMode.ToString());
|
|
}
|
|
|
|
return response;
|
|
}
|
|
|
|
struct SubprotocolParseResult
|
|
{
|
|
public static readonly SubprotocolParseResult HeaderInvalid = new SubprotocolParseResult(true, false, null);
|
|
public static readonly SubprotocolParseResult HeaderNotFound = new SubprotocolParseResult(false, false, null);
|
|
|
|
bool headerFound;
|
|
bool headerValid;
|
|
IEnumerable<string> parsedSubprotocols;
|
|
|
|
public SubprotocolParseResult(bool headerFound, bool headerValid, IEnumerable<string> parsedSubprotocols)
|
|
{
|
|
this.headerFound = headerFound;
|
|
this.headerValid = headerValid;
|
|
this.parsedSubprotocols = parsedSubprotocols;
|
|
}
|
|
|
|
public bool HeaderFound
|
|
{
|
|
get { return this.headerFound; }
|
|
}
|
|
|
|
public bool HeaderValid
|
|
{
|
|
get { return this.headerValid; }
|
|
}
|
|
|
|
public IEnumerable<string> ParsedSubprotocols
|
|
{
|
|
get { return this.parsedSubprotocols; }
|
|
}
|
|
}
|
|
}
|
|
}
|