MLX
|
#include <readwrite.h>
Public Member Functions | |
METAL_FUNC | ReadWriter (const device in_T *in_, threadgroup float2 *buf_, device out_T *out_, const short n_, const int batch_size_, const short elems_per_thread_, const uint3 elem_, const uint3 grid_, const bool inv_) |
METAL_FUNC float2 | post_in (float2 elem) const |
METAL_FUNC float2 | post_in (float elem) const |
METAL_FUNC float2 | pre_out (float2 elem) const |
METAL_FUNC float2 | pre_out (float2 elem, int length) const |
METAL_FUNC bool | out_of_bounds () const |
METAL_FUNC void | load () const |
METAL_FUNC void | write () const |
METAL_FUNC void | load_padded (int length, const device float2 *w_k) const |
METAL_FUNC void | write_padded (int length, const device float2 *w_k) const |
METAL_FUNC void | compute_strided_indices (int stride, int overall_n) |
METAL_FUNC void | load_strided (int stride, int overall_n) |
METAL_FUNC void | write_strided (int stride, int overall_n) |
METAL_FUNC void | load_strided (int stride, int overall_n) |
METAL_FUNC void | write_strided (int stride, int overall_n) |
METAL_FUNC bool | out_of_bounds () const |
METAL_FUNC void | load () const |
METAL_FUNC void | write () const |
METAL_FUNC void | load_padded (int length, const device float2 *w_k) const |
METAL_FUNC void | write_padded (int length, const device float2 *w_k) const |
METAL_FUNC bool | out_of_bounds () const |
METAL_FUNC void | load () const |
METAL_FUNC void | write () const |
METAL_FUNC void | load_padded (int length, const device float2 *w_k) const |
METAL_FUNC void | write_padded (int length, const device float2 *w_k) const |
METAL_FUNC void | load_strided (int stride, int overall_n) |
METAL_FUNC void | write_strided (int stride, int overall_n) |
METAL_FUNC void | load_strided (int stride, int overall_n) |
METAL_FUNC void | load_strided (int stride, int overall_n) |
METAL_FUNC void | write_strided (int stride, int overall_n) |
Public Attributes | |
const device in_T * | in |
threadgroup float2 * | buf |
device out_T * | out |
int | n |
int | batch_size |
int | elems_per_thread |
uint3 | elem |
uint3 | grid |
int | threads_per_tg |
bool | inv |
int | strided_device_idx = 0 |
int | strided_shared_idx = 0 |
|
inline |
|
inline |
|
inline |
METAL_FUNC void ReadWriter< float, float2 >::load | ( | ) | const |
METAL_FUNC void ReadWriter< float2, float >::load | ( | ) | const |
|
inline |
METAL_FUNC void ReadWriter< float, float2 >::load_padded | ( | int | length, |
const device float2 * | w_k ) const |
METAL_FUNC void ReadWriter< float2, float >::load_padded | ( | int | length, |
const device float2 * | w_k ) const |
|
inline |
METAL_FUNC void ReadWriter< float2, float2, 1 >::load_strided | ( | int | stride, |
int | overall_n ) |
METAL_FUNC void ReadWriter< float2, float2, 1, true >::load_strided | ( | int | stride, |
int | overall_n ) |
METAL_FUNC void ReadWriter< float2, float2, 0, true >::load_strided | ( | int | stride, |
int | overall_n ) |
METAL_FUNC void ReadWriter< float2, float, 1, true >::load_strided | ( | int | stride, |
int | overall_n ) |
|
inline |
METAL_FUNC bool ReadWriter< float, float2 >::out_of_bounds | ( | ) | const |
METAL_FUNC bool ReadWriter< float2, float >::out_of_bounds | ( | ) | const |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
METAL_FUNC void ReadWriter< float, float2 >::write | ( | ) | const |
METAL_FUNC void ReadWriter< float2, float >::write | ( | ) | const |
|
inline |
METAL_FUNC void ReadWriter< float, float2 >::write_padded | ( | int | length, |
const device float2 * | w_k ) const |
METAL_FUNC void ReadWriter< float2, float >::write_padded | ( | int | length, |
const device float2 * | w_k ) const |
|
inline |
METAL_FUNC void ReadWriter< float2, float2, 1 >::write_strided | ( | int | stride, |
int | overall_n ) |
METAL_FUNC void ReadWriter< float2, float2, 1, true >::write_strided | ( | int | stride, |
int | overall_n ) |
METAL_FUNC void ReadWriter< float2, float, 1, true >::write_strided | ( | int | stride, |
int | overall_n ) |
int ReadWriter< in_T, out_T, step, four_step_real >::batch_size |
threadgroup float2* ReadWriter< in_T, out_T, step, four_step_real >::buf |
uint3 ReadWriter< in_T, out_T, step, four_step_real >::elem |
int ReadWriter< in_T, out_T, step, four_step_real >::elems_per_thread |
uint3 ReadWriter< in_T, out_T, step, four_step_real >::grid |
const device in_T* ReadWriter< in_T, out_T, step, four_step_real >::in |
bool ReadWriter< in_T, out_T, step, four_step_real >::inv |
int ReadWriter< in_T, out_T, step, four_step_real >::n |
device out_T* ReadWriter< in_T, out_T, step, four_step_real >::out |
int ReadWriter< in_T, out_T, step, four_step_real >::strided_device_idx = 0 |
int ReadWriter< in_T, out_T, step, four_step_real >::strided_shared_idx = 0 |
int ReadWriter< in_T, out_T, step, four_step_real >::threads_per_tg |