// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { static IAsyncEnumerable Create(Func> getEnumerator) { return new AnonymousAsyncEnumerable(getEnumerator); } class AnonymousAsyncEnumerable : IAsyncEnumerable { Func> getEnumerator; public AnonymousAsyncEnumerable(Func> getEnumerator) { this.getEnumerator = getEnumerator; } public IAsyncEnumerator GetEnumerator() { return getEnumerator(); } } static IAsyncEnumerator Create(Func> moveNext, Func current, Action dispose) { return new AnonymousAsyncEnumerator(moveNext, current, dispose); } static IAsyncEnumerator Create(Func, Task> moveNext, Func current, Action dispose) { var self = default(IAsyncEnumerator); self = new AnonymousAsyncEnumerator( ct => { var tcs = new TaskCompletionSource(); var stop = new Action(() => { self.Dispose(); tcs.TrySetCanceled(); }); var ctr = ct.Register(stop); var res = moveNext(ct, tcs).Finally(ctr.Dispose); return res; }, current, dispose ); return self; } class AnonymousAsyncEnumerator : IAsyncEnumerator { private readonly Func> _moveNext; private readonly Func _current; private readonly Action _dispose; private bool _disposed; public AnonymousAsyncEnumerator(Func> moveNext, Func current, Action dispose) { _moveNext = moveNext; _current = current; _dispose = dispose; } public Task MoveNext(CancellationToken cancellationToken) { if (_disposed) return TaskExt.Return(false, CancellationToken.None); return _moveNext(cancellationToken); } public T Current { get { return _current(); } } public void Dispose() { if (!_disposed) { _disposed = true; _dispose(); } } } public static IAsyncEnumerable Return(TValue value) { return new[] { value }.ToAsyncEnumerable(); } public static IAsyncEnumerable Throw(Exception exception) { if (exception == null) throw new ArgumentNullException("exception"); return Create(() => Create( ct => TaskExt.Throw(exception, ct), () => { throw new InvalidOperationException(); }, () => { }) ); } public static IAsyncEnumerable Never() { return Create(() => Create( (ct, tcs) => tcs.Task, () => { throw new InvalidOperationException(); }, () => { }) ); } public static IAsyncEnumerable Empty() { return Create(() => Create( ct => TaskExt.Return(false, ct), () => { throw new InvalidOperationException(); }, () => { }) ); } public static IAsyncEnumerable Range(int start, int count) { if (count < 0) throw new ArgumentOutOfRangeException("count"); return Enumerable.Range(start, count).ToAsyncEnumerable(); } public static IAsyncEnumerable Repeat(TResult element, int count) { if (count < 0) throw new ArgumentOutOfRangeException("count"); return Enumerable.Repeat(element, count).ToAsyncEnumerable(); } public static IAsyncEnumerable Repeat(TResult element) { return Create(() => { return Create( ct => TaskExt.Return(true, ct), () => element, () => { } ); }); } public static IAsyncEnumerable Defer(Func> factory) { if (factory == null) throw new ArgumentNullException("factory"); return Create(() => factory().GetEnumerator()); } public static IAsyncEnumerable Generate(TState initialState, Func condition, Func iterate, Func resultSelector) { if (condition == null) throw new ArgumentNullException("condition"); if (iterate == null) throw new ArgumentNullException("iterate"); if (resultSelector == null) throw new ArgumentNullException("resultSelector"); return Create(() => { var i = initialState; var started = false; var current = default(TResult); return Create( ct => { var b = false; try { if (started) i = iterate(i); b = condition(i); if (b) current = resultSelector(i); } catch (Exception ex) { return TaskExt.Throw(ex, ct); } if (!b) return TaskExt.Return(false, ct); if (!started) started = true; return TaskExt.Return(true, ct); }, () => current, () => { } ); }); } public static IAsyncEnumerable Using(Func resourceFactory, Func> enumerableFactory) where TResource : IDisposable { if (resourceFactory == null) throw new ArgumentNullException("resourceFactory"); if (enumerableFactory == null) throw new ArgumentNullException("enumerableFactory"); return Create(() => { var resource = resourceFactory(); var e = default(IAsyncEnumerator); try { e = enumerableFactory(resource).GetEnumerator(); } catch (Exception) { resource.Dispose(); throw; } var cts = new CancellationTokenDisposable(); var d = new CompositeDisposable(cts, resource, e); var current = default(TSource); return Create( (ct, tcs) => { e.MoveNext(cts.Token).ContinueWith(t => { t.Handle(tcs, res => { if (res) { current = e.Current; tcs.TrySetResult(true); } else { d.Dispose(); tcs.TrySetResult(false); } }, ex => { d.Dispose(); tcs.TrySetException(ex); } ); }); return tcs.Task; }, () => current, d.Dispose ); }); } } }