20#define MAX_OUTPUT_SIZE 18
50typedef void (*
RadixFunc)(thread float2*, thread float2*);
53template <
int radix, RadixFunc radix_func>
58 thread
short* indices,
66 constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
68 constexpr short power = __builtin_ctz(radix);
70 j = ((i - k) << power) + k;
73 j = (i / p) * radix * p + k;
79 float2 twiddle = twiddle_1;
83 for (
int t = 2; t < radix; t++) {
92 for (
int t = 0; t < radix; t++) {
93 indices[t] = j + t * p;
99template <
int radix, RadixFunc radix_func>
106 thread float2* inputs,
107 thread
short* indices,
108 thread float2* values,
109 threadgroup float2*
buf) {
122 for (
int s = 0; s < num_steps; s++) {
123 for (
int t = 0; t < max_radices_per_thread; t++) {
126 for (
int r = 0; r < radix; r++) {
127 inputs[r] =
buf[index + r * m_r];
129 radix_butterfly<radix, radix_func>(
130 index, *p, inputs, indices + t * radix, values + t * radix);
135 threadgroup_barrier(mem_flags::mem_threadgroup);
137 for (
int t = 0; t < max_radices_per_thread; t++) {
140 for (
int r = 0; r < radix; r++) {
141 r_index = t * radix + r;
142 buf[indices[r_index]] = values[r_index];
148 threadgroup_barrier(mem_flags::mem_threadgroup);
153#define RADIX_STEP(radix, radix_func, num_steps) \
154 radix_n_steps<radix, radix_func>( \
155 fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
157template <
bool rader = false>
179template <
int tg_mem_size,
typename in_T,
typename out_T>
181 const device in_T* in [[buffer(0)]],
182 device out_T* out [[buffer(1)]],
183 constant
const int& n,
184 constant
const int& batch_size,
185 uint3 elem [[thread_position_in_grid]],
186 uint3 grid [[threads_per_grid]]) {
187 threadgroup float2 shared_in[tg_mem_size];
200 if (read_writer.out_of_bounds()) {
205 threadgroup_barrier(mem_flags::mem_threadgroup);
208 int fft_idx = elem.z;
210 int tg_idx = elem.y * n;
211 threadgroup float2*
buf = &shared_in[tg_idx];
218template <
int tg_mem_size,
typename in_T,
typename out_T>
220 const device in_T* in [[buffer(0)]],
221 device out_T* out [[buffer(1)]],
222 const device float2* raders_b_q [[buffer(2)]],
223 const device
short* raders_g_q [[buffer(3)]],
224 const device
short* raders_g_minus_q [[buffer(4)]],
225 constant
const int& n,
226 constant
const int& batch_size,
227 constant
const int& rader_n,
228 uint3 elem [[thread_position_in_grid]],
229 uint3 grid [[threads_per_grid]]) {
253 threadgroup float2 shared_in[tg_mem_size];
266 if (read_writer.out_of_bounds()) {
271 threadgroup_barrier(mem_flags::mem_threadgroup);
276 int fft_idx = elem.z;
277 int tg_idx = elem.y * n;
278 threadgroup float2*
buf = &shared_in[tg_idx];
290 float2 x_0[2] = {
buf[x_0_index],
buf[x_0_index + 1]};
294 int max_index = n - rader_m - 1;
297 short g_q = raders_g_q[index / rader_m];
298 temp[e] =
buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
301 threadgroup_barrier(mem_flags::mem_threadgroup);
305 buf[index + rader_m] = temp[e];
308 threadgroup_barrier(mem_flags::mem_threadgroup);
316 int x_sum_index =
metal::min(fft_idx, rader_m - 1);
317 buf[x_sum_index] =
buf[rader_m + x_sum_index * (rader_n - 1)];
319 float2 inv = {1.0f, -1.0f};
322 short interleaved_index =
323 index / rader_m + (index % rader_m) * (rader_n - 1);
325 buf[rader_m + interleaved_index],
326 raders_b_q[interleaved_index % (rader_n - 1)]);
329 threadgroup_barrier(mem_flags::mem_threadgroup);
333 buf[rader_m + index] = temp[e] * inv;
336 threadgroup_barrier(mem_flags::mem_threadgroup);
342 float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
346 short diff_index = index / (rader_n - 1) - x_0_index;
347 temp[e] =
buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
351 float2 x_sum =
buf[x_0_index] + x_0[0];
353 threadgroup_barrier(mem_flags::mem_threadgroup);
357 short g_q_index = index % (rader_n - 1);
358 short g_q = raders_g_minus_q[g_q_index];
359 short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
360 buf[out_index] = temp[e];
363 buf[x_0_index * rader_n] = x_sum;
365 threadgroup_barrier(mem_flags::mem_threadgroup);
373template <
int tg_mem_size,
typename in_T,
typename out_T>
375 const device in_T* in [[buffer(0)]],
376 device out_T* out [[buffer(1)]],
377 const device float2* w_q [[buffer(2)]],
378 const device float2* w_k [[buffer(3)]],
379 constant
const int& length,
380 constant
const int& n,
381 constant
const int& batch_size,
382 uint3 elem [[thread_position_in_grid]],
383 uint3 grid [[threads_per_grid]]) {
393 threadgroup float2 shared_in[tg_mem_size];
406 if (read_writer.out_of_bounds()) {
409 read_writer.load_padded(length, w_k);
411 threadgroup_barrier(mem_flags::mem_threadgroup);
414 int fft_idx = elem.z;
416 int tg_idx = elem.y * n;
417 threadgroup float2*
buf = &shared_in[tg_idx];
422 float2 inv = float2(1.0f, -1.0f);
424 int index = fft_idx + t * m;
428 threadgroup_barrier(mem_flags::mem_threadgroup);
434 read_writer.write_padded(length, w_k);
444 const device in_T* in [[buffer(0)]],
445 device out_T* out [[buffer(1)]],
446 constant
const int& n1,
447 constant
const int& n2,
448 constant
const int& batch_size,
449 uint3 elem [[thread_position_in_grid]],
450 uint3 grid [[threads_per_grid]]) {
452 int overall_n = n1 * n2;
453 int n = step == 0 ? n1 : n2;
454 int stride = step == 0 ? n2 : n1;
458 int fft_idx = elem.z;
460 threadgroup float2 shared_in[tg_mem_size];
461 threadgroup float2*
buf = &shared_in[elem.y * n];
464 read_writer_t read_writer = read_writer_t(
475 if (read_writer.out_of_bounds()) {
478 read_writer.load_strided(stride, overall_n);
480 threadgroup_barrier(mem_flags::mem_threadgroup);
485 read_writer.write_strided(stride, overall_n);
METAL_FUNC void radix5(thread float2 *x, thread float2 *y)
Definition radix.h:69
METAL_FUNC void radix4(thread float2 *x, thread float2 *y)
Definition radix.h:56
METAL_FUNC void radix11(thread float2 *x, thread float2 *y)
Definition radix.h:201
METAL_FUNC void radix3(thread float2 *x, thread float2 *y)
Definition radix.h:41
METAL_FUNC float2 complex_mul(float2 a, float2 b)
Definition radix.h:19
METAL_FUNC void radix8(thread float2 *x, thread float2 *y)
Definition radix.h:151
METAL_FUNC void radix7(thread float2 *x, thread float2 *y)
Definition radix.h:122
METAL_FUNC void radix2(thread float2 *x, thread float2 *y)
Definition radix.h:36
METAL_FUNC void radix13(thread float2 *x, thread float2 *y)
Definition radix.h:290
METAL_FUNC float2 get_twiddle(int k, int p)
Definition radix.h:29
METAL_FUNC void radix6(thread float2 *x, thread float2 *y)
Definition radix.h:96
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
Definition readwrite.h:35