MLX
 
Loading...
Searching...
No Matches
fft.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3// Metal FFT using Stockham's algorithm
4//
5// References:
6// - VkFFT (https://github.com/DTolm/VkFFT)
7// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
8
9#include <metal_common>
10
14
15using namespace metal;
16
17#define MAX_RADIX 13
18// Reached when elems_per_thread_ = 6, max_radix = 13
19// and some threads have to do 3 radix 6s requiring 18 float2s.
20#define MAX_OUTPUT_SIZE 18
21
22// Specialize for a particular value of N at runtime
23STEEL_CONST bool inv_ [[function_constant(0)]];
24STEEL_CONST bool is_power_of_2_ [[function_constant(1)]];
25STEEL_CONST int elems_per_thread_ [[function_constant(2)]];
26// rader_m = n / rader_n
27STEEL_CONST int rader_m_ [[function_constant(3)]];
28// Stockham steps
29STEEL_CONST int radix_13_steps_ [[function_constant(4)]];
30STEEL_CONST int radix_11_steps_ [[function_constant(5)]];
31STEEL_CONST int radix_8_steps_ [[function_constant(6)]];
32STEEL_CONST int radix_7_steps_ [[function_constant(7)]];
33STEEL_CONST int radix_6_steps_ [[function_constant(8)]];
34STEEL_CONST int radix_5_steps_ [[function_constant(9)]];
35STEEL_CONST int radix_4_steps_ [[function_constant(10)]];
36STEEL_CONST int radix_3_steps_ [[function_constant(11)]];
37STEEL_CONST int radix_2_steps_ [[function_constant(12)]];
38// Rader steps
39STEEL_CONST int rader_13_steps_ [[function_constant(13)]];
40STEEL_CONST int rader_11_steps_ [[function_constant(14)]];
41STEEL_CONST int rader_8_steps_ [[function_constant(15)]];
42STEEL_CONST int rader_7_steps_ [[function_constant(16)]];
43STEEL_CONST int rader_6_steps_ [[function_constant(17)]];
44STEEL_CONST int rader_5_steps_ [[function_constant(18)]];
45STEEL_CONST int rader_4_steps_ [[function_constant(19)]];
46STEEL_CONST int rader_3_steps_ [[function_constant(20)]];
47STEEL_CONST int rader_2_steps_ [[function_constant(21)]];
48
49// See "radix.h" for radix codelets
50typedef void (*RadixFunc)(thread float2*, thread float2*);
51
52// Perform a single radix n butterfly with appropriate twiddles
53template <int radix, RadixFunc radix_func>
54METAL_FUNC void radix_butterfly(
55 int i,
56 int p,
57 thread float2* x,
58 thread short* indices,
59 thread float2* y) {
60 // i: the index in the overall DFT that we're processing.
61 // p: the size of the DFTs we're merging at this step.
62 // m: how many threads are working on this DFT.
63 int k, j;
64
65 // Use faster bitwise operations when working with powers of two
66 constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
67 if (radix_p_2 && is_power_of_2_) {
68 constexpr short power = __builtin_ctz(radix);
69 k = i & (p - 1);
70 j = ((i - k) << power) + k;
71 } else {
72 k = i % p;
73 j = (i / p) * radix * p + k;
74 }
75
76 // Apply twiddles
77 if (p > 1) {
78 float2 twiddle_1 = get_twiddle(k, radix * p);
79 float2 twiddle = twiddle_1;
80 x[1] = complex_mul(x[1], twiddle);
81
83 for (int t = 2; t < radix; t++) {
84 twiddle = complex_mul(twiddle, twiddle_1);
85 x[t] = complex_mul(x[t], twiddle);
86 }
87 }
88
89 radix_func(x, y);
90
92 for (int t = 0; t < radix; t++) {
93 indices[t] = j + t * p;
94 }
95}
96
97// Perform all the radix steps required for a
98// particular radix size n.
99template <int radix, RadixFunc radix_func>
100METAL_FUNC void radix_n_steps(
101 int i,
102 thread int* p,
103 int m,
104 int n,
105 int num_steps,
106 thread float2* inputs,
107 thread short* indices,
108 thread float2* values,
109 threadgroup float2* buf) {
110 int m_r = n / radix;
111 // When combining different sized radices, we have to do
112 // multiple butterflies in a single thread.
113 // E.g. n = 28 = 4 * 7
114 // 4 threads, 7 elems_per_thread
115 // All threads do 1 radix7 butterfly.
116 // 3 threads do 2 radix4 butterflies.
117 // 1 thread does 1 radix4 butterfly.
118 int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;
119
120 int index = 0;
121 int r_index = 0;
122 for (int s = 0; s < num_steps; s++) {
123 for (int t = 0; t < max_radices_per_thread; t++) {
124 index = i + t * m;
125 if (index < m_r) {
126 for (int r = 0; r < radix; r++) {
127 inputs[r] = buf[index + r * m_r];
128 }
130 index, *p, inputs, indices + t * radix, values + t * radix);
131 }
132 }
133
134 // Wait until all threads have read their inputs into thread local mem
135 threadgroup_barrier(mem_flags::mem_threadgroup);
136
137 for (int t = 0; t < max_radices_per_thread; t++) {
138 index = i + t * m;
139 if (index < m_r) {
140 for (int r = 0; r < radix; r++) {
141 r_index = t * radix + r;
142 buf[indices[r_index]] = values[r_index];
143 }
144 }
145 }
146
147 // Wait until all threads have written back to threadgroup mem
148 threadgroup_barrier(mem_flags::mem_threadgroup);
149 *p *= radix;
150 }
151}
152
153#define RADIX_STEP(radix, radix_func, num_steps) \
154 radix_n_steps<radix, radix_func>( \
155 fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
156
157template <bool rader = false>
158METAL_FUNC void
159perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {
160 float2 inputs[MAX_RADIX];
161 short indices[MAX_OUTPUT_SIZE];
162 float2 values[MAX_OUTPUT_SIZE];
163
173}
174
175// Each FFT is computed entirely in shared GPU memory.
176//
177// N is decomposed into radix-n DFTs:
178// e.g. 128 = 2 * 4 * 4 * 4
179template <int tg_mem_size, typename in_T, typename out_T>
180[[kernel]] void fft(
181 const device in_T* in [[buffer(0)]],
182 device out_T* out [[buffer(1)]],
183 constant const int& n,
184 constant const int& batch_size,
185 uint3 elem [[thread_position_in_grid]],
186 uint3 grid [[threads_per_grid]]) {
187 threadgroup float2 shared_in[tg_mem_size];
188
190 in,
191 &shared_in[0],
192 out,
193 n,
194 batch_size,
196 elem,
197 grid,
198 inv_);
199
200 if (read_writer.out_of_bounds()) {
201 return;
202 };
203 read_writer.load();
204
205 threadgroup_barrier(mem_flags::mem_threadgroup);
206
207 int p = 1;
208 int fft_idx = elem.z; // Thread index in DFT
209 int m = grid.z; // Threads per DFT
210 int tg_idx = elem.y * n; // Index of this DFT in threadgroup
211 threadgroup float2* buf = &shared_in[tg_idx];
212
213 perform_fft(fft_idx, &p, m, n, buf);
214
215 read_writer.write();
216}
217
218template <int tg_mem_size, typename in_T, typename out_T>
219[[kernel]] void rader_fft(
220 const device in_T* in [[buffer(0)]],
221 device out_T* out [[buffer(1)]],
222 const device float2* raders_b_q [[buffer(2)]],
223 const device short* raders_g_q [[buffer(3)]],
224 const device short* raders_g_minus_q [[buffer(4)]],
225 constant const int& n,
226 constant const int& batch_size,
227 constant const int& rader_n,
228 uint3 elem [[thread_position_in_grid]],
229 uint3 grid [[threads_per_grid]]) {
230 // Use Rader's algorithm to compute fast FFTs
231 // when a prime factor `p` of `n` is greater than 13 but
232 // has `p - 1` Stockham decomposable into to prime factors <= 13.
233 //
234 // E.g. n = 102
235 // = 2 * 3 * 17
236 // . = 2 * 3 * RADER(16)
237 // . = 2 * 3 * RADER(4 * 4)
238 //
239 // In numpy:
240 // x_perm = x[g_q]
241 // y = np.fft.fft(x_perm) * b_q
242 // z = np.fft.ifft(y) + x[0]
243 // out = z[g_minus_q]
244 // out[0] = x[1:].sum()
245 //
246 // Where the g_q and g_minus_q are permutations formed
247 // by the group under multiplicative modulo N using the
248 // primitive root of N and b_q is a constant.
249 // See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm
250 //
251 // Rader's uses fewer operations than Bluestein's and so
252 // is more accurate. It's also faster in most cases.
253 threadgroup float2 shared_in[tg_mem_size];
254
256 in,
257 &shared_in[0],
258 out,
259 n,
260 batch_size,
262 elem,
263 grid,
264 inv_);
265
266 if (read_writer.out_of_bounds()) {
267 return;
268 };
269 read_writer.load();
270
271 threadgroup_barrier(mem_flags::mem_threadgroup);
272
273 // The number of the threads we're using for each DFT
274 int m = grid.z;
275
276 int fft_idx = elem.z;
277 int tg_idx = elem.y * n;
278 threadgroup float2* buf = &shared_in[tg_idx];
279
280 // rader_m = n / rader_n;
281 int rader_m = rader_m_;
282
283 // We have to load two x_0s for each thread since sometimes
284 // elems_per_thread_ crosses a boundary.
285 // E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4
286 // 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8
287 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
288 short x_0_index =
289 metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);
290 float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};
291
292 // Do the Rader permutation in shared memory
293 float2 temp[MAX_RADIX];
294 int max_index = n - rader_m - 1;
295 for (int e = 0; e < elems_per_thread_; e++) {
296 short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
297 short g_q = raders_g_q[index / rader_m];
298 temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
299 }
300
301 threadgroup_barrier(mem_flags::mem_threadgroup);
302
303 for (int e = 0; e < elems_per_thread_; e++) {
304 short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
305 buf[index + rader_m] = temp[e];
306 }
307
308 threadgroup_barrier(mem_flags::mem_threadgroup);
309
310 // Rader FFT on x[rader_m:]
311 int p = 1;
312 perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
313
314 // x_1 + ... + x_n is computed for us in the first FFT step so
315 // we save it in the first rader_m indices of the array for later.
316 int x_sum_index = metal::min(fft_idx, rader_m - 1);
317 buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];
318
319 float2 inv = {1.0f, -1.0f};
320 for (int e = 0; e < elems_per_thread_; e++) {
321 short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
322 short interleaved_index =
323 index / rader_m + (index % rader_m) * (rader_n - 1);
324 temp[e] = complex_mul(
325 buf[rader_m + interleaved_index],
326 raders_b_q[interleaved_index % (rader_n - 1)]);
327 }
328
329 threadgroup_barrier(mem_flags::mem_threadgroup);
330
331 for (int e = 0; e < elems_per_thread_; e++) {
332 short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
333 buf[rader_m + index] = temp[e] * inv;
334 }
335
336 threadgroup_barrier(mem_flags::mem_threadgroup);
337
338 // Rader IFFT on x[rader_m:]
339 p = 1;
340 perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
341
342 float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
343
344 for (int e = 0; e < elems_per_thread_; e++) {
345 short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);
346 short diff_index = index / (rader_n - 1) - x_0_index;
347 temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
348 }
349
350 // Use the sum of elements that was computed in the first FFT
351 float2 x_sum = buf[x_0_index] + x_0[0];
352
353 threadgroup_barrier(mem_flags::mem_threadgroup);
354
355 for (int e = 0; e < elems_per_thread_; e++) {
356 short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
357 short g_q_index = index % (rader_n - 1);
358 short g_q = raders_g_minus_q[g_q_index];
359 short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
360 buf[out_index] = temp[e];
361 }
362
363 buf[x_0_index * rader_n] = x_sum;
364
365 threadgroup_barrier(mem_flags::mem_threadgroup);
366
367 p = rader_n;
368 perform_fft(fft_idx, &p, m, n, buf);
369
370 read_writer.write();
371}
372
373template <int tg_mem_size, typename in_T, typename out_T>
374[[kernel]] void bluestein_fft(
375 const device in_T* in [[buffer(0)]],
376 device out_T* out [[buffer(1)]],
377 const device float2* w_q [[buffer(2)]],
378 const device float2* w_k [[buffer(3)]],
379 constant const int& length,
380 constant const int& n,
381 constant const int& batch_size,
382 uint3 elem [[thread_position_in_grid]],
383 uint3 grid [[threads_per_grid]]) {
384 // Computes arbitrary length FFTs with Bluestein's algorithm
385 //
386 // In numpy:
387 // bluestein_n = next_power_of_2(2*n - 1)
388 // out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)
389 //
390 // Where w_k and w_q are precomputed on CPU in high precision as:
391 // w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))
392 // w_q = np.fft.fft(1/w_k[-n:])
393 threadgroup float2 shared_in[tg_mem_size];
394
396 in,
397 &shared_in[0],
398 out,
399 n,
400 batch_size,
402 elem,
403 grid,
404 inv_);
405
406 if (read_writer.out_of_bounds()) {
407 return;
408 };
409 read_writer.load_padded(length, w_k);
410
411 threadgroup_barrier(mem_flags::mem_threadgroup);
412
413 int p = 1;
414 int fft_idx = elem.z; // Thread index in DFT
415 int m = grid.z; // Threads per DFT
416 int tg_idx = elem.y * n; // Index of this DFT in threadgroup
417 threadgroup float2* buf = &shared_in[tg_idx];
418
419 // fft
420 perform_fft(fft_idx, &p, m, n, buf);
421
422 float2 inv = float2(1.0f, -1.0f);
423 for (int t = 0; t < elems_per_thread_; t++) {
424 int index = fft_idx + t * m;
425 buf[index] = complex_mul(buf[index], w_q[index]) * inv;
426 }
427
428 threadgroup_barrier(mem_flags::mem_threadgroup);
429
430 // ifft
431 p = 1;
432 perform_fft(fft_idx, &p, m, n, buf);
433
434 read_writer.write_padded(length, w_k);
435}
436
437template <
438 int tg_mem_size,
439 typename in_T,
440 typename out_T,
441 int step,
442 bool real = false>
443[[kernel]] void four_step_fft(
444 const device in_T* in [[buffer(0)]],
445 device out_T* out [[buffer(1)]],
446 constant const int& n1,
447 constant const int& n2,
448 constant const int& batch_size,
449 uint3 elem [[thread_position_in_grid]],
450 uint3 grid [[threads_per_grid]]) {
451 // Fast four step FFT implementation for powers of 2.
452 int overall_n = n1 * n2;
453 int n = step == 0 ? n1 : n2;
454 int stride = step == 0 ? n2 : n1;
455
456 // The number of the threads we're using for each DFT
457 int m = grid.z;
458 int fft_idx = elem.z;
459
460 threadgroup float2 shared_in[tg_mem_size];
461 threadgroup float2* buf = &shared_in[elem.y * n];
462
463 using read_writer_t = ReadWriter<in_T, out_T, step, real>;
464 read_writer_t read_writer = read_writer_t(
465 in,
466 &shared_in[0],
467 out,
468 n,
469 batch_size,
471 elem,
472 grid,
473 inv_);
474
475 if (read_writer.out_of_bounds()) {
476 return;
477 };
478 read_writer.load_strided(stride, overall_n);
479
480 threadgroup_barrier(mem_flags::mem_threadgroup);
481
482 int p = 1;
483 perform_fft(fft_idx, &p, m, n, buf);
484
485 read_writer.write_strided(stride, overall_n);
486}
static constant constexpr const int rader_6_steps_
Definition fft.h:43
METAL_FUNC void perform_fft(int fft_idx, thread int *p, int m, int n, threadgroup float2 *buf)
Definition fft.h:159
void bluestein_fft(const device in_T *in, device out_T *out, const device float2 *w_q, const device float2 *w_k, constant const int &length, constant const int &n, constant const int &batch_size, uint3 elem, uint3 grid)
Definition fft.h:374
static constant constexpr const int rader_7_steps_
Definition fft.h:42
static constant constexpr const int radix_4_steps_
Definition fft.h:35
static constant constexpr const int rader_11_steps_
Definition fft.h:40
static constant constexpr const int rader_13_steps_
Definition fft.h:39
static constant constexpr const int radix_7_steps_
Definition fft.h:32
METAL_FUNC void radix_butterfly(int i, int p, thread float2 *x, thread short *indices, thread float2 *y)
Definition fft.h:54
#define MAX_OUTPUT_SIZE
Definition fft.h:20
static constant constexpr const bool is_power_of_2_
Definition fft.h:24
static constant constexpr const int rader_2_steps_
Definition fft.h:47
static constant constexpr const int radix_6_steps_
Definition fft.h:33
static constant constexpr const int radix_8_steps_
Definition fft.h:31
void fft(const device in_T *in, device out_T *out, constant const int &n, constant const int &batch_size, uint3 elem, uint3 grid)
Definition fft.h:180
static constant constexpr const int radix_2_steps_
Definition fft.h:37
static constant constexpr const int radix_3_steps_
Definition fft.h:36
void four_step_fft(const device in_T *in, device out_T *out, constant const int &n1, constant const int &n2, constant const int &batch_size, uint3 elem, uint3 grid)
Definition fft.h:443
void(* RadixFunc)(thread float2 *, thread float2 *)
Definition fft.h:50
#define RADIX_STEP(radix, radix_func, num_steps)
Definition fft.h:153
static constant constexpr const bool inv_
Definition fft.h:23
#define MAX_RADIX
Definition fft.h:17
static constant constexpr const int radix_11_steps_
Definition fft.h:30
static constant constexpr const int radix_5_steps_
Definition fft.h:34
METAL_FUNC void radix_n_steps(int i, thread int *p, int m, int n, int num_steps, thread float2 *inputs, thread short *indices, thread float2 *values, threadgroup float2 *buf)
Definition fft.h:100
static constant constexpr const int radix_13_steps_
Definition fft.h:29
static constant constexpr const int rader_m_
Definition fft.h:27
static constant constexpr const int rader_8_steps_
Definition fft.h:41
static constant constexpr const int rader_4_steps_
Definition fft.h:45
void rader_fft(const device in_T *in, device out_T *out, const device float2 *raders_b_q, const device short *raders_g_q, const device short *raders_g_minus_q, constant const int &n, constant const int &batch_size, constant const int &rader_n, uint3 elem, uint3 grid)
Definition fft.h:219
static constant constexpr const int elems_per_thread_
Definition fft.h:25
static constant constexpr const int rader_3_steps_
Definition fft.h:46
static constant constexpr const int rader_5_steps_
Definition fft.h:44
METAL_FUNC void radix_func(thread float *x)
Definition hadamard.h:11
Definition bf16_math.h:226
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
auto real(Simd< T, 1 > in) -> Simd< decltype(std::real(in.value)), 1 >
Definition base_simd.h:104
METAL_FUNC void radix5(thread float2 *x, thread float2 *y)
Definition radix.h:69
METAL_FUNC void radix4(thread float2 *x, thread float2 *y)
Definition radix.h:56
METAL_FUNC void radix11(thread float2 *x, thread float2 *y)
Definition radix.h:201
METAL_FUNC void radix3(thread float2 *x, thread float2 *y)
Definition radix.h:41
METAL_FUNC float2 complex_mul(float2 a, float2 b)
Definition radix.h:19
METAL_FUNC void radix8(thread float2 *x, thread float2 *y)
Definition radix.h:151
METAL_FUNC void radix7(thread float2 *x, thread float2 *y)
Definition radix.h:122
METAL_FUNC void radix2(thread float2 *x, thread float2 *y)
Definition radix.h:36
METAL_FUNC void radix13(thread float2 *x, thread float2 *y)
Definition radix.h:290
METAL_FUNC float2 get_twiddle(int k, int p)
Definition radix.h:29
METAL_FUNC void radix6(thread float2 *x, thread float2 *y)
Definition radix.h:96
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
Definition readwrite.h:35