10 const device T* in [[buffer(0)]],
11 device U* out [[buffer(1)]],
12 const constant
size_t& in_size [[buffer(2)]],
13 const constant
size_t& row_size [[buffer(3)]],
14 uint3 gid [[threadgroup_position_in_grid]],
15 uint3 lid [[thread_position_in_threadgroup]],
16 uint3 lsize [[threads_per_threadgroup]],
17 uint simd_per_group [[simdgroups_per_threadgroup]],
18 uint simd_lane_id [[thread_index_in_simdgroup]],
19 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
24 IdxT start_idx = gid.y * IdxT(row_size);
26 (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
27 IdxT blocks = actual_row / (lsize.x * N_READS);
28 int extra = actual_row - blocks * (lsize.x * N_READS);
29 extra -= lid.x * N_READS;
30 start_idx += lid.x * N_READS;
33 if (extra >= N_READS) {
38 for (IdxT b = 0; b < blocks; b++) {
39 for (
int i = 0; i < N_READS; i++) {
40 total =
op(
static_cast<U
>(in[i]), total);
42 in += lsize.x * N_READS;
45 for (
int i = 0; i < extra; i++) {
46 total =
op(
static_cast<U
>(in[i]), total);
51 total =
op.simd_reduce(total);
52 if (simd_per_group > 1) {
53 if (simd_lane_id == 0) {
54 shared_vals[simd_group_id] = total;
58 threadgroup_barrier(mem_flags::mem_threadgroup);
59 total = lid.x < simd_per_group ? shared_vals[lid.x] :
op.init;
60 total =
op.simd_reduce(total);
void all_reduce(const device T *in, device U *out, const constant size_t &in_size, const constant size_t &row_size, uint3 gid, uint3 lid, uint3 lsize, uint simd_per_group, uint simd_lane_id, uint simd_group_id)
Definition reduce_all.h:9