7template <
typename T, 
typename U, 
typename Op>
 
    9    const device T* in [[buffer(0)]],
 
   10    device U* out [[buffer(1)]],
 
   11    const constant 
size_t& reduction_size [[buffer(2)]],
 
   12    const constant 
size_t& reduction_stride [[buffer(3)]],
 
   13    const constant 
size_t& out_size [[buffer(4)]],
 
   14    const constant 
int* shape [[buffer(5)]],
 
   15    const constant 
size_t* strides [[buffer(6)]],
 
   16    const constant 
int& ndim [[buffer(7)]],
 
   17    const constant 
size_t& non_col_reductions [[buffer(8)]],
 
   18    const constant 
int* non_col_shapes [[buffer(9)]],
 
   19    const constant 
size_t* non_col_strides [[buffer(10)]],
 
   20    const constant 
int& non_col_ndim [[buffer(11)]],
 
   21    uint tid [[thread_position_in_grid]]) {
 
   26  U total_val = Op::init;
 
   33      strides + non_col_ndim,
 
   36  for (uint i = 0; i < non_col_reductions; i++) {
 
   38        elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
 
   40    for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
 
   41      U val = 
static_cast<U
>(in[in_idx]);
 
   42      total_val = 
op(total_val, val);
 
   46  out[out_idx] = total_val;
 
 
   53template <
typename T, 
typename U, 
typename Op, 
int N_READS = REDUCE_N_READS>
 
   56    threadgroup U* local_data,
 
   59    uint reduction_stride,
 
   64  U total_val = Op::init;
 
   66  uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
 
   67  for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
 
   68    uint offset = base_offset + r;
 
   70        op(
static_cast<U
>(total_val), in[in_idx + offset * reduction_stride]);
 
   72  local_data[lsize.y * lid.x + lid.y] = total_val;
 
   73  threadgroup_barrier(mem_flags::mem_threadgroup);
 
   78    for (uint i = 0; i < lsize.y; i++) {
 
   79      val = 
op(val, local_data[lsize.y * lid.x + i]);
 
 
   90template <
typename T, 
typename U, 
typename Op, 
int N_READS = REDUCE_N_READS>
 
   92    const device T* in [[buffer(0)]],
 
   94    const constant 
size_t& reduction_size [[buffer(2)]],
 
   95    const constant 
size_t& reduction_stride [[buffer(3)]],
 
   96    const constant 
size_t& out_size [[buffer(4)]],
 
   97    const constant 
int* shape [[buffer(5)]],
 
   98    const constant 
size_t* strides [[buffer(6)]],
 
   99    const constant 
int& ndim [[buffer(7)]],
 
  100    threadgroup U* local_data [[threadgroup(0)]],
 
  101    uint3 tid [[threadgroup_position_in_grid]],
 
  102    uint3 lid [[thread_position_in_threadgroup]],
 
  103    uint3 lsize [[threads_per_threadgroup]]) {
 
  104  auto out_idx = tid.x * lsize.x + lid.x;
 
  105  auto in_idx = 
elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
 
  108  if (out_idx < out_size) {
 
  109    U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
 
  122      op.atomic_update(out, val, out_idx);
 
 
  127template <
typename T, 
typename U, 
typename Op, 
int N_READS = REDUCE_N_READS>
 
  129    const device T* in [[buffer(0)]],
 
  130    device U* out [[buffer(1)]],
 
  131    const constant 
size_t& reduction_size [[buffer(2)]],
 
  132    const constant 
size_t& reduction_stride [[buffer(3)]],
 
  133    const constant 
size_t& out_size [[buffer(4)]],
 
  134    const constant 
int* shape [[buffer(5)]],
 
  135    const constant 
size_t* strides [[buffer(6)]],
 
  136    const constant 
int& ndim [[buffer(7)]],
 
  137    threadgroup U* local_data [[threadgroup(0)]],
 
  138    uint3 tid [[threadgroup_position_in_grid]],
 
  139    uint3 lid [[thread_position_in_threadgroup]],
 
  140    uint3 gid [[thread_position_in_grid]],
 
  141    uint3 lsize [[threads_per_threadgroup]],
 
  142    uint3 gsize [[threads_per_grid]]) {
 
  143  auto out_idx = tid.x * lsize.x + lid.x;
 
  144  auto in_idx = 
elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
 
  146  if (out_idx < out_size) {
 
  147    U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
 
  160      uint tgsize_y = 
ceildiv(gsize.y, lsize.y);
 
  161      uint tgsize_z = 
ceildiv(gsize.z, lsize.z);
 
  162      out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
 
 
Op op
Definition binary.h:141
 
void col_reduce_general(const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 lsize)
Definition reduce_col.h:91
 
METAL_FUNC U _contiguous_strided_reduce(const device T *in, threadgroup U *local_data, uint in_idx, uint reduction_size, uint reduction_stride, uint2 tid, uint2 lid, uint2 lsize)
Definition reduce_col.h:54
 
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant size_t &non_col_reductions, const constant int *non_col_shapes, const constant size_t *non_col_strides, const constant int &non_col_ndim, uint tid)
Definition reduce_col.h:8
 
void col_reduce_general_no_atomics(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 gid, uint3 lsize, uint3 gsize)
Definition reduce_col.h:128