5    const device T* in [[buffer(0)]],
 
    6    device U* out [[buffer(1)]],
 
    7    const constant 
size_t& in_size [[buffer(2)]],
 
    8    const constant 
size_t& row_size [[buffer(3)]],
 
    9    uint3 gid [[threadgroup_position_in_grid]],
 
   10    uint3 lid [[thread_position_in_threadgroup]],
 
   11    uint3 lsize [[threads_per_threadgroup]],
 
   12    uint simd_per_group [[simdgroups_per_threadgroup]],
 
   13    uint simd_lane_id [[thread_index_in_simdgroup]],
 
   14    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
   19  int64_t start_idx = gid.y * row_size;
 
   21      (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
 
   22  int64_t blocks = actual_row / (lsize.x * N_READS);
 
   23  int extra = actual_row - blocks * (lsize.x * N_READS);
 
   24  extra -= lid.x * N_READS;
 
   25  start_idx += lid.x * N_READS;
 
   28  if (extra >= N_READS) {
 
   33  for (int64_t b = 0; b < blocks; b++) {
 
   34    for (
int i = 0; i < N_READS; i++) {
 
   35      total = 
op(
static_cast<U
>(in[i]), total);
 
   37    in += lsize.x * N_READS;
 
   40    for (
int i = 0; i < extra; i++) {
 
   41      total = 
op(
static_cast<U
>(in[i]), total);
 
   46  total = 
op.simd_reduce(total);
 
   47  if (simd_per_group > 1) {
 
   48    if (simd_lane_id == 0) {
 
   49      shared_vals[simd_group_id] = total;
 
   53    threadgroup_barrier(mem_flags::mem_threadgroup);
 
   54    total = lid.x < simd_per_group ? shared_vals[lid.x] : 
op.init;
 
   55    total = 
op.simd_reduce(total);
 
 
void all_reduce(const device T *in, device U *out, const constant size_t &in_size, const constant size_t &row_size, uint3 gid, uint3 lid, uint3 lsize, uint simd_per_group, uint simd_lane_id, uint simd_group_id)
Definition reduce_all.h:4