#include <metal_common>#include "mlx/backend/metal/kernels/fft/radix.h"#include "mlx/backend/metal/kernels/fft/readwrite.h"#include "mlx/backend/metal/kernels/steel/defines.h"Go to the source code of this file.
Macros | |
| #define | MAX_RADIX 13 | 
| #define | MAX_OUTPUT_SIZE 18 | 
| #define | RADIX_STEP(radix, radix_func, num_steps) | 
Typedefs | |
| typedef void(* | RadixFunc) (thread float2 *, thread float2 *) | 
Functions | |
| template<int radix, RadixFunc radix_func> | |
| METAL_FUNC void | radix_butterfly (int i, int p, thread float2 *x, thread short *indices, thread float2 *y) | 
| template<int radix, RadixFunc radix_func> | |
| METAL_FUNC void | radix_n_steps (int i, thread int *p, int m, int n, int num_steps, thread float2 *inputs, thread short *indices, thread float2 *values, threadgroup float2 *buf) | 
| template<bool rader = false> | |
| METAL_FUNC void | perform_fft (int fft_idx, thread int *p, int m, int n, threadgroup float2 *buf) | 
| template<int tg_mem_size, typename in_T, typename out_T> | |
| void | fft (const device in_T *in, device out_T *out, constant const int &n, constant const int &batch_size, uint3 elem, uint3 grid) | 
| template<int tg_mem_size, typename in_T, typename out_T> | |
| void | rader_fft (const device in_T *in, device out_T *out, const device float2 *raders_b_q, const device short *raders_g_q, const device short *raders_g_minus_q, constant const int &n, constant const int &batch_size, constant const int &rader_n, uint3 elem, uint3 grid) | 
| template<int tg_mem_size, typename in_T, typename out_T> | |
| void | bluestein_fft (const device in_T *in, device out_T *out, const device float2 *w_q, const device float2 *w_k, constant const int &length, constant const int &n, constant const int &batch_size, uint3 elem, uint3 grid) | 
| template<int tg_mem_size, typename in_T, typename out_T, int step, bool real = false> | |
| void | four_step_fft (const device in_T *in, device out_T *out, constant const int &n1, constant const int &n2, constant const int &batch_size, uint3 elem, uint3 grid) | 
Variables | |
| static constant constexpr const bool | inv_ | 
| static constant constexpr const bool | is_power_of_2_ | 
| static constant constexpr const int | elems_per_thread_ | 
| static constant constexpr const int | rader_m_ | 
| static constant constexpr const int | radix_13_steps_ | 
| static constant constexpr const int | radix_11_steps_ | 
| static constant constexpr const int | radix_8_steps_ | 
| static constant constexpr const int | radix_7_steps_ | 
| static constant constexpr const int | radix_6_steps_ | 
| static constant constexpr const int | radix_5_steps_ | 
| static constant constexpr const int | radix_4_steps_ | 
| static constant constexpr const int | radix_3_steps_ | 
| static constant constexpr const int | radix_2_steps_ | 
| static constant constexpr const int | rader_13_steps_ | 
| static constant constexpr const int | rader_11_steps_ | 
| static constant constexpr const int | rader_8_steps_ | 
| static constant constexpr const int | rader_7_steps_ | 
| static constant constexpr const int | rader_6_steps_ | 
| static constant constexpr const int | rader_5_steps_ | 
| static constant constexpr const int | rader_4_steps_ | 
| static constant constexpr const int | rader_3_steps_ | 
| static constant constexpr const int | rader_2_steps_ | 
| #define MAX_OUTPUT_SIZE 18 | 
| #define MAX_RADIX 13 | 
| #define RADIX_STEP | ( | radix, | |
| radix_func, | |||
| num_steps ) | 
| typedef void(* RadixFunc) (thread float2 *, thread float2 *) | 
| void bluestein_fft | ( | const device in_T * | in, | 
| device out_T * | out, | ||
| const device float2 * | w_q, | ||
| const device float2 * | w_k, | ||
| constant const int & | length, | ||
| constant const int & | n, | ||
| constant const int & | batch_size, | ||
| uint3 | elem, | ||
| uint3 | grid ) | 
| void fft | ( | const device in_T * | in, | 
| device out_T * | out, | ||
| constant const int & | n, | ||
| constant const int & | batch_size, | ||
| uint3 | elem, | ||
| uint3 | grid ) | 
| void four_step_fft | ( | const device in_T * | in, | 
| device out_T * | out, | ||
| constant const int & | n1, | ||
| constant const int & | n2, | ||
| constant const int & | batch_size, | ||
| uint3 | elem, | ||
| uint3 | grid ) | 
| METAL_FUNC void perform_fft | ( | int | fft_idx, | 
| thread int * | p, | ||
| int | m, | ||
| int | n, | ||
| threadgroup float2 * | buf ) | 
| void rader_fft | ( | const device in_T * | in, | 
| device out_T * | out, | ||
| const device float2 * | raders_b_q, | ||
| const device short * | raders_g_q, | ||
| const device short * | raders_g_minus_q, | ||
| constant const int & | n, | ||
| constant const int & | batch_size, | ||
| constant const int & | rader_n, | ||
| uint3 | elem, | ||
| uint3 | grid ) | 
| METAL_FUNC void radix_butterfly | ( | int | i, | 
| int | p, | ||
| thread float2 * | x, | ||
| thread short * | indices, | ||
| thread float2 * | y ) | 
| METAL_FUNC void radix_n_steps | ( | int | i, | 
| thread int * | p, | ||
| int | m, | ||
| int | n, | ||
| int | num_steps, | ||
| thread float2 * | inputs, | ||
| thread short * | indices, | ||
| thread float2 * | values, | ||
| threadgroup float2 * | buf ) | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr | 
      
  | 
  staticconstexpr |