// <copyright>
// Copyright (c) Microsoft Corporation.  All rights reserved.
// </copyright>

namespace System.ServiceModel.Channels
{
    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.Linq;
    using System.Net;
    using System.Net.Http;
    using System.Net.WebSockets;
    using System.Runtime;
    using System.Threading;

    class DefaultWebSocketConnectionHandler : WebSocketConnectionHandler
    {
        string currentVersion;
        string subProtocol;
        MessageEncoder encoder;
        string transferMode;
        bool needToCheckContentType;
        bool needToCheckTransferMode;
        Func<string, bool> checkVersionFunc;
        Func<string, bool> checkContentTypeFunc;
        Func<string, bool> checkTransferModeFunc;

        public DefaultWebSocketConnectionHandler(string subProtocol, string currentVersion, MessageVersion messageVersion, MessageEncoderFactory encoderFactory, TransferMode transferMode)
        {
            this.subProtocol = subProtocol;
            this.currentVersion = currentVersion;
            this.checkVersionFunc = new Func<string, bool>(this.CheckVersion);

            if (messageVersion != MessageVersion.None)
            {
                this.needToCheckContentType = true;
                this.encoder = encoderFactory.CreateSessionEncoder();
                this.checkContentTypeFunc = new Func<string, bool>(this.CheckContentType);

                if (encoderFactory is BinaryMessageEncoderFactory)
                {
                    this.needToCheckTransferMode = true;
                    this.transferMode = transferMode.ToString();
                    this.checkTransferModeFunc = new Func<string, bool>(this.CheckTransferMode);
                }
            }
        }

        protected internal override HttpResponseMessage AcceptWebSocket(HttpRequestMessage request, CancellationToken cancellationToken)
        {
            if (!CheckHttpHeader(request, WebSocketHelper.SecWebSocketVersion, this.checkVersionFunc))
            {
                return GetUpgradeRequiredResponseMessageWithVersion(request, this.currentVersion);
            }

            if (this.needToCheckContentType)
            {
                if (!CheckHttpHeader(request, WebSocketTransportSettings.SoapContentTypeHeader, this.checkContentTypeFunc))
                {
                    return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request); 
                }

                if (this.needToCheckTransferMode && !CheckHttpHeader(request, WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.checkTransferModeFunc))
                {
                    return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request);
                }
            }

            HttpResponseMessage response = GetWebSocketAcceptedResponseMessage(request);

            SubprotocolParseResult subprotocolParseResult = ParseSubprotocolValues(request);
            if (subprotocolParseResult.HeaderFound)
            {
                if (!subprotocolParseResult.HeaderValid)
                {
                    return GetBadRequestResponseMessage(request);
                }

                string negotiatedProtocol = null;

                // match client protocols vs server protocol
                foreach (string protocol in subprotocolParseResult.ParsedSubprotocols)
                {
                    if (string.Compare(protocol, this.subProtocol, StringComparison.OrdinalIgnoreCase) == 0)
                    {
                        negotiatedProtocol = protocol;
                        break;
                    }
                }

                if (negotiatedProtocol == null)
                {
                    FxTrace.Exception.AsWarning(new WebException(
                        SR.GetString(SR.WebSocketInvalidProtocolNotInClientList, this.subProtocol, string.Join(", ", subprotocolParseResult.ParsedSubprotocols))));

                    return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
                }

                // set response header
                response.Headers.Remove(WebSocketHelper.SecWebSocketProtocol);
                if (negotiatedProtocol != string.Empty)
                {
                    response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, negotiatedProtocol);
                }
            }
            else
            {
                if (!string.IsNullOrEmpty(this.subProtocol))
                {
                    FxTrace.Exception.AsWarning(new WebException(
                        SR.GetString(SR.WebSocketInvalidProtocolNoHeader, this.subProtocol, WebSocketHelper.SecWebSocketProtocol)));

                    return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
                }
            }

            return response;
        }

        static SubprotocolParseResult ParseSubprotocolValues(HttpRequestMessage request)
        {
            Fx.Assert(request != null, "request should not be null");
            IEnumerable<string> clientProtocols = null;

            if (request.Headers.TryGetValues(WebSocketHelper.SecWebSocketProtocol, out clientProtocols))
            {
                List<string> tokenList = new List<string>();

                // We may have multiple subprotocol header in the response. We will build up a list with all the subprotocol values.
                // There might be duplicated ones inside the list, but it doesn't matter since we will always take the first matching value.
                foreach (string headerValue in clientProtocols)
                {
                    List<string> protocolList;
                    if (WebSocketHelper.TryParseSubProtocol(headerValue, out protocolList))
                    {
                        tokenList.AddRange(protocolList);
                    }
                    else
                    {
                        return SubprotocolParseResult.HeaderInvalid;
                    }
                }

                // If this method returns true, we should ensure that clientProtocols always contains at least one entry
                if (tokenList.Count == 0)
                {
                    tokenList.Add(string.Empty);
                }

                return new SubprotocolParseResult(true, true, tokenList);
            }

            return SubprotocolParseResult.HeaderNotFound;
        }

        static HttpResponseMessage GetUpgradeRequiredResponseMessageWithSubProtocol(HttpRequestMessage request, string subprotocol)
        {
            HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
            if (!string.IsNullOrEmpty(subprotocol))
            {
                response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, subprotocol);
            }

            return response;
        }

        static HttpResponseMessage GetUpgradeRequiredResponseMessageWithVersion(HttpRequestMessage request, string version)
        {
            HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
            response.Headers.Add(WebSocketHelper.SecWebSocketVersion, version);

            return response;
        }

        static bool CheckHttpHeader(HttpRequestMessage request, string header, Func<string, bool> validator)
        {
            Fx.Assert(request != null, "request should not be null.");
            Fx.Assert(header != null, "header should not be null.");
            Fx.Assert(validator != null, "validator should not be null.");

            IEnumerable<string> headerValues;
            if (!request.Headers.TryGetValues(header, out headerValues))
            {
                return false;
            }

            bool isValid = false;
            foreach (string headerValue in headerValues)
            {
                if (headerValue != null)
                {
                    isValid = validator(headerValue.Trim());
                    if (!isValid)
                    {
                        return false;
                    }
                }
            }

            return true;
        }

        bool CheckVersion(string headerValue)
        {
            Fx.Assert(headerValue != null, "headerValue should not be null.");
            return headerValue == this.currentVersion;
        }

        bool CheckContentType(string headerValue)
        {
            Fx.Assert(headerValue != null, "headerValue should not be null.");
            return this.encoder.IsContentTypeSupported(headerValue);
        }

        bool CheckTransferMode(string headerValue)
        {
            Fx.Assert(headerValue != null, "headerValue should not be null.");
            return headerValue.Equals(this.transferMode, StringComparison.OrdinalIgnoreCase);
        }
        
        HttpResponseMessage GetBadRequestResponseMessageWithContentTypeAndTransfermode(HttpRequestMessage request)
        {
            Fx.Assert(this.needToCheckContentType, "needToCheckContentType should be true.");
            HttpResponseMessage response = GetBadRequestResponseMessage(request);
            response.Headers.Add(WebSocketTransportSettings.SoapContentTypeHeader, this.encoder.ContentType);
            if (this.needToCheckTransferMode)
            {
                response.Headers.Add(WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.transferMode.ToString());
            }

            return response;
        }

        struct SubprotocolParseResult
        {
            public static readonly SubprotocolParseResult HeaderInvalid = new SubprotocolParseResult(true, false, null);
            public static readonly SubprotocolParseResult HeaderNotFound = new SubprotocolParseResult(false, false, null);

            bool headerFound;
            bool headerValid;
            IEnumerable<string> parsedSubprotocols;

            public SubprotocolParseResult(bool headerFound, bool headerValid, IEnumerable<string> parsedSubprotocols)
            {
                this.headerFound = headerFound;
                this.headerValid = headerValid;
                this.parsedSubprotocols = parsedSubprotocols;
            }

            public bool HeaderFound
            {
                get { return this.headerFound; }
            }

            public bool HeaderValid
            {
                get { return this.headerValid; }
            }

            public IEnumerable<string> ParsedSubprotocols
            {
                get { return this.parsedSubprotocols; }
            }
        }
    }
}