MLX
 
Loading...
Searching...
No Matches
sort.h File Reference

Go to the source code of this file.

Classes

struct  LessThan< T >
 
struct  ThreadSort< ValT, IdxT, ARG_SORT, N_PER_THREAD, CompareOp >
 
struct  BlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
 
struct  KernelMergeSort< T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
 
struct  KernelMultiBlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
 

Macros

#define MLX_MTL_CONST   static constant constexpr const
 
#define MLX_MTL_LOOP_UNROLL   _Pragma("clang loop unroll(full)")
 

Functions

template<typename T>
METAL_FUNC void thread_swap (thread T &a, thread T &b)
 
template<typename T, typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void block_sort (const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, uint3 tid, uint3 lid)
 
template<typename T, typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void block_sort_nc (const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant int64_t *in_nc_strides, const constant int64_t *out_nc_strides, uint3 tid, uint3 lid)
 
template<typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void mb_block_sort (const device ValT *inp, device ValT *out_vals, device IdxT *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant int64_t *nc_strides, uint3 tid, uint3 lid)
 
template<typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void mb_block_partition (device IdxT *block_partitions, const device ValT *dev_vals, const device IdxT *dev_idxs, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &n_blocks, uint3 tid, uint3 lid, uint3 tgp_dims)
 
template<typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp = LessThan<ValT>>
void mb_block_merge (const device IdxT *block_partitions, const device ValT *dev_vals_in, const device IdxT *dev_idxs_in, device ValT *dev_vals_out, device IdxT *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)
 

Variables

constant constexpr const int zero_helper = 0
 

Macro Definition Documentation

◆ MLX_MTL_CONST

#define MLX_MTL_CONST   static constant constexpr const

◆ MLX_MTL_LOOP_UNROLL

#define MLX_MTL_LOOP_UNROLL   _Pragma("clang loop unroll(full)")

Function Documentation

◆ block_sort()

template<typename T, typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void block_sort ( const device T * inp,
device U * out,
const constant int & size_sorted_axis,
const constant int & in_stride_sorted_axis,
const constant int & out_stride_sorted_axis,
const constant int & in_stride_segment_axis,
const constant int & out_stride_segment_axis,
uint3 tid,
uint3 lid )

◆ block_sort_nc()

template<typename T, typename U, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void block_sort_nc ( const device T * inp,
device U * out,
const constant int & size_sorted_axis,
const constant int & in_stride_sorted_axis,
const constant int & out_stride_sorted_axis,
const constant int & nc_dim,
const constant int * nc_shape,
const constant int64_t * in_nc_strides,
const constant int64_t * out_nc_strides,
uint3 tid,
uint3 lid )

◆ mb_block_merge()

template<typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp = LessThan<ValT>>
void mb_block_merge ( const device IdxT * block_partitions,
const device ValT * dev_vals_in,
const device IdxT * dev_idxs_in,
device ValT * dev_vals_out,
device IdxT * dev_idxs_out,
const constant int & size_sorted_axis,
const constant int & merge_tiles,
const constant int & num_tiles,
uint3 tid,
uint3 lid )

◆ mb_block_partition()

template<typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void mb_block_partition ( device IdxT * block_partitions,
const device ValT * dev_vals,
const device IdxT * dev_idxs,
const constant int & size_sorted_axis,
const constant int & merge_tiles,
const constant int & n_blocks,
uint3 tid,
uint3 lid,
uint3 tgp_dims )

◆ mb_block_sort()

template<typename ValT, typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void mb_block_sort ( const device ValT * inp,
device ValT * out_vals,
device IdxT * out_idxs,
const constant int & size_sorted_axis,
const constant int & stride_sorted_axis,
const constant int & nc_dim,
const constant int * nc_shape,
const constant int64_t * nc_strides,
uint3 tid,
uint3 lid )

◆ thread_swap()

template<typename T>
METAL_FUNC void thread_swap ( thread T & a,
thread T & b )

Variable Documentation

◆ zero_helper

constant constexpr const int zero_helper = 0
constexpr