// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable Catch(this IAsyncEnumerable source, Func> handler) where TException : Exception { if (source == null) throw new ArgumentNullException("source"); if (handler == null) throw new ArgumentNullException("handler"); return Create(() => { var e = source.GetEnumerator(); var cts = new CancellationTokenDisposable(); var a = new AssignableDisposable { Disposable = e }; var d = new CompositeDisposable(cts, a); var done = false; var f = default(Action, CancellationToken>); f = (tcs, ct) => { if (!done) { e.MoveNext(ct).ContinueWith(t => { t.Handle(tcs, res => { tcs.TrySetResult(res); }, ex => { var err = default(IAsyncEnumerator); try { ex.Flatten().Handle(ex_ => { var exx = ex_ as TException; if (exx != null) { err = handler(exx).GetEnumerator(); return true; } return false; }); } catch (Exception ex2) { tcs.TrySetException(ex2); return; } if (err != null) { e = err; a.Disposable = e; done = true; f(tcs, ct); } } ); }); } else { e.MoveNext(ct).ContinueWith(t => { t.Handle(tcs, res => { tcs.TrySetResult(res); }); }); } }; return Create( (ct, tcs) => { f(tcs, cts.Token); return tcs.Task.UsingEnumerator(a); }, () => e.Current, d.Dispose ); }); } public static IAsyncEnumerable Catch(this IEnumerable> sources) { if (sources == null) throw new ArgumentNullException("sources"); return sources.Catch_(); } public static IAsyncEnumerable Catch(params IAsyncEnumerable[] sources) { if (sources == null) throw new ArgumentNullException("sources"); return sources.Catch_(); } public static IAsyncEnumerable Catch(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw new ArgumentNullException("first"); if (second == null) throw new ArgumentNullException("second"); return new[] { first, second }.Catch_(); } private static IAsyncEnumerable Catch_(this IEnumerable> sources) { return Create(() => { var se = sources.GetEnumerator(); var e = default(IAsyncEnumerator); var cts = new CancellationTokenDisposable(); var a = new AssignableDisposable(); var d = new CompositeDisposable(cts, se, a); var error = default(Exception); var f = default(Action, CancellationToken>); f = (tcs, ct) => { if (e == null) { var b = false; try { b = se.MoveNext(); if (b) e = se.Current.GetEnumerator(); } catch (Exception ex) { tcs.TrySetException(ex); return; } if (!b) { if (error != null) { tcs.TrySetException(error); return; } tcs.TrySetResult(false); return; } error = null; a.Disposable = e; } e.MoveNext(ct).ContinueWith(t => { t.Handle(tcs, res => { tcs.TrySetResult(res); }, ex => { e.Dispose(); e = null; error = ex; f(tcs, ct); } ); }); }; return Create( (ct, tcs) => { f(tcs, cts.Token); return tcs.Task.UsingEnumerator(a); }, () => e.Current, d.Dispose ); }); } public static IAsyncEnumerable Finally(this IAsyncEnumerable source, Action finallyAction) { if (source == null) throw new ArgumentNullException("source"); if (finallyAction == null) throw new ArgumentNullException("finallyAction"); return Create(() => { var e = source.GetEnumerator(); var cts = new CancellationTokenDisposable(); var r = new Disposable(finallyAction); var d = new CompositeDisposable(cts, e, r); var f = default(Action, CancellationToken>); f = (tcs, ct) => { e.MoveNext(ct).ContinueWith(t => { t.Handle(tcs, res => { tcs.TrySetResult(res); }); }); }; return Create( (ct, tcs) => { f(tcs, cts.Token); return tcs.Task.UsingEnumeratorSync(r); }, () => e.Current, d.Dispose ); }); } public static IAsyncEnumerable OnErrorResumeNext(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw new ArgumentNullException("first"); if (second == null) throw new ArgumentNullException("second"); return OnErrorResumeNext_(new[] { first, second }); } public static IAsyncEnumerable OnErrorResumeNext(params IAsyncEnumerable[] sources) { if (sources == null) throw new ArgumentNullException("sources"); return OnErrorResumeNext_(sources); } public static IAsyncEnumerable OnErrorResumeNext(this IEnumerable> sources) { if (sources == null) throw new ArgumentNullException("sources"); return OnErrorResumeNext_(sources); } private static IAsyncEnumerable OnErrorResumeNext_(IEnumerable> sources) { return Create(() => { var se = sources.GetEnumerator(); var e = default(IAsyncEnumerator); var cts = new CancellationTokenDisposable(); var a = new AssignableDisposable(); var d = new CompositeDisposable(cts, se, a); var f = default(Action, CancellationToken>); f = (tcs, ct) => { if (e == null) { var b = false; try { b = se.MoveNext(); if (b) e = se.Current.GetEnumerator(); } catch (Exception ex) { tcs.TrySetException(ex); return; } if (!b) { tcs.TrySetResult(false); return; } a.Disposable = e; } e.MoveNext(ct).ContinueWith(t => { t.Handle(tcs, res => { if (res) { tcs.TrySetResult(true); } else { e.Dispose(); e = null; f(tcs, ct); } }, ex => { e.Dispose(); e = null; f(tcs, ct); } ); }); }; return Create( (ct, tcs) => { f(tcs, cts.Token); return tcs.Task.UsingEnumerator(a); }, () => e.Current, d.Dispose ); }); } public static IAsyncEnumerable Retry(this IAsyncEnumerable source) { if (source == null) throw new ArgumentNullException("source"); return new[] { source }.Repeat().Catch(); } public static IAsyncEnumerable Retry(this IAsyncEnumerable source, int retryCount) { if (source == null) throw new ArgumentNullException("source"); if (retryCount < 0) throw new ArgumentOutOfRangeException("retryCount"); return new[] { source }.Repeat(retryCount).Catch(); } private static IEnumerable Repeat(this IEnumerable source) { while (true) foreach (var item in source) yield return item; } private static IEnumerable Repeat(this IEnumerable source, int count) { for (var i = 0; i < count; i++) foreach (var item in source) yield return item; } } }