MLX
Loading...
Searching...
No Matches
Macros | Typedefs | Functions | Variables
fft.h File Reference
#include <metal_common>
#include "mlx/backend/metal/kernels/fft/radix.h"
#include "mlx/backend/metal/kernels/fft/readwrite.h"
#include "mlx/backend/metal/kernels/steel/defines.h"

Go to the source code of this file.

Macros

#define MAX_RADIX   13
 
#define MAX_OUTPUT_SIZE   18
 
#define RADIX_STEP(radix, radix_func, num_steps)
 

Typedefs

typedef void(* RadixFunc) (thread float2 *, thread float2 *)
 

Functions

template<int radix, RadixFunc radix_func>
METAL_FUNC void radix_butterfly (int i, int p, thread float2 *x, thread short *indices, thread float2 *y)
 
template<int radix, RadixFunc radix_func>
METAL_FUNC void radix_n_steps (int i, thread int *p, int m, int n, int num_steps, thread float2 *inputs, thread short *indices, thread float2 *values, threadgroup float2 *buf)
 
template<bool rader = false>
METAL_FUNC void perform_fft (int fft_idx, thread int *p, int m, int n, threadgroup float2 *buf)
 
template<int tg_mem_size, typename in_T , typename out_T >
void fft (const device in_T *in, device out_T *out, constant const int &n, constant const int &batch_size, uint3 elem, uint3 grid)
 
template<int tg_mem_size, typename in_T , typename out_T >
void rader_fft (const device in_T *in, device out_T *out, const device float2 *raders_b_q, const device short *raders_g_q, const device short *raders_g_minus_q, constant const int &n, constant const int &batch_size, constant const int &rader_n, uint3 elem, uint3 grid)
 
template<int tg_mem_size, typename in_T , typename out_T >
void bluestein_fft (const device in_T *in, device out_T *out, const device float2 *w_q, const device float2 *w_k, constant const int &length, constant const int &n, constant const int &batch_size, uint3 elem, uint3 grid)
 
template<int tg_mem_size, typename in_T , typename out_T , int step, bool real = false>
void four_step_fft (const device in_T *in, device out_T *out, constant const int &n1, constant const int &n2, constant const int &batch_size, uint3 elem, uint3 grid)
 

Variables

STEEL_CONST bool inv_
 
STEEL_CONST bool is_power_of_2_
 
STEEL_CONST int elems_per_thread_
 
STEEL_CONST int rader_m_
 
STEEL_CONST int radix_13_steps_
 
STEEL_CONST int radix_11_steps_
 
STEEL_CONST int radix_8_steps_
 
STEEL_CONST int radix_7_steps_
 
STEEL_CONST int radix_6_steps_
 
STEEL_CONST int radix_5_steps_
 
STEEL_CONST int radix_4_steps_
 
STEEL_CONST int radix_3_steps_
 
STEEL_CONST int radix_2_steps_
 
STEEL_CONST int rader_13_steps_
 
STEEL_CONST int rader_11_steps_
 
STEEL_CONST int rader_8_steps_
 
STEEL_CONST int rader_7_steps_
 
STEEL_CONST int rader_6_steps_
 
STEEL_CONST int rader_5_steps_
 
STEEL_CONST int rader_4_steps_
 
STEEL_CONST int rader_3_steps_
 
STEEL_CONST int rader_2_steps_
 

Macro Definition Documentation

◆ MAX_OUTPUT_SIZE

#define MAX_OUTPUT_SIZE   18

◆ MAX_RADIX

#define MAX_RADIX   13

◆ RADIX_STEP

#define RADIX_STEP ( radix,
radix_func,
num_steps )
Value:
radix_n_steps<radix, radix_func>( \
fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
MTL::Buffer * buf
Definition allocator.h:38

Typedef Documentation

◆ RadixFunc

typedef void(* RadixFunc) (thread float2 *, thread float2 *)

Function Documentation

◆ bluestein_fft()

template<int tg_mem_size, typename in_T , typename out_T >
void bluestein_fft ( const device in_T * in,
device out_T * out,
const device float2 * w_q,
const device float2 * w_k,
constant const int & length,
constant const int & n,
constant const int & batch_size,
uint3 elem,
uint3 grid )

◆ fft()

template<int tg_mem_size, typename in_T , typename out_T >
void fft ( const device in_T * in,
device out_T * out,
constant const int & n,
constant const int & batch_size,
uint3 elem,
uint3 grid )

◆ four_step_fft()

template<int tg_mem_size, typename in_T , typename out_T , int step, bool real = false>
void four_step_fft ( const device in_T * in,
device out_T * out,
constant const int & n1,
constant const int & n2,
constant const int & batch_size,
uint3 elem,
uint3 grid )

◆ perform_fft()

template<bool rader = false>
METAL_FUNC void perform_fft ( int fft_idx,
thread int * p,
int m,
int n,
threadgroup float2 * buf )

◆ rader_fft()

template<int tg_mem_size, typename in_T , typename out_T >
void rader_fft ( const device in_T * in,
device out_T * out,
const device float2 * raders_b_q,
const device short * raders_g_q,
const device short * raders_g_minus_q,
constant const int & n,
constant const int & batch_size,
constant const int & rader_n,
uint3 elem,
uint3 grid )

◆ radix_butterfly()

template<int radix, RadixFunc radix_func>
METAL_FUNC void radix_butterfly ( int i,
int p,
thread float2 * x,
thread short * indices,
thread float2 * y )

◆ radix_n_steps()

template<int radix, RadixFunc radix_func>
METAL_FUNC void radix_n_steps ( int i,
thread int * p,
int m,
int n,
int num_steps,
thread float2 * inputs,
thread short * indices,
thread float2 * values,
threadgroup float2 * buf )

Variable Documentation

◆ elems_per_thread_

STEEL_CONST int elems_per_thread_

◆ inv_

STEEL_CONST bool inv_

◆ is_power_of_2_

STEEL_CONST bool is_power_of_2_

◆ rader_11_steps_

STEEL_CONST int rader_11_steps_

◆ rader_13_steps_

STEEL_CONST int rader_13_steps_

◆ rader_2_steps_

STEEL_CONST int rader_2_steps_

◆ rader_3_steps_

STEEL_CONST int rader_3_steps_

◆ rader_4_steps_

STEEL_CONST int rader_4_steps_

◆ rader_5_steps_

STEEL_CONST int rader_5_steps_

◆ rader_6_steps_

STEEL_CONST int rader_6_steps_

◆ rader_7_steps_

STEEL_CONST int rader_7_steps_

◆ rader_8_steps_

STEEL_CONST int rader_8_steps_

◆ rader_m_

STEEL_CONST int rader_m_

◆ radix_11_steps_

STEEL_CONST int radix_11_steps_

◆ radix_13_steps_

STEEL_CONST int radix_13_steps_

◆ radix_2_steps_

STEEL_CONST int radix_2_steps_

◆ radix_3_steps_

STEEL_CONST int radix_3_steps_

◆ radix_4_steps_

STEEL_CONST int radix_4_steps_

◆ radix_5_steps_

STEEL_CONST int radix_5_steps_

◆ radix_6_steps_

STEEL_CONST int radix_6_steps_

◆ radix_7_steps_

STEEL_CONST int radix_7_steps_

◆ radix_8_steps_

STEEL_CONST int radix_8_steps_