8template <
typename T, 
typename U, 
typename Op>
 
   10    const device T* in [[buffer(0)]],
 
   11    device U* out [[buffer(1)]],
 
   12    const constant 
size_t& reduction_size [[buffer(2)]],
 
   13    const constant 
size_t& out_size [[buffer(3)]],
 
   14    const constant 
size_t& non_row_reductions [[buffer(4)]],
 
   15    const constant 
int* shape [[buffer(5)]],
 
   16    const constant 
size_t* strides [[buffer(6)]],
 
   17    const constant 
int& ndim [[buffer(7)]],
 
   18    uint lid [[thread_position_in_grid]]) {
 
   23  if (out_idx >= out_size) {
 
   27  U total_val = Op::init;
 
   29  for (
short r = 0; r < short(non_row_reductions); r++) {
 
   30    uint in_idx = 
elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
 
   31    const device T* in_row = in + in_idx;
 
   33    for (
short i = 0; i < short(reduction_size); i++) {
 
   34      total_val = 
op(
static_cast<U
>(in_row[i]), total_val);
 
   38  out[out_idx] = total_val;
 
 
   42template <
typename T, 
typename U, 
typename Op>
 
   44    const device T* in [[buffer(0)]],
 
   45    device U* out [[buffer(1)]],
 
   46    const constant 
size_t& reduction_size [[buffer(2)]],
 
   47    const constant 
size_t& out_size [[buffer(3)]],
 
   48    const constant 
size_t& non_row_reductions [[buffer(4)]],
 
   49    const constant 
int* shape [[buffer(5)]],
 
   50    const constant 
size_t* strides [[buffer(6)]],
 
   51    const constant 
int& ndim [[buffer(7)]],
 
   52    uint tid [[threadgroup_position_in_grid]],
 
   53    uint simd_lane_id [[thread_index_in_simdgroup]],
 
   54    uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
 
   55    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
   58  uint out_idx = simd_per_group * tid + simd_group_id;
 
   60  if (out_idx >= out_size) {
 
   64  U total_val = Op::init;
 
   66  if (
short(non_row_reductions) == 1) {
 
   67    uint in_idx = 
elem_to_loc(out_idx, shape, strides, ndim);
 
   68    const device T* in_row = in + in_idx;
 
   70    for (
short i = simd_lane_id; i < short(reduction_size); i += 32) {
 
   71      total_val = 
op(
static_cast<U
>(in_row[i]), total_val);
 
   75  else if (
short(non_row_reductions) >= 32) {
 
   76    for (
short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
 
   77      uint in_idx = 
elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
 
   78      const device T* in_row = in + in_idx;
 
   80      for (
short i = 0; i < short(reduction_size); i++) {
 
   81        total_val = 
op(
static_cast<U
>(in_row[i]), total_val);
 
   88    const short n_reductions =
 
   89        short(reduction_size) * short(non_row_reductions);
 
   90    const short reductions_per_thread =
 
   93    const short r_st = simd_lane_id / reductions_per_thread;
 
   94    const short r_ed = short(non_row_reductions);
 
   95    const short r_jump = 
simd_size / reductions_per_thread;
 
   97    const short i_st = simd_lane_id % reductions_per_thread;
 
   98    const short i_ed = short(reduction_size);
 
   99    const short i_jump = reductions_per_thread;
 
  102      for (
short r = r_st; r < r_ed; r += r_jump) {
 
  103        uint in_idx = 
elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
 
  104        const device T* in_row = in + in_idx;
 
  106        for (
short i = i_st; i < i_ed; i += i_jump) {
 
  107          total_val = 
op(
static_cast<U
>(in_row[i]), total_val);
 
  113  total_val = 
op.simd_reduce(total_val);
 
  115  if (simd_lane_id == 0) {
 
  116    out[out_idx] = total_val;
 
 
  124template <
typename T, 
typename U, 
typename Op, 
int N_READS = REDUCE_N_READS>
 
  127    const constant 
size_t& reduction_size,
 
  128    const constant 
size_t& out_size,
 
  129    const constant 
int* shape,
 
  130    const constant 
size_t* strides,
 
  131    const constant 
int& ndim,
 
  139  int idx = tid.y * out_size + tid.x;
 
  140  int extra_offset = 
elem_to_loc(idx, shape, strides, ndim);
 
  141  in += extra_offset + lid_x * N_READS;
 
  144  U total_val = Op::init;
 
  148  for (; r < (int)
ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) {
 
  150    for (
int i = 0; i < N_READS; i++) {
 
  153    for (
int i = 0; i < N_READS; i++) {
 
  154      total_val = 
op(
static_cast<U
>(vals[i]), total_val);
 
  157    in += lsize_x * N_READS;
 
  161  size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
 
  162  if (reduction_index < reduction_size) {
 
  163    int max_reads = reduction_size - reduction_index;
 
  166    for (
int i = 0; i < N_READS; i++) {
 
  167      int idx = min(i, max_reads - 1);
 
  168      vals[i] = 
static_cast<U
>(in[idx]);
 
  170    for (
int i = 0; i < N_READS; i++) {
 
  171      T val = i < max_reads ? vals[i] : Op::init;
 
  172      total_val = 
op(
static_cast<U
>(val), total_val);
 
 
  179template <
typename T, 
typename U, 
typename Op, 
int N_READS = REDUCE_N_READS>
 
  181    const device T* in [[buffer(0)]],
 
  183    const constant 
size_t& reduction_size [[buffer(2)]],
 
  184    const constant 
size_t& out_size [[buffer(3)]],
 
  185    const constant 
size_t& non_row_reductions [[buffer(4)]],
 
  186    const constant 
int* shape [[buffer(5)]],
 
  187    const constant 
size_t* strides [[buffer(6)]],
 
  188    const constant 
int& ndim [[buffer(7)]],
 
  189    uint3 lid [[thread_position_in_threadgroup]],
 
  190    uint3 lsize [[threads_per_threadgroup]],
 
  191    uint3 tid [[threadgroup_position_in_grid]],
 
  192    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  193    uint simd_per_group [[simdgroups_per_threadgroup]],
 
  194    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
  195  (void)non_row_reductions;
 
  200  U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
 
  211  total_val = 
op.simd_reduce(total_val);
 
  214  if (simd_lane_id == 0) {
 
  215    local_vals[simd_group_id] = total_val;
 
  217  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  222    total_val = lid.x < simd_per_group ? local_vals[lid.x] : 
op.init;
 
  223    total_val = 
op.simd_reduce(total_val);
 
  227    op.atomic_update(out, total_val, tid.x);
 
 
  231template <
typename T, 
typename U, 
typename Op, 
int N_READS = REDUCE_N_READS>
 
  233    const device T* in [[buffer(0)]],
 
  234    device U* out [[buffer(1)]],
 
  235    const constant 
size_t& reduction_size [[buffer(2)]],
 
  236    const constant 
size_t& out_size [[buffer(3)]],
 
  237    const constant 
size_t& non_row_reductions [[buffer(4)]],
 
  238    const constant 
int* shape [[buffer(5)]],
 
  239    const constant 
size_t* strides [[buffer(6)]],
 
  240    const constant 
int& ndim [[buffer(7)]],
 
  241    uint3 lid [[thread_position_in_threadgroup]],
 
  242    uint3 lsize [[threads_per_threadgroup]],
 
  243    uint3 gsize [[threads_per_grid]],
 
  244    uint3 tid [[threadgroup_position_in_grid]],
 
  245    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  246    uint simd_per_group [[simdgroups_per_threadgroup]],
 
  247    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
  248  (void)non_row_reductions;
 
  253  U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
 
  265  for (uint16_t i = 
simd_size / 2; i > 0; i /= 2) {
 
  270  if (simd_lane_id == 0) {
 
  271    local_vals[simd_group_id] = total_val;
 
  273  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  278    total_val = lid.x < simd_per_group ? local_vals[lid.x] : 
op.init;
 
  279    for (uint16_t i = 
simd_size / 2; i > 0; i /= 2) {
 
  285    out[(
ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
 
 
Op op
Definition binary.h:141
 
void row_reduce_general_no_atomics(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 gsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:232
 
METAL_FUNC U per_thread_row_reduce(const device T *in, const constant size_t &reduction_size, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lsize_x, uint lid_x, uint2 tid)
Definition reduce_row.h:125
 
void row_reduce_general(const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:180
 
void row_reduce_general_med(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:43
 
void row_reduce_general_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lid)
Definition reduce_row.h:9