// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace System.Linq
{
    public static partial class AsyncEnumerable
    {
        static IAsyncEnumerable<T> Create<T>(Func<IAsyncEnumerator<T>> getEnumerator)
        {
            return new AnonymousAsyncEnumerable<T>(getEnumerator);
        }

        class AnonymousAsyncEnumerable<T> : IAsyncEnumerable<T>
        {
            Func<IAsyncEnumerator<T>> getEnumerator;

            public AnonymousAsyncEnumerable(Func<IAsyncEnumerator<T>> getEnumerator)
            {
                this.getEnumerator = getEnumerator;
            }

            public IAsyncEnumerator<T> GetEnumerator()
            {
                return getEnumerator();
            }
        }

        static IAsyncEnumerator<T> Create<T>(Func<CancellationToken, Task<bool>> moveNext, Func<T> current, Action dispose)
        {
            return new AnonymousAsyncEnumerator<T>(moveNext, current, dispose);
        }

        static IAsyncEnumerator<T> Create<T>(Func<CancellationToken, TaskCompletionSource<bool>, Task<bool>> moveNext, Func<T> current, Action dispose)
        {
            var self = default(IAsyncEnumerator<T>);
            self = new AnonymousAsyncEnumerator<T>(
                ct =>
                {
                    var tcs = new TaskCompletionSource<bool>();

                    var stop = new Action(() =>
                    {
                        self.Dispose();
                        tcs.TrySetCanceled();
                    });

                    var ctr = ct.Register(stop);

                    var res = moveNext(ct, tcs).Finally(ctr.Dispose);
                    return res;
                },
                current,
                dispose
            );
            return self;
        }

        class AnonymousAsyncEnumerator<T> : IAsyncEnumerator<T>
        {
            private readonly Func<CancellationToken, Task<bool>> _moveNext;
            private readonly Func<T> _current;
            private readonly Action _dispose;
            private bool _disposed;

            public AnonymousAsyncEnumerator(Func<CancellationToken, Task<bool>> moveNext, Func<T> current, Action dispose)
            {
                _moveNext = moveNext;
                _current = current;
                _dispose = dispose;
            }

            public Task<bool> MoveNext(CancellationToken cancellationToken)
            {
                if (_disposed)
                    return TaskExt.Return(false, CancellationToken.None);

                return _moveNext(cancellationToken);
            }

            public T Current
            {
                get
                {
                    return _current();
                }
            }

            public void Dispose()
            {
                if (!_disposed)
                {
                    _disposed = true;
                    _dispose();
                }
            }
        }

        public static IAsyncEnumerable<TValue> Return<TValue>(TValue value)
        {
            return new[] { value }.ToAsyncEnumerable();
        }

        public static IAsyncEnumerable<TValue> Throw<TValue>(Exception exception)
        {
            if (exception == null)
                throw new ArgumentNullException("exception");

            return Create(() => Create<TValue>(
                ct => TaskExt.Throw<bool>(exception, ct),
                () => { throw new InvalidOperationException(); },
                () => { })
            );
        }

        public static IAsyncEnumerable<TValue> Never<TValue>()
        {
            return Create(() => Create<TValue>(
                (ct, tcs) => tcs.Task,
                () => { throw new InvalidOperationException(); },
                () => { })
            );
        }

        public static IAsyncEnumerable<TValue> Empty<TValue>()
        {
            return Create(() => Create<TValue>(
                ct => TaskExt.Return(false, ct),
                () => { throw new InvalidOperationException(); },
                () => { })
            );
        }

        public static IAsyncEnumerable<int> Range(int start, int count)
        {
            if (count < 0)
                throw new ArgumentOutOfRangeException("count");

            return Enumerable.Range(start, count).ToAsyncEnumerable();
        }

        public static IAsyncEnumerable<TResult> Repeat<TResult>(TResult element, int count)
        {
            if (count < 0)
                throw new ArgumentOutOfRangeException("count");

            return Enumerable.Repeat(element, count).ToAsyncEnumerable();
        }

        public static IAsyncEnumerable<TResult> Repeat<TResult>(TResult element)
        {
            return Create(() =>
            {
                return Create(
                    ct => TaskExt.Return(true, ct),
                    () => element,
                    () => { }
                );
            });
        }

        public static IAsyncEnumerable<TSource> Defer<TSource>(Func<IAsyncEnumerable<TSource>> factory)
        {
            if (factory == null)
                throw new ArgumentNullException("factory");

            return Create(() => factory().GetEnumerator());
        }

        public static IAsyncEnumerable<TResult> Generate<TState, TResult>(TState initialState, Func<TState, bool> condition, Func<TState, TState> iterate, Func<TState, TResult> resultSelector)
        {
            if (condition == null)
                throw new ArgumentNullException("condition");
            if (iterate == null)
                throw new ArgumentNullException("iterate");
            if (resultSelector == null)
                throw new ArgumentNullException("resultSelector");

            return Create(() =>
            {
                var i = initialState;
                var started = false;
                var current = default(TResult);

                return Create(
                    ct =>
                    {
                        var b = false;
                        try
                        {
                            if (started)
                                i = iterate(i);

                            b = condition(i);

                            if (b)
                                current = resultSelector(i);
                        }
                        catch (Exception ex)
                        {
                            return TaskExt.Throw<bool>(ex, ct);
                        }

                        if (!b)
                            return TaskExt.Return(false, ct);

                        if (!started)
                            started = true;

                        return TaskExt.Return(true, ct);
                    },
                    () => current,
                    () => { }
                );
            });
        }

        public static IAsyncEnumerable<TSource> Using<TSource, TResource>(Func<TResource> resourceFactory, Func<TResource, IAsyncEnumerable<TSource>> enumerableFactory) where TResource : IDisposable
        {
            if (resourceFactory == null)
                throw new ArgumentNullException("resourceFactory");
            if (enumerableFactory == null)
                throw new ArgumentNullException("enumerableFactory");

            return Create(() =>
            {
                var resource = resourceFactory();
                var e = default(IAsyncEnumerator<TSource>);

                try
                {
                    e = enumerableFactory(resource).GetEnumerator();
                }
                catch (Exception)
                {
                    resource.Dispose();
                    throw;
                }

                var cts = new CancellationTokenDisposable();
                var d = new CompositeDisposable(cts, resource, e);

                var current = default(TSource);

                return Create(
                    (ct, tcs) =>
                    {
                        e.MoveNext(cts.Token).ContinueWith(t =>
                        {
                            t.Handle(tcs,
                                res =>
                                {
                                    if (res)
                                    {
                                        current = e.Current;
                                        tcs.TrySetResult(true);
                                    }
                                    else
                                    {
                                        d.Dispose();
                                        tcs.TrySetResult(false);
                                    }
                                },
                                ex =>
                                {
                                    d.Dispose();
                                    tcs.TrySetException(ex);
                                }
                            );
                        });

                        return tcs.Task;
                    },
                    () => current,
                    d.Dispose
                );
            });
        }
    }
}