//------------------------------------------------------------------------------
// <copyright file="AspNetWebSocketManager.cs" company="Microsoft">
//     Copyright (c) Microsoft Corporation.  All rights reserved.
// </copyright>
//------------------------------------------------------------------------------

namespace System.Web.WebSockets {
    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Threading.Tasks;
    using System.Web.Util;

    // Keeps track of AspNetWebSocket instances so that they can be aborted en masse,
    // such as in the case of an AppDomain shutdown.

    internal sealed class AspNetWebSocketManager {

        public static readonly AspNetWebSocketManager Current = new AspNetWebSocketManager(PerfCounters.Instance);

        private bool _aborted;
        internal readonly HashSet<IAsyncAbortableWebSocket> _activeSockets = new HashSet<IAsyncAbortableWebSocket>(); // internal only for unit testing purposes
        private readonly IPerfCounters _perfCounters;

        internal AspNetWebSocketManager(IPerfCounters perfCounters) {
            _perfCounters = perfCounters;
        }

        public int ActiveSocketCount {
            get {
                // We acquire a full lock when reading the count, similar to how the collections
                // in the System.Collections.Concurrent namespace operate.
                lock (_activeSockets) {
                    return _activeSockets.Count;
                }
            }
        }

        // Calls Abort() on each tracked socket, then blocks until all have been aborted
        public void AbortAllAndWait() {
            // Make a copy so we're not iterating over the original collection asynchronously;
            // keep the lock for as short a duration as possible.
            IAsyncAbortableWebSocket[] sockets;
            lock (_activeSockets) {
                _aborted = true;
                sockets = _activeSockets.ToArray();
            }

            Task[] abortTasks = Array.ConvertAll(sockets, socket => socket.AbortAsync());
            Task.WaitAll(abortTasks);
        }

        // Begins tracking a socket, calling Abort() if there was an earlier call to AbortAll()
        public void Add(IAsyncAbortableWebSocket webSocket) {
            int activeSocketCount;
            bool shouldAbort;

            // keep the lock for as short a period as possible
            lock (_activeSockets) {
                _activeSockets.Add(webSocket);
                activeSocketCount = _activeSockets.Count;
                shouldAbort = _aborted;
            }

            // perform any additional operations outside the lock
            _perfCounters.SetCounter(AppPerfCounter.REQUESTS_EXECUTING_WEBSOCKETS, activeSocketCount);
            if (shouldAbort) {
                webSocket.AbortAsync(); // don't care about the result of the abort at the present time
            }
        }

        // Stops tracking a socket
        public void Remove(IAsyncAbortableWebSocket webSocket) {
            int activeSocketCount;

            // keep the lock for as short a period as possible
            lock (_activeSockets) {
                _activeSockets.Remove(webSocket);
                activeSocketCount = _activeSockets.Count;
            }

            // perform any additional operations outside the lock
            _perfCounters.SetCounter(AppPerfCounter.REQUESTS_EXECUTING_WEBSOCKETS, activeSocketCount);
        }

    }
}