// Copyright Epic Games, Inc. All Rights Reserved. #pragma once /** * ThreadGroupPrefixSum(Value, GroupThreadIndex [, GroupSumOut]) * * Calculates the prefix sum (and optionally group sum) of a given value across threads of a thread group. * * EXAMPLE USES: * float GroupSum; * float PrefixSum = ThreadGroupPrefixSum(ValueFloat, GroupThreadIndex, GroupSum); * * int PrefixSum = ThreadGroupPrefixSum(ValueInt, GroupThreadIndex); * * NOTES: * - PrefixSum means the sum of the value for all threads whose group thread index is LESS than the current * - Only scalar types are currently supported * - All threads in the group must be active when calling this method; it cannot be used in a branch (or after an early return) * that doesn't include the entire group. */ #ifndef NUM_THREADS_PER_GROUP #error NUM_THREADS_PER_GROUP must be defined, and must be equal to the thread group size of the caller of ThreadGroupPrefixSum #endif // re-use the same groupshared memory, in case the caller utilizes multiple overloads groupshared uint ThreadGroupPrefixSumWorkspace[2][NUM_THREADS_PER_GROUP]; #define DECLARE_THREAD_GROUP_PREFIX_SUM(ValType, CastToUint, CastFromUint) \ ValType ThreadGroupPrefixSum(ValType Value, uint GroupThreadIndex, inout ValType GroupSum) \ { \ uint Curr = 0; \ ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex] = CastToUint(Value); \ GroupMemoryBarrierWithGroupSync(); \ for (uint i = 1U; i <= (NUM_THREADS_PER_GROUP / 2U); i *= 2U) \ { \ const uint Next = 1U - Curr; \ if (GroupThreadIndex < i) \ { \ ThreadGroupPrefixSumWorkspace[Next][GroupThreadIndex] = \ ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex]; \ } \ else \ { \ ThreadGroupPrefixSumWorkspace[Next][GroupThreadIndex] = CastToUint( \ CastFromUint(ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex]) + \ CastFromUint(ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex - i]) \ ); \ } \ Curr = Next; \ GroupMemoryBarrierWithGroupSync(); \ } \ GroupSum = CastFromUint(ThreadGroupPrefixSumWorkspace[Curr][NUM_THREADS_PER_GROUP - 1]); \ return (GroupThreadIndex == 0U) ? CastFromUint(0) : CastFromUint(ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex - 1]); \ } \ ValType ThreadGroupPrefixSum(ValType Value, uint GroupThreadIndex) \ { \ ValType GroupSum; \ return ThreadGroupPrefixSum(Value, GroupThreadIndex, GroupSum); \ } DECLARE_THREAD_GROUP_PREFIX_SUM(float, asuint, asfloat) DECLARE_THREAD_GROUP_PREFIX_SUM(int, uint, int) DECLARE_THREAD_GROUP_PREFIX_SUM(uint, uint, uint) #undef DECLARE_THREAD_GROUP_PREFIX_SUM