MLX
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
ReadWriter< in_T, out_T, step, four_step_real > Struct Template Reference

#include <readwrite.h>

Public Member Functions

METAL_FUNC ReadWriter (const device in_T *in_, threadgroup float2 *buf_, device out_T *out_, const short n_, const int batch_size_, const short elems_per_thread_, const uint3 elem_, const uint3 grid_, const bool inv_)
 
METAL_FUNC float2 post_in (float2 elem) const
 
METAL_FUNC float2 post_in (float elem) const
 
METAL_FUNC float2 pre_out (float2 elem) const
 
METAL_FUNC float2 pre_out (float2 elem, int length) const
 
METAL_FUNC bool out_of_bounds () const
 
METAL_FUNC void load () const
 
METAL_FUNC void write () const
 
METAL_FUNC void load_padded (int length, const device float2 *w_k) const
 
METAL_FUNC void write_padded (int length, const device float2 *w_k) const
 
METAL_FUNC void compute_strided_indices (int stride, int overall_n)
 
METAL_FUNC void load_strided (int stride, int overall_n)
 
METAL_FUNC void write_strided (int stride, int overall_n)
 
METAL_FUNC void load_strided (int stride, int overall_n)
 
METAL_FUNC void write_strided (int stride, int overall_n)
 
METAL_FUNC bool out_of_bounds () const
 
METAL_FUNC void load () const
 
METAL_FUNC void write () const
 
METAL_FUNC void load_padded (int length, const device float2 *w_k) const
 
METAL_FUNC void write_padded (int length, const device float2 *w_k) const
 
METAL_FUNC bool out_of_bounds () const
 
METAL_FUNC void load () const
 
METAL_FUNC void write () const
 
METAL_FUNC void load_padded (int length, const device float2 *w_k) const
 
METAL_FUNC void write_padded (int length, const device float2 *w_k) const
 
METAL_FUNC void load_strided (int stride, int overall_n)
 
METAL_FUNC void write_strided (int stride, int overall_n)
 
METAL_FUNC void load_strided (int stride, int overall_n)
 
METAL_FUNC void load_strided (int stride, int overall_n)
 
METAL_FUNC void write_strided (int stride, int overall_n)
 

Public Attributes

const device in_T * in
 
threadgroup float2 * buf
 
device out_T * out
 
int n
 
int batch_size
 
int elems_per_thread
 
uint3 elem
 
uint3 grid
 
int threads_per_tg
 
bool inv
 
int strided_device_idx = 0
 
int strided_shared_idx = 0
 

Constructor & Destructor Documentation

◆ ReadWriter()

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC ReadWriter< in_T, out_T, step, four_step_real >::ReadWriter ( const device in_T * in_,
threadgroup float2 * buf_,
device out_T * out_,
const short n_,
const int batch_size_,
const short elems_per_thread_,
const uint3 elem_,
const uint3 grid_,
const bool inv_ )
inline

Member Function Documentation

◆ compute_strided_indices()

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC void ReadWriter< in_T, out_T, step, four_step_real >::compute_strided_indices ( int stride,
int overall_n )
inline

◆ load() [1/3]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC void ReadWriter< in_T, out_T, step, four_step_real >::load ( ) const
inline

◆ load() [2/3]

METAL_FUNC void ReadWriter< float, float2 >::load ( ) const

◆ load() [3/3]

METAL_FUNC void ReadWriter< float2, float >::load ( ) const

◆ load_padded() [1/3]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC void ReadWriter< in_T, out_T, step, four_step_real >::load_padded ( int length,
const device float2 * w_k ) const
inline

◆ load_padded() [2/3]

METAL_FUNC void ReadWriter< float, float2 >::load_padded ( int length,
const device float2 * w_k ) const

◆ load_padded() [3/3]

METAL_FUNC void ReadWriter< float2, float >::load_padded ( int length,
const device float2 * w_k ) const

◆ load_strided() [1/5]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC void ReadWriter< in_T, out_T, step, four_step_real >::load_strided ( int stride,
int overall_n )
inline

◆ load_strided() [2/5]

METAL_FUNC void ReadWriter< float2, float2, 1 >::load_strided ( int stride,
int overall_n )

◆ load_strided() [3/5]

METAL_FUNC void ReadWriter< float2, float2, 1, true >::load_strided ( int stride,
int overall_n )

◆ load_strided() [4/5]

METAL_FUNC void ReadWriter< float2, float2, 0, true >::load_strided ( int stride,
int overall_n )

◆ load_strided() [5/5]

METAL_FUNC void ReadWriter< float2, float, 1, true >::load_strided ( int stride,
int overall_n )

◆ out_of_bounds() [1/3]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC bool ReadWriter< in_T, out_T, step, four_step_real >::out_of_bounds ( ) const
inline

◆ out_of_bounds() [2/3]

METAL_FUNC bool ReadWriter< float, float2 >::out_of_bounds ( ) const

◆ out_of_bounds() [3/3]

METAL_FUNC bool ReadWriter< float2, float >::out_of_bounds ( ) const

◆ post_in() [1/2]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC float2 ReadWriter< in_T, out_T, step, four_step_real >::post_in ( float elem) const
inline

◆ post_in() [2/2]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC float2 ReadWriter< in_T, out_T, step, four_step_real >::post_in ( float2 elem) const
inline

◆ pre_out() [1/2]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC float2 ReadWriter< in_T, out_T, step, four_step_real >::pre_out ( float2 elem) const
inline

◆ pre_out() [2/2]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC float2 ReadWriter< in_T, out_T, step, four_step_real >::pre_out ( float2 elem,
int length ) const
inline

◆ write() [1/3]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC void ReadWriter< in_T, out_T, step, four_step_real >::write ( ) const
inline

◆ write() [2/3]

METAL_FUNC void ReadWriter< float, float2 >::write ( ) const

◆ write() [3/3]

METAL_FUNC void ReadWriter< float2, float >::write ( ) const

◆ write_padded() [1/3]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC void ReadWriter< in_T, out_T, step, four_step_real >::write_padded ( int length,
const device float2 * w_k ) const
inline

◆ write_padded() [2/3]

METAL_FUNC void ReadWriter< float, float2 >::write_padded ( int length,
const device float2 * w_k ) const

◆ write_padded() [3/3]

METAL_FUNC void ReadWriter< float2, float >::write_padded ( int length,
const device float2 * w_k ) const

◆ write_strided() [1/4]

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
METAL_FUNC void ReadWriter< in_T, out_T, step, four_step_real >::write_strided ( int stride,
int overall_n )
inline

◆ write_strided() [2/4]

METAL_FUNC void ReadWriter< float2, float2, 1 >::write_strided ( int stride,
int overall_n )

◆ write_strided() [3/4]

METAL_FUNC void ReadWriter< float2, float2, 1, true >::write_strided ( int stride,
int overall_n )

◆ write_strided() [4/4]

METAL_FUNC void ReadWriter< float2, float, 1, true >::write_strided ( int stride,
int overall_n )

Member Data Documentation

◆ batch_size

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
int ReadWriter< in_T, out_T, step, four_step_real >::batch_size

◆ buf

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
threadgroup float2* ReadWriter< in_T, out_T, step, four_step_real >::buf

◆ elem

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
uint3 ReadWriter< in_T, out_T, step, four_step_real >::elem

◆ elems_per_thread

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
int ReadWriter< in_T, out_T, step, four_step_real >::elems_per_thread

◆ grid

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
uint3 ReadWriter< in_T, out_T, step, four_step_real >::grid

◆ in

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
const device in_T* ReadWriter< in_T, out_T, step, four_step_real >::in

◆ inv

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
bool ReadWriter< in_T, out_T, step, four_step_real >::inv

◆ n

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
int ReadWriter< in_T, out_T, step, four_step_real >::n

◆ out

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
device out_T* ReadWriter< in_T, out_T, step, four_step_real >::out

◆ strided_device_idx

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
int ReadWriter< in_T, out_T, step, four_step_real >::strided_device_idx = 0

◆ strided_shared_idx

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
int ReadWriter< in_T, out_T, step, four_step_real >::strided_shared_idx = 0

◆ threads_per_tg

template<typename in_T , typename out_T , int step = 0, bool four_step_real = false>
int ReadWriter< in_T, out_T, step, four_step_real >::threads_per_tg

The documentation for this struct was generated from the following file: