10    const device T* in [[buffer(0)]],
 
   11    device U* out [[buffer(1)]],
 
   12    const constant 
size_t& in_size [[buffer(2)]],
 
   13    const constant 
size_t& row_size [[buffer(3)]],
 
   14    uint3 gid [[threadgroup_position_in_grid]],
 
   15    uint3 lid [[thread_position_in_threadgroup]],
 
   16    uint3 lsize [[threads_per_threadgroup]],
 
   17    uint simd_per_group [[simdgroups_per_threadgroup]],
 
   18    uint simd_lane_id [[thread_index_in_simdgroup]],
 
   19    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
   24  IdxT start_idx = gid.y * IdxT(row_size);
 
   26      (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
 
   27  IdxT blocks = actual_row / (lsize.x * N_READS);
 
   28  int extra = actual_row - blocks * (lsize.x * N_READS);
 
   29  extra -= lid.x * N_READS;
 
   30  start_idx += lid.x * N_READS;
 
   33  if (extra >= N_READS) {
 
   38  for (IdxT b = 0; b < blocks; b++) {
 
   39    for (
int i = 0; i < N_READS; i++) {
 
   40      total = 
op(
static_cast<U
>(in[i]), total);
 
   42    in += lsize.x * N_READS;
 
   45    for (
int i = 0; i < extra; i++) {
 
   46      total = 
op(
static_cast<U
>(in[i]), total);
 
   51  total = 
op.simd_reduce(total);
 
   52  if (simd_per_group > 1) {
 
   53    if (simd_lane_id == 0) {
 
   54      shared_vals[simd_group_id] = total;
 
   58    threadgroup_barrier(mem_flags::mem_threadgroup);
 
   59    total = lid.x < simd_per_group ? shared_vals[lid.x] : 
op.init;
 
   60    total = 
op.simd_reduce(total);
 
 
void all_reduce(const device T *in, device U *out, const constant size_t &in_size, const constant size_t &row_size, uint3 gid, uint3 lid, uint3 lsize, uint simd_per_group, uint simd_lane_id, uint simd_group_id)
Definition reduce_all.h:9