34    bool four_step_real = 
false>
 
   36  const device in_T* 
in;
 
   37  threadgroup float2* 
buf;
 
   52      const device in_T* in_,
 
   53      threadgroup float2* buf_,
 
   56      const int batch_size_,
 
 
   83    return float2(
elem, 0);
 
 
  103    short max_index = 
grid.y * 
n - 2;
 
  106    constexpr int read_width = 2;
 
  108      short index = read_width * tg_idx + read_width * 
threads_per_tg * e;
 
  116      short index = tg_idx +
 
 
  126    short max_index = 
grid.y * 
n - 2;
 
  128    constexpr int read_width = 2;
 
  130      short index = read_width * tg_idx + read_width * 
threads_per_tg * e;
 
  138      short index = tg_idx +
 
 
  146  METAL_FUNC 
void load_padded(
int length, 
const device float2* w_k)
 const {
 
  147    int batch_idx = 
elem.x * 
grid.y * length + 
elem.y * length;
 
  148    int fft_idx = 
elem.z;
 
  151    threadgroup float2* seq_buf = 
buf + 
elem.y * 
n;
 
  154      if (index < length) {
 
  158        seq_buf[index] = 0.0;
 
 
  163  METAL_FUNC 
void write_padded(
int length, 
const device float2* w_k)
 const {
 
  164    int batch_idx = 
elem.x * 
grid.y * length + 
elem.y * length;
 
  165    int fft_idx = 
elem.z;
 
  167    float2 inv_factor = {1.0f / 
n, -1.0f / 
n};
 
  169    threadgroup float2* seq_buf = 
buf + 
elem.y * 
n;
 
  172      if (index < length) {
 
  173        float2 
elem = seq_buf[index + length - 1] * inv_factor;
 
 
  188    int coalesce_width = 
grid.y;
 
  190    int outer_batch_size = stride / coalesce_width;
 
  192    int strided_batch_idx = (
elem.x % outer_batch_size) * coalesce_width +
 
  193        overall_n * (
elem.x / outer_batch_size);
 
  196        tg_idx % coalesce_width;
 
 
  214      int ij = (combined_idx / stride) * (combined_idx % stride);
 
 
 
  231  bool default_inv = inv;
 
 
  241  compute_strided_indices(stride, overall_n);
 
  242  for (
int e = 0; e < elems_per_thread; e++) {
 
  243    float2 output = 
buf[strided_shared_idx + e];
 
  244    out[strided_device_idx + e * stride] = pre_out(output, overall_n);
 
 
  257  int grid_index = elem.x * grid.y + elem.y;
 
  259  return grid_index * 2 >= batch_size;
 
 
  264  int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
 
  265  threadgroup float2* seq_buf = 
buf + elem.y * n;
 
  268  int grid_index = elem.x * grid.y + elem.y;
 
  270      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
 
  273  short fft_idx = elem.z;
 
  275  for (
int e = 0; e < elems_per_thread; e++) {
 
  276    int index = 
metal::min(fft_idx + e * m, n - 1);
 
  277    seq_buf[index].x = in[batch_idx + index];
 
  278    seq_buf[index].y = in[batch_idx + index + next_in];
 
 
  284  short n_over_2 = (n / 2) + 1;
 
  286  int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
 
  287  threadgroup float2* seq_buf = 
buf + elem.y * n;
 
  289  int grid_index = elem.x * grid.y + elem.y;
 
  291      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
 
  293  float2 conj = {1, -1};
 
  294  float2 minus_j = {0, -1};
 
  297  short fft_idx = elem.z;
 
  299  for (
int e = 0; e < elems_per_thread / 2 + 1; e++) {
 
  300    int index = 
metal::min(fft_idx + e * m, n_over_2 - 1);
 
  304      out[batch_idx + index] = {seq_buf[index].x, 0};
 
  305      out[batch_idx + index + next_out] = {seq_buf[index].y, 0};
 
  307      float2 x_k = seq_buf[index];
 
  308      float2 x_n_minus_k = seq_buf[n - index] * conj;
 
  309      out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
 
  310      out[batch_idx + index + next_out] =
 
 
  319    const device float2* w_k)
 const {
 
  320  int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
 
  321  threadgroup float2* seq_buf = 
buf + elem.y * n;
 
  324  int grid_index = elem.x * grid.y + elem.y;
 
  326      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
 
  329  short fft_idx = elem.z;
 
  331  for (
int e = 0; e < elems_per_thread; e++) {
 
  332    int index = 
metal::min(fft_idx + e * m, n - 1);
 
  333    if (index < length) {
 
  335          float2(in[batch_idx + index], in[batch_idx + index + next_in]);
 
 
  346    const device float2* w_k)
 const {
 
  347  int length_over_2 = (length / 2) + 1;
 
  349      elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
 
  350  threadgroup float2* seq_buf = 
buf + elem.y * n + length - 1;
 
  352  int grid_index = elem.x * grid.y + elem.y;
 
  353  short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
 
  357  float2 conj = {1, -1};
 
  358  float2 inv_factor = {1.0f / n, -1.0f / n};
 
  359  float2 minus_j = {0, -1};
 
  362  short fft_idx = elem.z;
 
  364  for (
int e = 0; e < elems_per_thread / 2 + 1; e++) {
 
  365    int index = 
metal::min(fft_idx + e * m, length_over_2 - 1);
 
  369      float2 elem = 
complex_mul(w_k[index], seq_buf[index] * inv_factor);
 
  370      out[batch_idx + index] = float2(elem.x, 0);
 
  371      out[batch_idx + index + next_out] = float2(elem.y, 0);
 
  373      float2 x_k = 
complex_mul(w_k[index], seq_buf[index] * inv_factor);
 
  375          w_k[length - index], seq_buf[length - index] * inv_factor);
 
  378      out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
 
  379      out[batch_idx + index + next_out] =
 
 
  392  int grid_index = elem.x * grid.y + elem.y;
 
  394  return grid_index * 2 >= batch_size;
 
 
  399  short n_over_2 = (n / 2) + 1;
 
  400  int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
 
  401  threadgroup float2* seq_buf = 
buf + elem.y * n;
 
  404  int grid_index = elem.x * grid.y + elem.y;
 
  406      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
 
  409  short fft_idx = elem.z;
 
  411  float2 conj = {1, -1};
 
  412  float2 plus_j = {0, 1};
 
  414  for (
int t = 0; t < elems_per_thread / 2 + 1; t++) {
 
  415    int index = 
metal::min(fft_idx + t * m, n_over_2 - 1);
 
  416    float2 x = in[batch_idx + index];
 
  417    float2 y = in[batch_idx + index + next_in];
 
  419    bool first_val = index == 0;
 
  421    bool last_val = n % 2 == 0 && index == n_over_2 - 1;
 
  422    if (first_val || last_val) {
 
  427    seq_buf[index].y = -seq_buf[index].y;
 
  428    if (index > 0 && !last_val) {
 
  429      seq_buf[n - index] = (x * conj) + 
complex_mul(y * conj, plus_j);
 
  430      seq_buf[n - index].y = -seq_buf[n - index].y;
 
 
  437  int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
 
  438  threadgroup float2* seq_buf = 
buf + elem.y * n;
 
  440  int grid_index = elem.x * grid.y + elem.y;
 
  442      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
 
  445  short fft_idx = elem.z;
 
  447  for (
int e = 0; e < elems_per_thread; e++) {
 
  448    int index = 
metal::min(fft_idx + e * m, n - 1);
 
  449    out[batch_idx + index] = seq_buf[index].x / n;
 
  450    out[batch_idx + index + next_out] = seq_buf[index].y / -n;
 
 
  457    const device float2* w_k)
 const {
 
  458  int n_over_2 = (n / 2) + 1;
 
  459  int length_over_2 = (length / 2) + 1;
 
  462      elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
 
  463  threadgroup float2* seq_buf = 
buf + elem.y * n;
 
  466  int grid_index = elem.x * grid.y + elem.y;
 
  467  short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
 
  472  short fft_idx = elem.z;
 
  474  float2 conj = {1, -1};
 
  475  float2 plus_j = {0, 1};
 
  477  for (
int t = 0; t < elems_per_thread / 2 + 1; t++) {
 
  478    int index = 
metal::min(fft_idx + t * m, n_over_2 - 1);
 
  479    float2 x = in[batch_idx + index];
 
  480    float2 y = in[batch_idx + index + next_in];
 
  481    if (index < length_over_2) {
 
  482      bool last_val = length % 2 == 0 && index == length_over_2 - 1;
 
  488      seq_buf[index] = 
complex_mul(elem1 * conj, w_k[index]);
 
  489      if (index > 0 && !last_val) {
 
  490        float2 elem2 = (x * conj) + 
complex_mul(y * conj, plus_j);
 
  491        seq_buf[length - index] =
 
  495      short pad_index = 
metal::min(length + (index - length_over_2) * 2, n - 2);
 
  496      seq_buf[pad_index] = 0;
 
  497      seq_buf[pad_index + 1] = 0;
 
 
  505    const device float2* w_k)
 const {
 
  506  int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
 
  507  threadgroup float2* seq_buf = 
buf + elem.y * n + length - 1;
 
  509  int grid_index = elem.x * grid.y + elem.y;
 
  511      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
 
  514  short fft_idx = elem.z;
 
  516  float2 inv_factor = {1.0f / n, -1.0f / n};
 
  517  for (
int e = 0; e < elems_per_thread; e++) {
 
  518    int index = fft_idx + e * m;
 
  519    if (index < length) {
 
  520      float2 output = 
complex_mul(seq_buf[index] * inv_factor, w_k[index]);
 
  521      out[batch_idx + index] = output.x / length;
 
  522      out[batch_idx + index + next_out] = output.y / -length;
 
 
  537  bool default_inv = inv;
 
 
  548  int overall_n_over_2 = overall_n / 2 + 1;
 
  549  int coalesce_width = grid.y;
 
  550  int tg_idx = elem.y * grid.z + elem.z;
 
  551  int outer_batch_size = stride / coalesce_width;
 
  553  int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
 
  554      overall_n_over_2 * (elem.x / outer_batch_size);
 
  555  strided_device_idx = strided_batch_idx +
 
  556      tg_idx / coalesce_width * elems_per_thread / 2 * stride +
 
  557      tg_idx % coalesce_width;
 
  558  strided_shared_idx = (tg_idx % coalesce_width) * n +
 
  559      tg_idx / coalesce_width * elems_per_thread / 2;
 
  560  for (
int e = 0; e < elems_per_thread / 2; e++) {
 
  561    float2 output = 
buf[strided_shared_idx + e];
 
  562    out[strided_device_idx + e * stride] = output;
 
  566  if (tg_idx == 0 && elem.x % outer_batch_size == 0) {
 
  567    out[strided_batch_idx + overall_n / 2] = 
buf[n / 2];
 
 
  577  int overall_n_over_2 = overall_n / 2 + 1;
 
  578  auto conj = float2(1, -1);
 
  580  compute_strided_indices(stride, overall_n);
 
  582  for (
int e = 0; e < elems_per_thread; e++) {
 
  583    int device_idx = strided_device_idx + e * stride;
 
  584    int overall_batch = device_idx / overall_n;
 
  585    int overall_index = device_idx % overall_n;
 
  586    if (overall_index < overall_n_over_2) {
 
  587      device_idx -= overall_batch * (overall_n - overall_n_over_2);
 
  588      buf[strided_shared_idx + e] = in[device_idx] * conj;
 
  590      int conj_idx = overall_n - overall_index;
 
  591      device_idx = overall_batch * overall_n_over_2 + conj_idx;
 
  592      buf[strided_shared_idx + e] = in[device_idx];
 
 
  605  bool default_inv = inv;
 
 
  616  compute_strided_indices(stride, overall_n);
 
  618  for (
int e = 0; e < elems_per_thread; e++) {
 
  619    out[strided_device_idx + e * stride] =
 
  620        pre_out(
buf[strided_shared_idx + e], overall_n).x;
 
 
METAL_FUNC float2 complex_mul(float2 a, float2 b)
Definition radix.h:19
 
METAL_FUNC float2 get_twiddle(int k, int p)
Definition radix.h:29
 
Definition readwrite.h:35
 
METAL_FUNC bool out_of_bounds() const
Definition readwrite.h:94
 
METAL_FUNC void load() const
Definition readwrite.h:100
 
METAL_FUNC float2 pre_out(float2 elem, int length) const
Definition readwrite.h:90
 
METAL_FUNC ReadWriter(const device in_T *in_, threadgroup float2 *buf_, device out_T *out_, const short n_, const int batch_size_, const short elems_per_thread_, const uint3 elem_, const uint3 grid_, const bool inv_)
Definition readwrite.h:51
 
threadgroup float2 * buf
Definition readwrite.h:37
 
uint3 elem
Definition readwrite.h:42
 
int elems_per_thread
Definition readwrite.h:41
 
int strided_device_idx
Definition readwrite.h:48
 
int threads_per_tg
Definition readwrite.h:44
 
int n
Definition readwrite.h:39
 
int batch_size
Definition readwrite.h:40
 
METAL_FUNC float2 post_in(float elem) const
Definition readwrite.h:82
 
bool inv
Definition readwrite.h:45
 
METAL_FUNC void write_strided(int stride, int overall_n)
Definition readwrite.h:210
 
METAL_FUNC void compute_strided_indices(int stride, int overall_n)
Definition readwrite.h:180
 
METAL_FUNC float2 pre_out(float2 elem) const
Definition readwrite.h:86
 
METAL_FUNC void write_padded(int length, const device float2 *w_k) const
Definition readwrite.h:163
 
METAL_FUNC void load_strided(int stride, int overall_n)
Definition readwrite.h:202
 
METAL_FUNC float2 post_in(float2 elem) const
Definition readwrite.h:77
 
const device in_T * in
Definition readwrite.h:36
 
device out_T * out
Definition readwrite.h:38
 
METAL_FUNC void write() const
Definition readwrite.h:123
 
uint3 grid
Definition readwrite.h:43
 
int strided_shared_idx
Definition readwrite.h:49
 
METAL_FUNC void load_padded(int length, const device float2 *w_k) const
Definition readwrite.h:146