|
template<typename T , typename U , typename Op > |
void | col_reduce_small (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant size_t &non_col_reductions, const constant int *non_col_shapes, const constant size_t *non_col_strides, const constant int &non_col_ndim, uint tid) |
|
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> |
METAL_FUNC U | _contiguous_strided_reduce (const device T *in, threadgroup U *local_data, uint in_idx, uint reduction_size, uint reduction_stride, uint2 tid, uint2 lid, uint2 lsize) |
|
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> |
void | col_reduce_general (const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 lsize) |
|
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> |
void | col_reduce_general_no_atomics (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 gid, uint3 lsize, uint3 gsize) |
|