20    thread U totals[N_WRITES],
 
   21    const device T* inputs[N_WRITES],
 
   29  for (
int i = 0; i < N_WRITES; i++) {
 
   34  for (
int i = 0; i < blocks; i++) {
 
   35    for (
int j = 0; j < N_WRITES; j++) {
 
   36      for (
int i = 0; i < N_READS; i++) {
 
   37        totals[j] = 
op(
static_cast<U
>(inputs[j][i]), totals[j]);
 
   40      inputs[j] += lsize_x * N_READS;
 
   45  int index = lid_x * N_READS;
 
   46  if (index + N_READS <= extra) {
 
   47    for (
int j = 0; j < N_WRITES; j++) {
 
   48      for (
int i = 0; i < N_READS; i++) {
 
   49        totals[j] = 
op(
static_cast<U
>(inputs[j][i]), totals[j]);
 
   53    for (
int j = 0; j < N_WRITES; j++) {
 
   54      for (
int i = 0; index + i < extra; i++) {
 
   55        totals[j] = 
op(
static_cast<U
>(inputs[j][i]), totals[j]);
 
 
   71    thread U totals[N_WRITES],
 
   73    const constant 
size_t& reduction_size,
 
   79  const device T* inputs[N_WRITES];
 
   80  inputs[0] = in + lid_x * N_READS;
 
   81  for (
int i = 1; i < N_READS; i++) {
 
   82    inputs[i] = inputs[i - 1] + reduction_size;
 
   86      totals, inputs, blocks, extra, lsize_x, lid_x);
 
 
   99    thread U totals[N_WRITES],
 
  101    const size_t row_idx,
 
  104    const constant 
int* shape,
 
  105    const constant 
size_t* strides,
 
  106    const constant 
int& ndim,
 
  110  const device T* inputs[N_WRITES];
 
  111  in += lid_x * N_READS;
 
  112  for (
int i = 0; i < N_READS; i++) {
 
  113    inputs[i] = in + 
elem_to_loc(row_idx + i, shape, strides, ndim);
 
  117      totals, inputs, blocks, extra, lsize_x, lid_x);
 
 
  130    thread U totals[N_WRITES],
 
  131    threadgroup U* shared_vals,
 
  132    uint3 lid [[thread_position_in_threadgroup]],
 
  133    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  134    uint simd_per_group [[simdgroups_per_threadgroup]],
 
  135    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
  139  for (
int i = 0; i < N_WRITES; i++) {
 
  140    totals[i] = 
op.simd_reduce(totals[i]);
 
  144  if (simd_per_group > 1) {
 
  145    if (simd_lane_id == 0) {
 
  146      for (
int i = 0; i < N_WRITES; i++) {
 
  147        shared_vals[simd_group_id * N_WRITES + i] = totals[i];
 
  150    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  153    for (
int i = 0; i < N_WRITES; i++) {
 
  154      values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
 
  158    for (
int i = 0; i < N_WRITES; i++) {
 
  159      totals[i] = 
op.simd_reduce(values[i]);
 
 
  164template <
typename T, 
typename U, 
typename Op, 
int N_READS = REDUCE_N_READS>
 
  168  for (
int i = 0; i < blocks; i++) {
 
  170    for (
int j = 0; j < N_READS; j++) {
 
  173    for (
int j = 0; j < N_READS; j++) {
 
  174      total = 
op(vals[j], total);
 
  178  for (
int i = 0; i < extra; i++) {
 
  179    total = 
op(*row++, total);
 
 
  199    const device T* in [[buffer(0)]],
 
  200    device U* out [[buffer(1)]],
 
  201    const constant 
size_t& row_size [[buffer(2)]],
 
  202    const constant 
size_t& non_row_reductions [[buffer(3)]],
 
  203    const constant 
int* shape [[buffer(4)]],
 
  204    const constant 
size_t* strides [[buffer(5)]],
 
  205    const constant 
int& ndim [[buffer(6)]],
 
  206    const constant 
int* reduce_shape [[buffer(7)]],
 
  207    const constant 
size_t* reduce_strides [[buffer(8)]],
 
  208    const constant 
int& reduce_ndim [[buffer(9)]],
 
  209    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  210    uint3 gid [[threadgroup_position_in_grid]],
 
  211    uint3 gsize [[threadgroups_per_grid]],
 
  212    uint3 tid [[thread_position_in_grid]],
 
  213    uint3 tsize [[threads_per_grid]]) {
 
  216  U total_val = Op::init;
 
  221  int blocks = row_size / N_READS;
 
  222  int extra = row_size % N_READS;
 
  224  if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
 
  226    size_t out_idx = tid.x + tsize.y * size_t(tid.y);
 
  229    for (uint r = 0; r < non_row_reductions; r++) {
 
  230      row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
 
  232      loop.
next(reduce_shape, reduce_strides);
 
  235    out[out_idx] = total_val;
 
  239    size_t out_idx = gid.y + gsize.y * size_t(gid.z);
 
  242    loop.
next(simd_lane_id, reduce_shape, reduce_strides);
 
  244    for (uint r = simd_lane_id; r < non_row_reductions; r += 
simd_size) {
 
  245      row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
 
  250    total_val = 
op.simd_reduce(total_val);
 
  252    if (simd_lane_id == 0) {
 
  253      out[out_idx] = total_val;
 
 
  265    const device T* in [[buffer(0)]],
 
  266    device U* out [[buffer(1)]],
 
  267    const constant 
size_t& reduction_size [[buffer(2)]],
 
  268    const constant 
size_t& out_size [[buffer(3)]],
 
  269    uint3 gid [[threadgroup_position_in_grid]],
 
  270    uint3 gsize [[threadgroups_per_grid]],
 
  271    uint3 lid [[thread_position_in_threadgroup]],
 
  272    uint3 lsize [[threads_per_threadgroup]],
 
  273    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  274    uint simd_per_group [[simdgroups_per_threadgroup]],
 
  275    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
  276  threadgroup U shared_vals[
simd_size * N_WRITES];
 
  280  size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
 
  281  if (out_idx + N_WRITES > out_size) {
 
  282    out_idx = out_size - N_WRITES;
 
  284  in += out_idx * reduction_size;
 
  288  int blocks = reduction_size / (lsize.x * N_READS);
 
  289  int extra = reduction_size - blocks * (lsize.x * N_READS);
 
  291      totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
 
  295      totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
 
  299    for (
int i = 0; i < N_WRITES; i++) {
 
 
  312    const device T* in [[buffer(0)]],
 
  313    device U* out [[buffer(1)]],
 
  314    const constant 
size_t& row_size [[buffer(2)]],
 
  315    const constant 
size_t& non_row_reductions [[buffer(3)]],
 
  316    const constant 
int* shape [[buffer(4)]],
 
  317    const constant 
size_t* strides [[buffer(5)]],
 
  318    const constant 
int& ndim [[buffer(6)]],
 
  319    const constant 
int* reduce_shape [[buffer(7)]],
 
  320    const constant 
size_t* reduce_strides [[buffer(8)]],
 
  321    const constant 
int& reduce_ndim [[buffer(9)]],
 
  322    uint3 gid [[threadgroup_position_in_grid]],
 
  323    uint3 gsize [[threadgroups_per_grid]],
 
  324    uint3 lid [[thread_position_in_threadgroup]],
 
  325    uint3 lsize [[threads_per_threadgroup]],
 
  326    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  327    uint simd_per_group [[simdgroups_per_threadgroup]],
 
  328    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
  333  size_t out_idx = gid.y + gsize.y * size_t(gid.z);
 
  337  in += 
elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
 
  341  int blocks = row_size / (lsize.x * N_READS);
 
  342  int extra = row_size - blocks * (lsize.x * N_READS);
 
  344  for (
size_t i = 0; i < non_row_reductions; i++) {
 
  345    row = in + loop.
location(i, reduce_shape, reduce_strides, reduce_ndim);
 
  350        &row_total, &row, blocks, extra, lsize.x, lid.x);
 
  353    total = 
op(total, row_total);
 
  355    loop.
next(reduce_shape, reduce_strides);
 
  360      &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
 
  364    out[out_idx] = total;
 
 
Op op
Definition binary.h:129
 
static constexpr int REDUCE_N_READS
Definition defines.h:12
 
static constexpr int REDUCE_N_WRITES
Definition defines.h:13
 
void row_reduce_small(const device T *in, device U *out, const constant size_t &row_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, uint simd_lane_id, uint3 gid, uint3 gsize, uint3 tid, uint3 tsize)
Definition reduce_row.h:198
 
METAL_FUNC void per_thread_row_reduce(thread U totals[N_WRITES], const device T *inputs[N_WRITES], int blocks, int extra, uint lsize_x, uint lid_x)
The thread group collaboratively reduces across the rows with bounds checking.
Definition reduce_row.h:19
 
METAL_FUNC void threadgroup_reduce(thread U totals[N_WRITES], threadgroup U *shared_vals, uint3 lid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Reduce within the threadgroup.
Definition reduce_row.h:129
 
void row_reduce_simple(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:264
 
void row_reduce_looped(const device T *in, device U *out, const constant size_t &row_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:311
 
METAL_FUNC void thread_reduce(thread U &total, const device T *row, int blocks, int extra)
Definition reduce_row.h:166
 
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:202
 
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:229