MLX
|
#include <metal_common>
#include "mlx/backend/metal/kernels/fft/radix.h"
#include "mlx/backend/metal/kernels/fft/readwrite.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
Go to the source code of this file.
Macros | |
#define | MAX_RADIX 13 |
#define | MAX_OUTPUT_SIZE 18 |
#define | RADIX_STEP(radix, radix_func, num_steps) |
Typedefs | |
typedef void(* | RadixFunc) (thread float2 *, thread float2 *) |
Functions | |
template<int radix, RadixFunc radix_func> | |
METAL_FUNC void | radix_butterfly (int i, int p, thread float2 *x, thread short *indices, thread float2 *y) |
template<int radix, RadixFunc radix_func> | |
METAL_FUNC void | radix_n_steps (int i, thread int *p, int m, int n, int num_steps, thread float2 *inputs, thread short *indices, thread float2 *values, threadgroup float2 *buf) |
template<bool rader = false> | |
METAL_FUNC void | perform_fft (int fft_idx, thread int *p, int m, int n, threadgroup float2 *buf) |
template<int tg_mem_size, typename in_T , typename out_T > | |
void | fft (const device in_T *in, device out_T *out, constant const int &n, constant const int &batch_size, uint3 elem, uint3 grid) |
template<int tg_mem_size, typename in_T , typename out_T > | |
void | rader_fft (const device in_T *in, device out_T *out, const device float2 *raders_b_q, const device short *raders_g_q, const device short *raders_g_minus_q, constant const int &n, constant const int &batch_size, constant const int &rader_n, uint3 elem, uint3 grid) |
template<int tg_mem_size, typename in_T , typename out_T > | |
void | bluestein_fft (const device in_T *in, device out_T *out, const device float2 *w_q, const device float2 *w_k, constant const int &length, constant const int &n, constant const int &batch_size, uint3 elem, uint3 grid) |
template<int tg_mem_size, typename in_T , typename out_T , int step, bool real = false> | |
void | four_step_fft (const device in_T *in, device out_T *out, constant const int &n1, constant const int &n2, constant const int &batch_size, uint3 elem, uint3 grid) |
Variables | |
STEEL_CONST bool | inv_ |
STEEL_CONST bool | is_power_of_2_ |
STEEL_CONST int | elems_per_thread_ |
STEEL_CONST int | rader_m_ |
STEEL_CONST int | radix_13_steps_ |
STEEL_CONST int | radix_11_steps_ |
STEEL_CONST int | radix_8_steps_ |
STEEL_CONST int | radix_7_steps_ |
STEEL_CONST int | radix_6_steps_ |
STEEL_CONST int | radix_5_steps_ |
STEEL_CONST int | radix_4_steps_ |
STEEL_CONST int | radix_3_steps_ |
STEEL_CONST int | radix_2_steps_ |
STEEL_CONST int | rader_13_steps_ |
STEEL_CONST int | rader_11_steps_ |
STEEL_CONST int | rader_8_steps_ |
STEEL_CONST int | rader_7_steps_ |
STEEL_CONST int | rader_6_steps_ |
STEEL_CONST int | rader_5_steps_ |
STEEL_CONST int | rader_4_steps_ |
STEEL_CONST int | rader_3_steps_ |
STEEL_CONST int | rader_2_steps_ |
#define MAX_OUTPUT_SIZE 18 |
#define MAX_RADIX 13 |
#define RADIX_STEP | ( | radix, | |
radix_func, | |||
num_steps ) |
typedef void(* RadixFunc) (thread float2 *, thread float2 *) |
void bluestein_fft | ( | const device in_T * | in, |
device out_T * | out, | ||
const device float2 * | w_q, | ||
const device float2 * | w_k, | ||
constant const int & | length, | ||
constant const int & | n, | ||
constant const int & | batch_size, | ||
uint3 | elem, | ||
uint3 | grid ) |
void fft | ( | const device in_T * | in, |
device out_T * | out, | ||
constant const int & | n, | ||
constant const int & | batch_size, | ||
uint3 | elem, | ||
uint3 | grid ) |
void four_step_fft | ( | const device in_T * | in, |
device out_T * | out, | ||
constant const int & | n1, | ||
constant const int & | n2, | ||
constant const int & | batch_size, | ||
uint3 | elem, | ||
uint3 | grid ) |
METAL_FUNC void perform_fft | ( | int | fft_idx, |
thread int * | p, | ||
int | m, | ||
int | n, | ||
threadgroup float2 * | buf ) |
void rader_fft | ( | const device in_T * | in, |
device out_T * | out, | ||
const device float2 * | raders_b_q, | ||
const device short * | raders_g_q, | ||
const device short * | raders_g_minus_q, | ||
constant const int & | n, | ||
constant const int & | batch_size, | ||
constant const int & | rader_n, | ||
uint3 | elem, | ||
uint3 | grid ) |
METAL_FUNC void radix_butterfly | ( | int | i, |
int | p, | ||
thread float2 * | x, | ||
thread short * | indices, | ||
thread float2 * | y ) |
METAL_FUNC void radix_n_steps | ( | int | i, |
thread int * | p, | ||
int | m, | ||
int | n, | ||
int | num_steps, | ||
thread float2 * | inputs, | ||
thread short * | indices, | ||
thread float2 * | values, | ||
threadgroup float2 * | buf ) |
STEEL_CONST int elems_per_thread_ |
STEEL_CONST bool inv_ |
STEEL_CONST bool is_power_of_2_ |
STEEL_CONST int rader_11_steps_ |
STEEL_CONST int rader_13_steps_ |
STEEL_CONST int rader_2_steps_ |
STEEL_CONST int rader_3_steps_ |
STEEL_CONST int rader_4_steps_ |
STEEL_CONST int rader_5_steps_ |
STEEL_CONST int rader_6_steps_ |
STEEL_CONST int rader_7_steps_ |
STEEL_CONST int rader_8_steps_ |
STEEL_CONST int rader_m_ |
STEEL_CONST int radix_11_steps_ |
STEEL_CONST int radix_13_steps_ |
STEEL_CONST int radix_2_steps_ |
STEEL_CONST int radix_3_steps_ |
STEEL_CONST int radix_4_steps_ |
STEEL_CONST int radix_5_steps_ |
STEEL_CONST int radix_6_steps_ |
STEEL_CONST int radix_7_steps_ |
STEEL_CONST int radix_8_steps_ |