20#define MAX_OUTPUT_SIZE 18 
   50typedef void (*
RadixFunc)(thread float2*, thread float2*);
 
   53template <
int radix, RadixFunc radix_func>
 
   58    thread 
short* indices,
 
   66  constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
 
   68    constexpr short power = __builtin_ctz(radix);
 
   70    j = ((i - k) << power) + k;
 
   73    j = (i / p) * radix * p + k;
 
   79    float2 twiddle = twiddle_1;
 
   83    for (
int t = 2; t < radix; t++) {
 
   92  for (
int t = 0; t < radix; t++) {
 
   93    indices[t] = j + t * p;
 
 
   99template <
int radix, RadixFunc radix_func>
 
  106    thread float2* inputs,
 
  107    thread 
short* indices,
 
  108    thread float2* values,
 
  109    threadgroup float2* 
buf) {
 
  122  for (
int s = 0; s < num_steps; s++) {
 
  123    for (
int t = 0; t < max_radices_per_thread; t++) {
 
  126        for (
int r = 0; r < radix; r++) {
 
  127          inputs[r] = 
buf[index + r * m_r];
 
  130            index, *p, inputs, indices + t * radix, values + t * radix);
 
  135    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  137    for (
int t = 0; t < max_radices_per_thread; t++) {
 
  140        for (
int r = 0; r < radix; r++) {
 
  141          r_index = t * radix + r;
 
  142          buf[indices[r_index]] = values[r_index];
 
  148    threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
  153#define RADIX_STEP(radix, radix_func, num_steps) \ 
  154  radix_n_steps<radix, radix_func>(              \ 
  155      fft_idx, p, m, n, num_steps, inputs, indices, values, buf); 
 
  157template <
bool rader = false>
 
  179template <
int tg_mem_size, 
typename in_T, 
typename out_T>
 
  181    const device in_T* in [[buffer(0)]],
 
  182    device out_T* out [[buffer(1)]],
 
  183    constant 
const int& n,
 
  184    constant 
const int& batch_size,
 
  185    uint3 elem [[thread_position_in_grid]],
 
  186    uint3 grid [[threads_per_grid]]) {
 
  187  threadgroup float2 shared_in[tg_mem_size];
 
  200  if (read_writer.out_of_bounds()) {
 
  205  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  208  int fft_idx = elem.z; 
 
  210  int tg_idx = elem.y * n; 
 
  211  threadgroup float2* 
buf = &shared_in[tg_idx];
 
 
  218template <
int tg_mem_size, 
typename in_T, 
typename out_T>
 
  220    const device in_T* in [[buffer(0)]],
 
  221    device out_T* out [[buffer(1)]],
 
  222    const device float2* raders_b_q [[buffer(2)]],
 
  223    const device 
short* raders_g_q [[buffer(3)]],
 
  224    const device 
short* raders_g_minus_q [[buffer(4)]],
 
  225    constant 
const int& n,
 
  226    constant 
const int& batch_size,
 
  227    constant 
const int& rader_n,
 
  228    uint3 elem [[thread_position_in_grid]],
 
  229    uint3 grid [[threads_per_grid]]) {
 
  253  threadgroup float2 shared_in[tg_mem_size];
 
  266  if (read_writer.out_of_bounds()) {
 
  271  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  276  int fft_idx = elem.z;
 
  277  int tg_idx = elem.y * n;
 
  278  threadgroup float2* 
buf = &shared_in[tg_idx];
 
  290  float2 x_0[2] = {
buf[x_0_index], 
buf[x_0_index + 1]};
 
  294  int max_index = n - rader_m - 1;
 
  297    short g_q = raders_g_q[index / rader_m];
 
  298    temp[e] = 
buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
 
  301  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  305    buf[index + rader_m] = temp[e];
 
  308  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  316  int x_sum_index = 
metal::min(fft_idx, rader_m - 1);
 
  317  buf[x_sum_index] = 
buf[rader_m + x_sum_index * (rader_n - 1)];
 
  319  float2 inv = {1.0f, -1.0f};
 
  322    short interleaved_index =
 
  323        index / rader_m + (index % rader_m) * (rader_n - 1);
 
  325        buf[rader_m + interleaved_index],
 
  326        raders_b_q[interleaved_index % (rader_n - 1)]);
 
  329  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  333    buf[rader_m + index] = temp[e] * inv;
 
  336  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  342  float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
 
  346    short diff_index = index / (rader_n - 1) - x_0_index;
 
  347    temp[e] = 
buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
 
  351  float2 x_sum = 
buf[x_0_index] + x_0[0];
 
  353  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  357    short g_q_index = index % (rader_n - 1);
 
  358    short g_q = raders_g_minus_q[g_q_index];
 
  359    short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
 
  360    buf[out_index] = temp[e];
 
  363  buf[x_0_index * rader_n] = x_sum;
 
  365  threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
  373template <
int tg_mem_size, 
typename in_T, 
typename out_T>
 
  375    const device in_T* in [[buffer(0)]],
 
  376    device out_T* out [[buffer(1)]],
 
  377    const device float2* w_q [[buffer(2)]],
 
  378    const device float2* w_k [[buffer(3)]],
 
  379    constant 
const int& length,
 
  380    constant 
const int& n,
 
  381    constant 
const int& batch_size,
 
  382    uint3 elem [[thread_position_in_grid]],
 
  383    uint3 grid [[threads_per_grid]]) {
 
  393  threadgroup float2 shared_in[tg_mem_size];
 
  406  if (read_writer.out_of_bounds()) {
 
  409  read_writer.load_padded(length, w_k);
 
  411  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  414  int fft_idx = elem.z; 
 
  416  int tg_idx = elem.y * n; 
 
  417  threadgroup float2* 
buf = &shared_in[tg_idx];
 
  422  float2 inv = float2(1.0f, -1.0f);
 
  424    int index = fft_idx + t * m;
 
  428  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  434  read_writer.write_padded(length, w_k);
 
 
  444    const device in_T* in [[buffer(0)]],
 
  445    device out_T* out [[buffer(1)]],
 
  446    constant 
const int& n1,
 
  447    constant 
const int& n2,
 
  448    constant 
const int& batch_size,
 
  449    uint3 elem [[thread_position_in_grid]],
 
  450    uint3 grid [[threads_per_grid]]) {
 
  452  int overall_n = n1 * n2;
 
  453  int n = step == 0 ? n1 : n2;
 
  454  int stride = step == 0 ? n2 : n1;
 
  458  int fft_idx = elem.z;
 
  460  threadgroup float2 shared_in[tg_mem_size];
 
  461  threadgroup float2* 
buf = &shared_in[elem.y * n];
 
  464  read_writer_t read_writer = read_writer_t(
 
  475  if (read_writer.out_of_bounds()) {
 
  478  read_writer.load_strided(stride, overall_n);
 
  480  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  485  read_writer.write_strided(stride, overall_n);
 
 
array real(const array &a, StreamOrDevice s={})
 
METAL_FUNC void radix5(thread float2 *x, thread float2 *y)
Definition radix.h:69
 
METAL_FUNC void radix4(thread float2 *x, thread float2 *y)
Definition radix.h:56
 
METAL_FUNC void radix11(thread float2 *x, thread float2 *y)
Definition radix.h:201
 
METAL_FUNC void radix3(thread float2 *x, thread float2 *y)
Definition radix.h:41
 
METAL_FUNC float2 complex_mul(float2 a, float2 b)
Definition radix.h:19
 
METAL_FUNC void radix8(thread float2 *x, thread float2 *y)
Definition radix.h:151
 
METAL_FUNC void radix7(thread float2 *x, thread float2 *y)
Definition radix.h:122
 
METAL_FUNC void radix2(thread float2 *x, thread float2 *y)
Definition radix.h:36
 
METAL_FUNC void radix13(thread float2 *x, thread float2 *y)
Definition radix.h:290
 
METAL_FUNC float2 get_twiddle(int k, int p)
Definition radix.h:29
 
METAL_FUNC void radix6(thread float2 *x, thread float2 *y)
Definition radix.h:96
 
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
 
#define STEEL_CONST
Definition defines.h:3
 
Definition readwrite.h:35