34 bool four_step_real =
false>
36 const device in_T*
in;
37 threadgroup float2*
buf;
52 const device in_T* in_,
53 threadgroup float2* buf_,
56 const int batch_size_,
83 return float2(
elem, 0);
103 short max_index =
grid.y *
n - 2;
106 constexpr int read_width = 2;
108 short index = read_width * tg_idx + read_width *
threads_per_tg * e;
116 short index = tg_idx +
126 short max_index =
grid.y *
n - 2;
128 constexpr int read_width = 2;
130 short index = read_width * tg_idx + read_width *
threads_per_tg * e;
138 short index = tg_idx +
146 METAL_FUNC
void load_padded(
int length,
const device float2* w_k)
const {
147 int batch_idx =
elem.x *
grid.y * length +
elem.y * length;
148 int fft_idx =
elem.z;
151 threadgroup float2* seq_buf =
buf +
elem.y *
n;
154 if (index < length) {
158 seq_buf[index] = 0.0;
163 METAL_FUNC
void write_padded(
int length,
const device float2* w_k)
const {
164 int batch_idx =
elem.x *
grid.y * length +
elem.y * length;
165 int fft_idx =
elem.z;
167 float2 inv_factor = {1.0f /
n, -1.0f /
n};
169 threadgroup float2* seq_buf =
buf +
elem.y *
n;
172 if (index < length) {
173 float2
elem = seq_buf[index + length - 1] * inv_factor;
188 int coalesce_width =
grid.y;
190 int outer_batch_size = stride / coalesce_width;
192 int strided_batch_idx = (
elem.x % outer_batch_size) * coalesce_width +
193 overall_n * (
elem.x / outer_batch_size);
196 tg_idx % coalesce_width;
214 int ij = (combined_idx / stride) * (combined_idx % stride);
231 bool default_inv =
inv;
265 threadgroup float2* seq_buf =
buf +
elem.y *
n;
273 short fft_idx =
elem.z;
277 seq_buf[index].x =
in[batch_idx + index];
278 seq_buf[index].y =
in[batch_idx + index + next_in];
284 short n_over_2 = (
n / 2) + 1;
286 int batch_idx =
elem.x *
grid.y * n_over_2 * 2 +
elem.y * n_over_2 * 2;
287 threadgroup float2* seq_buf =
buf +
elem.y *
n;
293 float2 conj = {1, -1};
294 float2 minus_j = {0, -1};
297 short fft_idx =
elem.z;
300 int index =
metal::min(fft_idx + e * m, n_over_2 - 1);
304 out[batch_idx + index] = {seq_buf[index].x, 0};
305 out[batch_idx + index + next_out] = {seq_buf[index].y, 0};
307 float2 x_k = seq_buf[index];
308 float2 x_n_minus_k = seq_buf[
n - index] * conj;
309 out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
310 out[batch_idx + index + next_out] =
319 const device float2* w_k)
const {
320 int batch_idx =
elem.x *
grid.y * length * 2 +
elem.y * length * 2;
321 threadgroup float2* seq_buf =
buf +
elem.y *
n;
329 short fft_idx =
elem.z;
333 if (index < length) {
335 float2(
in[batch_idx + index],
in[batch_idx + index + next_in]);
346 const device float2* w_k)
const {
347 int length_over_2 = (length / 2) + 1;
349 elem.x *
grid.y * length_over_2 * 2 +
elem.y * length_over_2 * 2;
350 threadgroup float2* seq_buf =
buf +
elem.y *
n + length - 1;
357 float2 conj = {1, -1};
358 float2 inv_factor = {1.0f /
n, -1.0f /
n};
359 float2 minus_j = {0, -1};
362 short fft_idx =
elem.z;
365 int index =
metal::min(fft_idx + e * m, length_over_2 - 1);
370 out[batch_idx + index] = float2(
elem.x, 0);
371 out[batch_idx + index + next_out] = float2(
elem.y, 0);
373 float2 x_k =
complex_mul(w_k[index], seq_buf[index] * inv_factor);
375 w_k[length - index], seq_buf[length - index] * inv_factor);
378 out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
379 out[batch_idx + index + next_out] =
399 short n_over_2 = (
n / 2) + 1;
400 int batch_idx =
elem.x *
grid.y * n_over_2 * 2 +
elem.y * n_over_2 * 2;
401 threadgroup float2* seq_buf =
buf +
elem.y *
n;
409 short fft_idx =
elem.z;
411 float2 conj = {1, -1};
412 float2 plus_j = {0, 1};
415 int index =
metal::min(fft_idx + t * m, n_over_2 - 1);
416 float2 x =
in[batch_idx + index];
417 float2 y =
in[batch_idx + index + next_in];
419 bool first_val = index == 0;
421 bool last_val =
n % 2 == 0 && index == n_over_2 - 1;
422 if (first_val || last_val) {
427 seq_buf[index].y = -seq_buf[index].y;
428 if (index > 0 && !last_val) {
429 seq_buf[
n - index] = (x * conj) +
complex_mul(y * conj, plus_j);
430 seq_buf[
n - index].y = -seq_buf[
n - index].y;
438 threadgroup float2* seq_buf =
buf +
elem.y *
n;
445 short fft_idx =
elem.z;
449 out[batch_idx + index] = seq_buf[index].x /
n;
450 out[batch_idx + index + next_out] = seq_buf[index].y / -
n;
457 const device float2* w_k)
const {
458 int n_over_2 = (
n / 2) + 1;
459 int length_over_2 = (length / 2) + 1;
462 elem.x *
grid.y * length_over_2 * 2 +
elem.y * length_over_2 * 2;
463 threadgroup float2* seq_buf =
buf +
elem.y *
n;
472 short fft_idx =
elem.z;
474 float2 conj = {1, -1};
475 float2 plus_j = {0, 1};
478 int index =
metal::min(fft_idx + t * m, n_over_2 - 1);
479 float2 x =
in[batch_idx + index];
480 float2 y =
in[batch_idx + index + next_in];
481 if (index < length_over_2) {
482 bool last_val = length % 2 == 0 && index == length_over_2 - 1;
488 seq_buf[index] =
complex_mul(elem1 * conj, w_k[index]);
489 if (index > 0 && !last_val) {
490 float2 elem2 = (x * conj) +
complex_mul(y * conj, plus_j);
491 seq_buf[length - index] =
495 short pad_index =
metal::min(length + (index - length_over_2) * 2,
n - 2);
496 seq_buf[pad_index] = 0;
497 seq_buf[pad_index + 1] = 0;
505 const device float2* w_k)
const {
506 int batch_idx =
elem.x *
grid.y * length * 2 +
elem.y * length * 2;
507 threadgroup float2* seq_buf =
buf +
elem.y *
n + length - 1;
514 short fft_idx =
elem.z;
516 float2 inv_factor = {1.0f /
n, -1.0f /
n};
518 int index = fft_idx + e * m;
519 if (index < length) {
520 float2 output =
complex_mul(seq_buf[index] * inv_factor, w_k[index]);
521 out[batch_idx + index] = output.x / length;
522 out[batch_idx + index + next_out] = output.y / -length;
537 bool default_inv =
inv;
548 int overall_n_over_2 = overall_n / 2 + 1;
549 int coalesce_width =
grid.y;
551 int outer_batch_size = stride / coalesce_width;
553 int strided_batch_idx = (
elem.x % outer_batch_size) * coalesce_width +
554 overall_n_over_2 * (
elem.x / outer_batch_size);
557 tg_idx % coalesce_width;
566 if (tg_idx == 0 &&
elem.x % outer_batch_size == 0) {
567 out[strided_batch_idx + overall_n / 2] =
buf[
n / 2];
577 int overall_n_over_2 = overall_n / 2 + 1;
578 auto conj = float2(1, -1);
584 int overall_batch = device_idx / overall_n;
585 int overall_index = device_idx % overall_n;
586 if (overall_index < overall_n_over_2) {
587 device_idx -= overall_batch * (overall_n - overall_n_over_2);
590 int conj_idx = overall_n - overall_index;
591 device_idx = overall_batch * overall_n_over_2 + conj_idx;
605 bool default_inv =
inv;
METAL_FUNC float2 complex_mul(float2 a, float2 b)
Definition radix.h:19
METAL_FUNC float2 get_twiddle(int k, int p)
Definition radix.h:29
METAL_FUNC bool out_of_bounds() const
Definition readwrite.h:94
METAL_FUNC void load() const
Definition readwrite.h:100
METAL_FUNC float2 pre_out(float2 elem, int length) const
Definition readwrite.h:90
METAL_FUNC ReadWriter(const device in_T *in_, threadgroup float2 *buf_, device out_T *out_, const short n_, const int batch_size_, const short elems_per_thread_, const uint3 elem_, const uint3 grid_, const bool inv_)
Definition readwrite.h:51
threadgroup float2 * buf
Definition readwrite.h:37
uint3 elem
Definition readwrite.h:42
int elems_per_thread
Definition readwrite.h:41
int strided_device_idx
Definition readwrite.h:48
int threads_per_tg
Definition readwrite.h:44
int n
Definition readwrite.h:39
int batch_size
Definition readwrite.h:40
METAL_FUNC float2 post_in(float elem) const
Definition readwrite.h:82
bool inv
Definition readwrite.h:45
METAL_FUNC void write_strided(int stride, int overall_n)
Definition readwrite.h:210
METAL_FUNC void compute_strided_indices(int stride, int overall_n)
Definition readwrite.h:180
METAL_FUNC float2 pre_out(float2 elem) const
Definition readwrite.h:86
METAL_FUNC void write_padded(int length, const device float2 *w_k) const
Definition readwrite.h:163
METAL_FUNC void load_strided(int stride, int overall_n)
Definition readwrite.h:202
METAL_FUNC float2 post_in(float2 elem) const
Definition readwrite.h:77
const device in_T * in
Definition readwrite.h:36
device out_T * out
Definition readwrite.h:38
METAL_FUNC void write() const
Definition readwrite.h:123
uint3 grid
Definition readwrite.h:43
int strided_shared_idx
Definition readwrite.h:49
METAL_FUNC void load_padded(int length, const device float2 *w_k) const
Definition readwrite.h:146