@@ -23,102 +23,76 @@ namespace fft
2323{
2424
2525// ---------------------------------- Utils -----------------------------------------------
26- template<typename SharedMemoryAdaptor, typename Scalar>
27- struct exchangeValues
26+
27+ // No need to expose these
28+ namespace impl
2829{
29- static void __call (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
30+ template<typename SharedMemoryAdaptor, typename Scalar>
31+ struct exchangeValues
3032 {
31- const bool topHalf = bool (threadID & stride);
32- // Pack into float vector because ternary operator does not support structs
33- vector <Scalar, 2 > exchanged = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
34- shuffleXor<SharedMemoryAdaptor, vector <Scalar, 2 > >(exchanged, stride, sharedmemAdaptor);
35- if (topHalf)
36- {
37- lo.real (exchanged.x);
38- lo.imag (exchanged.y);
39- }
40- else
33+ static void __call (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
4134 {
42- hi.real (exchanged.x);
43- hi.imag (exchanged.y);
35+ const bool topHalf = bool (threadID & stride);
36+ // Pack into float vector because ternary operator does not support structs
37+ vector <Scalar, 2 > exchanged = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
38+ shuffleXor<SharedMemoryAdaptor, vector <Scalar, 2 > >(exchanged, stride, sharedmemAdaptor);
39+ if (topHalf)
40+ {
41+ lo.real (exchanged.x);
42+ lo.imag (exchanged.y);
43+ }
44+ else
45+ {
46+ hi.real (exchanged.x);
47+ hi.imag (exchanged.y);
48+ }
4449 }
45- }
46- };
47-
48- // Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
49- template <typename scalar_t, uint32_t WorkgroupSize>
50- NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) * WorkgroupSize;
51-
50+ };
5251
53- template<uint32_t N, uint32_t H>
54- enable_if_t<H <= N, uint32_t> bitShiftRightHigher (uint32_t i)
55- {
56- // Highest H bits are numbered N-1 through N - H
57- // N - H is then the middle bit
58- // Lowest bits numbered from 0 through N - H - 1
59- uint32_t low = i & ((1 << (N - H)) - 1 );
60- uint32_t mid = i & (1 << (N - H));
61- uint32_t high = i & ~((1 << (N - H + 1 )) - 1 );
62-
63- high >>= 1 ;
64- mid <<= H - 1 ;
65-
66- return mid | high | low;
67- }
52+ template<uint16_t N, uint16_t H>
53+ enable_if_t<(H <= N) && (N < 32 ), uint32_t> circularBitShiftRightHigher (uint32_t i)
54+ {
55+ // Highest H bits are numbered N-1 through N - H
56+ // N - H is then the middle bit
57+ // Lowest bits numbered from 0 through N - H - 1
58+ NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1 ;
59+ NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H);
60+ NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~((1 << (N - H + 1 )) - 1 );
6861
69- template<uint32_t N, uint32_t H>
70- enable_if_t<H <= N, uint32_t> bitShiftLeftHigher (uint32_t i)
71- {
72- // Highest H bits are numbered N-1 through N - H
73- // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
74- // Lowest bits numbered from 0 through N - H - 1
75- uint32_t low = i & ((1 << (N - H)) - 1 );
76- uint32_t mid = i & (~((1 << (N - H)) - 1 ) | ~(1 << (N - 1 )));
77- uint32_t high = i & (1 << (N - 1 ));
62+ uint32_t low = i & lowMask;
63+ uint32_t mid = i & midMask;
64+ uint32_t high = i & highMask;
7865
79- mid << = 1 ;
80- high >> = H - 1 ;
66+ high >> = 1 ;
67+ mid << = H - 1 ;
8168
82- return mid | high | low;
83- }
69+ return mid | high | low;
70+ }
8471
85- // This function maps the index `idx` in the output array of a Forward FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = output[idx]`
86- // This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
87- template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
88- uint32_t getFrequencyIndex (uint32_t outputIdx)
89- {
90- NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t (mpl::log2<ElementsPerInvocation>::value);
91- NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t (mpl::log2<WorkgroupSize>::value);
72+ template<uint16_t N, uint16_t H>
73+ enable_if_t<(H <= N) && (N < 32 ), uint32_t> circularBitShiftLeftHigher (uint32_t i)
74+ {
75+ // Highest H bits are numbered N-1 through N - H
76+ // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
77+ // Lowest bits numbered from 0 through N - H - 1
78+ NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1 ;
79+ NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~((1 << (N - H)) - 1 ) | ~(1 << (N - 1 ));
80+ NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1 );
9281
93- return bitShiftRightHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1 >(glsl::bitfieldReverse<uint32_t>(outputIdx) >> (32 - FFT_SIZE_LOG_2));
94- }
82+ uint32_t low = i & lowMask;
83+ uint32_t mid = i & midMask;
84+ uint32_t high = i & highMask;
9585
96- // This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Forward FFT such that `DFT[freqIdx] = output[idx]`
97- // It is essentially the inverse of `getFrequencyIndex`
98- template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
99- uint32_t getOutputIndex (uint32_t freqIdx)
100- {
101- NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t (mpl::log2<ElementsPerInvocation>::value);
102- NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t (mpl::log2<WorkgroupSize>::value);
86+ mid <<= 1 ;
87+ high >>= H - 1 ;
10388
104- return glsl::bitfieldReverse<uint32_t>(bitShiftLeftHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1 >(freqIdx)) >> (32 - FFT_SIZE_LOG_2);
105- }
106-
107- // Mirrors an index about the Nyquist frequency
108- template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
109- uint32_t mirror (uint32_t idx)
110- {
111- NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE = WorkgroupSize * uint32_t (ElementsPerInvocation);
112- return (FFT_SIZE - idx) & (FFT_SIZE - 1 );
113- }
89+ return mid | high | low;
90+ }
91+ } //namespace impl
11492
115- // When packing real FFTs a common operation is to get `DFT[T]` and `DFT[-T]` to unpack the result of a packed real FFT.
116- // Given an index `idx` into the Nabla-ordered DFT such that `output[idx] = DFT[T]`, this function is such that `output[getNegativeIndex(idx)] = DFT[-T]`
117- template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
118- uint32_t getNegativeIndex (uint32_t idx)
119- {
120- return getOutputIndex<ElementsPerInvocation, WorkgroupSize>(mirror <ElementsPerInvocation, WorkgroupSize>(getFrequencyIndex<ElementsPerInvocation, WorkgroupSize>(idx)));
121- }
93+ // Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
94+ template <typename scalar_t, uint16_t WorkgroupSize>
95+ NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) * WorkgroupSize;
12296
12397// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
12498template<typename Scalar>
@@ -129,11 +103,45 @@ void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi
129103 lo = x;
130104}
131105
106+ template<uint16_t ElementsPerInvocation, uint16_t WorkgroupSize>
107+ struct FFTIndexingUtils
108+ {
109+ // This function maps the index `idx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[idx]`
110+ // This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
111+ static uint32_t getDFTIndex (uint32_t outputIdx)
112+ {
113+ return impl::circularBitShiftRightHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1 >(glsl::bitfieldReverse<uint32_t>(outputIdx) >> (32 - FFTSizeLog2));
114+ }
115+
116+ // This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Nabla FFT such that `DFT[freqIdx] = NablaFFT[idx]`
117+ // It is essentially the inverse of `getDFTIndex`
118+ static uint32_t getNablaIndex (uint32_t freqIdx)
119+ {
120+ return glsl::bitfieldReverse<uint32_t>(impl::circularBitShiftLeftHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1 >(freqIdx)) >> (32 - FFTSizeLog2);
121+ }
122+
123+ // Mirrors an index about the Nyquist frequency in the DFT order
124+ static uint32_t getDFTMirrorIndex (uint32_t idx)
125+ {
126+ return (FFTSize - idx) & (FFTSize - 1 );
127+ }
128+
129+ // Given an index `idx` of an element into the Nabla FFT, get the index into the Nabla FFT of the element corresponding to its negative frequency
130+ static uint32_t getNablaMirrorIndex (uint32_t idx)
131+ {
132+ return getNablaIndex (getDFTMirrorIndex (getDFTIndex (idx)));
133+ }
134+
135+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = mpl::log2<ElementsPerInvocation>::value;
136+ NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + mpl::log2<WorkgroupSize>::value;
137+ NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t (WorkgroupSize) * uint32_t (ElementsPerInvocation);
138+ };
139+
132140} //namespace fft
133141
134142// ----------------------------------- End Utils -----------------------------------------------
135143
136- template<uint16_t ElementsPerInvocation, bool Inverse, uint32_t WorkgroupSize, typename Scalar, class device_capabilities=void >
144+ template<uint16_t ElementsPerInvocation, bool Inverse, uint16_t WorkgroupSize, typename Scalar, class device_capabilities=void >
137145struct FFT;
138146
139147// For the FFT methods below, we assume:
@@ -153,13 +161,13 @@ struct FFT;
153161// * void workgroupExecutionAndMemoryBarrier();
154162
155163// 2 items per invocation forward specialization
156- template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
164+ template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
157165struct FFT<2 ,false , WorkgroupSize, Scalar, device_capabilities>
158166{
159167 template<typename SharedMemoryAdaptor>
160168 static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
161169 {
162- fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
170+ fft::impl:: exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
163171
164172 // Get twiddle with k = threadID mod stride, halfN = stride
165173 hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
@@ -199,7 +207,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
199207 }
200208
201209 // special last workgroup-shuffle
202- fft::exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
210+ fft::impl:: exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
203211
204212 // Remember to update the accessor's state
205213 sharedmemAccessor = sharedmemAdaptor.accessor;
@@ -217,7 +225,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
217225
218226
219227// 2 items per invocation inverse specialization
220- template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
228+ template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
221229struct FFT<2 ,true , WorkgroupSize, Scalar, device_capabilities>
222230{
223231 template<typename SharedMemoryAdaptor>
@@ -226,7 +234,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
226234 // Get twiddle with k = threadID mod stride, halfN = stride
227235 hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
228236
229- fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
237+ fft::impl:: exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
230238 }
231239
232240
@@ -255,7 +263,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
255263 sharedmemAdaptor.accessor = sharedmemAccessor;
256264
257265 // special first workgroup-shuffle
258- fft::exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
266+ fft::impl:: exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
259267
260268 // The bigger steps
261269 [unroll]
@@ -283,7 +291,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
283291};
284292
285293// Forward FFT
286- template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
294+ template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
287295struct FFT<K, false , WorkgroupSize, Scalar, device_capabilities>
288296{
289297 template<typename Accessor, typename SharedMemoryAccessor>
@@ -326,7 +334,7 @@ struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
326334};
327335
328336// Inverse FFT
329- template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
337+ template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
330338struct FFT<K, true , WorkgroupSize, Scalar, device_capabilities>
331339{
332340 template<typename Accessor, typename SharedMemoryAccessor>
0 commit comments