10    const device T* in [[buffer(0)]],
 
   11    device U* out [[buffer(1)]],
 
   12    const constant 
size_t& reduction_size [[buffer(2)]],
 
   13    const constant 
size_t& reduction_stride [[buffer(3)]],
 
   14    const constant 
int* shape [[buffer(4)]],
 
   15    const constant 
size_t* strides [[buffer(5)]],
 
   16    const constant 
int& ndim [[buffer(6)]],
 
   17    const constant 
int* reduce_shape [[buffer(7)]],
 
   18    const constant 
size_t* reduce_strides [[buffer(8)]],
 
   19    const constant 
int& reduce_ndim [[buffer(9)]],
 
   20    const constant 
size_t& non_col_reductions [[buffer(10)]],
 
   21    uint3 gid [[threadgroup_position_in_grid]],
 
   22    uint3 gsize [[threadgroups_per_grid]],
 
   23    uint simd_lane_id [[thread_index_in_simdgroup]],
 
   24    uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
   25    uint3 tid [[thread_position_in_grid]],
 
   26    uint3 tsize [[threads_per_grid]]) {
 
   32  if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
 
   34    for (
int i = 0; i < 31; i++) {
 
   38    short stride = reduction_stride;
 
   39    short size = reduction_size;
 
   40    short blocks = stride / N_READS;
 
   41    short extra = stride - blocks * N_READS;
 
   43    size_t out_idx = tid.x + tsize.y * size_t(tid.y);
 
   46    for (uint r = 0; r < non_col_reductions; r++) {
 
   47      row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
 
   49      for (
short i = 0; i < size; i++) {
 
   50        for (
short j = 0; j < blocks; j++) {
 
   51          for (
short k = 0; k < N_READS; k++) {
 
   52            totals[j * N_READS + k] =
 
   53                op(totals[j * N_READS + k],
 
   54                   static_cast<U
>(row[i * stride + j * N_READS + k]));
 
   57        for (
short k = 0; k < extra; k++) {
 
   58          totals[blocks * N_READS + k] =
 
   59              op(totals[blocks * N_READS + k],
 
   60                 static_cast<U
>(row[i * stride + blocks * N_READS + k]));
 
   64      loop.
next(reduce_shape, reduce_strides);
 
   66    out += out_idx * reduction_stride;
 
   67    for (
short j = 0; j < stride; j++) {
 
   73  else if (reduction_size * non_col_reductions < 32) {
 
   75    for (
int i = 0; i < N_READS; i++) {
 
   79    short size = reduction_size;
 
   80    size_t offset = size_t(tid.x) * N_READS;
 
   81    bool safe = offset + N_READS <= reduction_stride;
 
   82    short extra = reduction_stride - offset;
 
   84    size_t out_idx = tid.y + tsize.z * size_t(tid.z);
 
   85    in += 
elem_to_loc(out_idx, shape, strides, ndim) + offset;
 
   87    for (uint r = 0; r < non_col_reductions; r++) {
 
   88      row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
 
   91        for (
short i = 0; i < size; i++) {
 
   92          for (
short j = 0; j < N_READS; j++) {
 
   94                op(
static_cast<U
>(row[i * reduction_stride + j]), totals[j]);
 
   98        for (
short i = 0; i < size; i++) {
 
   99          for (
short j = 0; j < extra; j++) {
 
  101                op(
static_cast<U
>(row[i * reduction_stride + j]), totals[j]);
 
  106      loop.
next(reduce_shape, reduce_strides);
 
  108    out += out_idx * reduction_stride + offset;
 
  110      for (
short i = 0; i < N_READS; i++) {
 
  114      for (
short i = 0; i < extra; i++) {
 
  122    threadgroup U shared_vals[1024];
 
  124    for (
int i = 0; i < N_READS; i++) {
 
  125      totals[i] = Op::init;
 
  128    short stride = reduction_stride;
 
  129    short lid = simd_group_id * 
simd_size + simd_lane_id;
 
  130    short2 tile((stride + N_READS - 1) / N_READS, 32);
 
  131    short2 offset((lid % tile.x) * N_READS, lid / tile.x);
 
  132    short sm_stride = tile.x * N_READS;
 
  133    bool safe = offset.x + N_READS <= stride;
 
  135    size_t out_idx = gid.y + gsize.y * size_t(gid.z);
 
  136    in += 
elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
 
  139    size_t total = non_col_reductions * reduction_size;
 
  140    loop.
next(offset.y, reduce_shape, reduce_strides);
 
  141    for (
size_t r = offset.y; r < total; r += 
simd_size) {
 
  142      row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
 
  145        for (
int i = 0; i < N_READS; i++) {
 
  146          totals[i] = 
op(
static_cast<U
>(row[i]), totals[i]);
 
  150        for (
int i = 0; i < N_READS; i++) {
 
  151          vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : 
op.init;
 
  153        for (
int i = 0; i < N_READS; i++) {
 
  154          totals[i] = 
op(vals[i], totals[i]);
 
  164    for (
int i = 0; i < N_READS; i++) {
 
  165      shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
 
  167    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  168    for (
int i = 0; i < N_READS; i++) {
 
  169      totals[i] = 
op.simd_reduce(
 
  170          shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
 
  174    if (simd_lane_id == 0) {
 
  175      short column = simd_group_id * N_READS;
 
  176      out += out_idx * reduction_stride + column;
 
  177      if (column + N_READS <= stride) {
 
  178        for (
int i = 0; i < N_READS; i++) {
 
  182        for (
int i = 0; column + i < stride; i++) {
 
 
  209    const device T* in [[buffer(0)]],
 
  210    device U* out [[buffer(1)]],
 
  211    const constant 
size_t& reduction_size [[buffer(2)]],
 
  212    const constant 
size_t& reduction_stride [[buffer(3)]],
 
  213    const constant 
int* shape [[buffer(4)]],
 
  214    const constant 
size_t* strides [[buffer(5)]],
 
  215    const constant 
int& ndim [[buffer(6)]],
 
  216    const constant 
int* reduce_shape [[buffer(7)]],
 
  217    const constant 
size_t* reduce_strides [[buffer(8)]],
 
  218    const constant 
int& reduce_ndim [[buffer(9)]],
 
  219    const constant 
size_t& non_col_reductions [[buffer(10)]],
 
  220    uint3 gid [[threadgroup_position_in_grid]],
 
  221    uint3 gsize [[threadgroups_per_grid]],
 
  222    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  223    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
  225  constexpr int n_simdgroups = 4;
 
  226  constexpr short tgp_size = n_simdgroups * 
simd_size;
 
  227  constexpr short n_reads = (BM * BN) / tgp_size;
 
  228  constexpr short n_read_blocks = BN / n_reads;
 
  230  threadgroup U shared_vals[BN * BM];
 
  235  for (
int i = 0; i < n_reads; i++) {
 
  236    totals[i] = Op::init;
 
  239  short lid = simd_group_id * 
simd_size + simd_lane_id;
 
  240  short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
 
  241  size_t column = BN * gid.x + offset.x;
 
  242  bool safe = column + n_reads <= reduction_stride;
 
  244  size_t out_idx = gid.y + gsize.y * size_t(gid.z);
 
  245  size_t in_idx = 
elem_to_loc(out_idx, shape, strides, ndim);
 
  246  in += in_idx + column;
 
  248  size_t total = non_col_reductions * reduction_size;
 
  249  loop.
next(offset.y, reduce_shape, reduce_strides);
 
  250  for (
size_t r = offset.y; r < total; r += BM) {
 
  251    row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
 
  254      for (
int i = 0; i < n_reads; i++) {
 
  255        totals[i] = 
op(
static_cast<U
>(row[i]), totals[i]);
 
  259      for (
int i = 0; i < n_reads; i++) {
 
  261            (column + i < reduction_stride) ? static_cast<U>(row[i]) : 
op.init;
 
  263      for (
int i = 0; i < n_reads; i++) {
 
  264        totals[i] = 
op(vals[i], totals[i]);
 
  268    loop.
next(BM, reduce_shape, reduce_strides);
 
  275    constexpr int n_outputs = BN / n_simdgroups;
 
  277        BM != 32 || n_outputs == n_reads,
 
  278        "The tile should be selected such that n_outputs == n_reads");
 
  279    for (
int i = 0; i < n_reads; i++) {
 
  280      shared_vals[offset.y * BN + offset.x + i] = totals[i];
 
  282    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  283    short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
 
  284    for (
int i = 0; i < n_outputs; i++) {
 
  286          op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
 
  290    if (simd_lane_id == 0) {
 
  291      size_t out_column = BN * gid.x + out_offset.x;
 
  292      out += out_idx * reduction_stride + out_column;
 
  293      if (out_column + n_outputs <= reduction_stride) {
 
  294        for (
int i = 0; i < n_outputs; i++) {
 
  298        for (
int i = 0; out_column + i < reduction_stride; i++) {
 
  309    short x_block = offset.x / n_reads;
 
  310    for (
int i = 0; i < n_reads; i++) {
 
  311      shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
 
  313    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  315      for (
int i = 0; i < n_reads; i++) {
 
  316        for (
int j = 1; j < BM; j++) {
 
  318              op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
 
  325      out += out_idx * reduction_stride + column;
 
  327        for (
int i = 0; i < n_reads; i++) {
 
  331        for (
int i = 0; column + i < reduction_stride; i++) {
 
 
void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Our approach is the following simple looped approach:
Definition reduce_col.h:208
 
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)
Definition reduce_col.h:9