5 const device T* in [[buffer(0)]],
6 device U* out [[buffer(1)]],
7 const constant
size_t& in_size [[buffer(2)]],
8 const constant
size_t& row_size [[buffer(3)]],
9 uint3 gid [[threadgroup_position_in_grid]],
10 uint3 lid [[thread_position_in_threadgroup]],
11 uint3 lsize [[threads_per_threadgroup]],
12 uint simd_per_group [[simdgroups_per_threadgroup]],
13 uint simd_lane_id [[thread_index_in_simdgroup]],
14 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
19 int64_t start_idx = gid.y * row_size;
21 (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
22 int64_t blocks = actual_row / (lsize.x * N_READS);
23 int extra = actual_row - blocks * (lsize.x * N_READS);
24 extra -= lid.x * N_READS;
25 start_idx += lid.x * N_READS;
28 if (extra >= N_READS) {
33 for (int64_t b = 0; b < blocks; b++) {
34 for (
int i = 0; i < N_READS; i++) {
35 total =
op(
static_cast<U
>(in[i]), total);
37 in += lsize.x * N_READS;
40 for (
int i = 0; i < extra; i++) {
41 total =
op(
static_cast<U
>(in[i]), total);
46 total =
op.simd_reduce(total);
47 if (simd_per_group > 1) {
48 if (simd_lane_id == 0) {
49 shared_vals[simd_group_id] = total;
53 threadgroup_barrier(mem_flags::mem_threadgroup);
54 total = lid.x < simd_per_group ? shared_vals[lid.x] :
op.init;
55 total =
op.simd_reduce(total);
void all_reduce(const device T *in, device U *out, const constant size_t &in_size, const constant size_t &row_size, uint3 gid, uint3 lid, uint3 lsize, uint simd_per_group, uint simd_lane_id, uint simd_group_id)
Definition reduce_all.h:4