MLX
Loading...
Searching...
No Matches
sort.h File Reference

Go to the source code of this file.

Classes

struct  LessThan< T >
 
struct  ThreadSort< val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp >
 
struct  BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
 
struct  KernelMergeSort< T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
 
struct  KernelMultiBlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
 

Macros

#define MLX_MTL_CONST   static constant constexpr const
 
#define MLX_MTL_LOOP_UNROLL   _Pragma("clang loop unroll(full)")
 

Functions

template<typename T >
METAL_FUNC void thread_swap (thread T &a, thread T &b)
 
template<typename T , typename U , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void block_sort (const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, uint3 tid, uint3 lid)
 
template<typename T , typename U , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void block_sort_nc (const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant size_t *in_nc_strides, const constant size_t *out_nc_strides, uint3 tid, uint3 lid)
 
template<typename val_t , typename idx_t , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void mb_block_sort (const device val_t *inp, device val_t *out_vals, device idx_t *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant size_t *nc_strides, uint3 tid, uint3 lid)
 
template<typename val_t , typename idx_t , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void mb_block_partition (device idx_t *block_partitions, const device val_t *dev_vals, const device idx_t *dev_idxs, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &n_blocks, uint3 tid, uint3 lid, uint3 tgp_dims)
 
template<typename val_t , typename idx_t , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp = LessThan<val_t>>
void mb_block_merge (const device idx_t *block_partitions, const device val_t *dev_vals_in, const device idx_t *dev_idxs_in, device val_t *dev_vals_out, device idx_t *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)
 

Variables

constant constexpr const int zero_helper = 0
 

Macro Definition Documentation

◆ MLX_MTL_CONST

#define MLX_MTL_CONST   static constant constexpr const

◆ MLX_MTL_LOOP_UNROLL

#define MLX_MTL_LOOP_UNROLL   _Pragma("clang loop unroll(full)")

Function Documentation

◆ block_sort()

template<typename T , typename U , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void block_sort ( const device T * inp,
device U * out,
const constant int & size_sorted_axis,
const constant int & in_stride_sorted_axis,
const constant int & out_stride_sorted_axis,
const constant int & in_stride_segment_axis,
const constant int & out_stride_segment_axis,
uint3 tid,
uint3 lid )

◆ block_sort_nc()

template<typename T , typename U , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void block_sort_nc ( const device T * inp,
device U * out,
const constant int & size_sorted_axis,
const constant int & in_stride_sorted_axis,
const constant int & out_stride_sorted_axis,
const constant int & nc_dim,
const constant int * nc_shape,
const constant size_t * in_nc_strides,
const constant size_t * out_nc_strides,
uint3 tid,
uint3 lid )

◆ mb_block_merge()

template<typename val_t , typename idx_t , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp = LessThan<val_t>>
void mb_block_merge ( const device idx_t * block_partitions,
const device val_t * dev_vals_in,
const device idx_t * dev_idxs_in,
device val_t * dev_vals_out,
device idx_t * dev_idxs_out,
const constant int & size_sorted_axis,
const constant int & merge_tiles,
const constant int & num_tiles,
uint3 tid,
uint3 lid )

◆ mb_block_partition()

template<typename val_t , typename idx_t , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void mb_block_partition ( device idx_t * block_partitions,
const device val_t * dev_vals,
const device idx_t * dev_idxs,
const constant int & size_sorted_axis,
const constant int & merge_tiles,
const constant int & n_blocks,
uint3 tid,
uint3 lid,
uint3 tgp_dims )

◆ mb_block_sort()

template<typename val_t , typename idx_t , bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD>
void mb_block_sort ( const device val_t * inp,
device val_t * out_vals,
device idx_t * out_idxs,
const constant int & size_sorted_axis,
const constant int & stride_sorted_axis,
const constant int & nc_dim,
const constant int * nc_shape,
const constant size_t * nc_strides,
uint3 tid,
uint3 lid )

◆ thread_swap()

template<typename T >
METAL_FUNC void thread_swap ( thread T & a,
thread T & b )

Variable Documentation

◆ zero_helper

constant constexpr const int zero_helper = 0
constexpr