wiprog

C# とか,数学とか,社会とか.

C# で同時実行数制御つき ForEachAsync

2020/03/26 Qiita から移行. 現在下記の実装には例外ハンドリングの不具合が確認されていますので使用は推奨しません

同時実行数を抑えながら非同期を走らせる方法として、下記の記事で ForEachAsync 拡張メソッドが紹介されています。 ForEachAsync - 非同期の列挙の方法 Part2

で、これを参考に別の実装をしてみました。 (使い方等は上記のブログ記事を御覧ください)

public static Task ForEachAsync<T>(this IEnumerable<T> source, Func<T, Task> asyncAction, int concurrency,
    CancellationToken cancellationToken = default)
{
    source.ThrowIfArgumentNull(nameof(source));
    asyncAction.ThrowIfArgumentNull(nameof(asyncAction));
    concurrency.ThrowIfArgumentOutOfRange(1, int.MaxValue, nameof(concurrency));

    async Task ForEachInner()
    {
        int throwedCount = 0;

        void OnFault(Exception e)
        {
            Interlocked.Add(ref throwedCount, 1);
            throw e;
        }

        using (var tasks = new TaskSet(concurrency, OnFault))
        {
            foreach (var x in source)
            {
                if (throwedCount > 0) break;
                cancellationToken.ThrowIfCancellationRequested();

                await tasks.AddAsync(x, asyncAction).ConfigureAwait(false);
            }

            await tasks.WhenAll().ConfigureAwait(false);
        }
    }

    return ForEachInner();
}

private sealed class TaskSet : IDisposable
{
    private readonly Task[] _tasks;
    private readonly ConcurrentStack<int> _unusedIndexes;
    private readonly Action<Exception> _faultedAction;
    private readonly SemaphoreSlim _semaphore;

    public TaskSet(int concurrency, Action<Exception> faulted)
    {
        _tasks = new Task[concurrency];
        _unusedIndexes = new ConcurrentStack<int>(Enumerable.Range(0, concurrency));
        _faultedAction = faulted;
        _semaphore = new SemaphoreSlim(concurrency, concurrency);
    }

    public async Task AddAsync<T>(T arg, Func<T, Task> asyncAction)
    {
        await _semaphore.WaitAsync().ConfigureAwait(false);
        if (!_unusedIndexes.TryPop(out int index)) throw new Exception();

        var task = asyncAction(arg).ContinueWith(t =>
        {
            _unusedIndexes.Push(index);
            _semaphore.Release();

            if (t.IsFaulted)
            {
                _faultedAction(t.Exception);
            }
        });

        _tasks[index] = task;
    }

    public Task WhenAll() => Task.WhenAll(_tasks.Where(t => t != null));

    void IDisposable.Dispose() => _semaphore.Dispose();
}

この実装と冒頭のブログとの違いは、生成した Task の管理に TaskSet クラスを導入したことです。
これによって、source の長さ分まで大きくなり続ける List<Task> を使うことなく、長さが concurrency で固定の配列 Task[] を使うことができ、 List.Add 時の拡張等のコストが削減できます。

コードは GitHub にあります。 https://github.com/wipiano/cisis/blob/master/Cisis/Linq/ForEachAsync.cs https://github.com/wipiano/cisis/blob/master/Cisis.Test/Linq/ForEachAsyncTest.cs