|
template<typename T , typename U , typename Op > |
void | row_reduce_general_small (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lid) |
|
template<typename T , typename U , typename Op > |
void | row_reduce_general_med (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id) |
|
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> |
METAL_FUNC U | per_thread_row_reduce (const device T *in, const constant size_t &reduction_size, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lsize_x, uint lid_x, uint2 tid) |
|
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> |
void | row_reduce_general (const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id) |
|
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> |
void | row_reduce_general_no_atomics (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 gsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id) |
|