MLX
|
Go to the source code of this file.
Macros | |
#define | MLX_MTL_CONST static constant constexpr const |
#define | MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") |
Functions | |
template<typename T > | |
METAL_FUNC void | thread_swap (thread T &a, thread T &b) |
template<typename T , typename U , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> | |
void | block_sort (const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, uint3 tid, uint3 lid) |
template<typename T , typename U , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> | |
void | block_sort_nc (const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant size_t *in_nc_strides, const constant size_t *out_nc_strides, uint3 tid, uint3 lid) |
template<typename val_t , typename idx_t , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> | |
void | mb_block_sort (const device val_t *inp, device val_t *out_vals, device idx_t *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant size_t *nc_strides, uint3 tid, uint3 lid) |
template<typename val_t , typename idx_t , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> | |
void | mb_block_partition (device idx_t *block_partitions, const device val_t *dev_vals, const device idx_t *dev_idxs, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &n_blocks, uint3 tid, uint3 lid, uint3 tgp_dims) |
template<typename val_t , typename idx_t , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp = LessThan<val_t>> | |
void | mb_block_merge (const device idx_t *block_partitions, const device val_t *dev_vals_in, const device idx_t *dev_idxs_in, device val_t *dev_vals_out, device idx_t *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid) |
Variables | |
constant constexpr const int | zero_helper = 0 |
#define MLX_MTL_CONST static constant constexpr const |
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") |
void block_sort | ( | const device T * | inp, |
device U * | out, | ||
const constant int & | size_sorted_axis, | ||
const constant int & | in_stride_sorted_axis, | ||
const constant int & | out_stride_sorted_axis, | ||
const constant int & | in_stride_segment_axis, | ||
const constant int & | out_stride_segment_axis, | ||
uint3 | tid, | ||
uint3 | lid ) |
void block_sort_nc | ( | const device T * | inp, |
device U * | out, | ||
const constant int & | size_sorted_axis, | ||
const constant int & | in_stride_sorted_axis, | ||
const constant int & | out_stride_sorted_axis, | ||
const constant int & | nc_dim, | ||
const constant int * | nc_shape, | ||
const constant size_t * | in_nc_strides, | ||
const constant size_t * | out_nc_strides, | ||
uint3 | tid, | ||
uint3 | lid ) |
void mb_block_merge | ( | const device idx_t * | block_partitions, |
const device val_t * | dev_vals_in, | ||
const device idx_t * | dev_idxs_in, | ||
device val_t * | dev_vals_out, | ||
device idx_t * | dev_idxs_out, | ||
const constant int & | size_sorted_axis, | ||
const constant int & | merge_tiles, | ||
const constant int & | num_tiles, | ||
uint3 | tid, | ||
uint3 | lid ) |
void mb_block_partition | ( | device idx_t * | block_partitions, |
const device val_t * | dev_vals, | ||
const device idx_t * | dev_idxs, | ||
const constant int & | size_sorted_axis, | ||
const constant int & | merge_tiles, | ||
const constant int & | n_blocks, | ||
uint3 | tid, | ||
uint3 | lid, | ||
uint3 | tgp_dims ) |
void mb_block_sort | ( | const device val_t * | inp, |
device val_t * | out_vals, | ||
device idx_t * | out_idxs, | ||
const constant int & | size_sorted_axis, | ||
const constant int & | stride_sorted_axis, | ||
const constant int & | nc_dim, | ||
const constant int * | nc_shape, | ||
const constant size_t * | nc_strides, | ||
uint3 | tid, | ||
uint3 | lid ) |
METAL_FUNC void thread_swap | ( | thread T & | a, |
thread T & | b ) |
|
constexpr |