MLX
Loading...
Searching...
No Matches
reduce_all.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3template <
4 typename T,
5 typename U,
6 typename Op,
7 typename IdxT = int64_t,
8 int N_READS = REDUCE_N_READS>
9[[kernel]] void all_reduce(
10 const device T* in [[buffer(0)]],
11 device U* out [[buffer(1)]],
12 const constant size_t& in_size [[buffer(2)]],
13 const constant size_t& row_size [[buffer(3)]],
14 uint3 gid [[threadgroup_position_in_grid]],
15 uint3 lid [[thread_position_in_threadgroup]],
16 uint3 lsize [[threads_per_threadgroup]],
17 uint simd_per_group [[simdgroups_per_threadgroup]],
18 uint simd_lane_id [[thread_index_in_simdgroup]],
19 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
20 Op op;
21 threadgroup U shared_vals[simd_size];
22
23 U total = Op::init;
24 IdxT start_idx = gid.y * IdxT(row_size);
25 IdxT actual_row =
26 (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
27 IdxT blocks = actual_row / (lsize.x * N_READS);
28 int extra = actual_row - blocks * (lsize.x * N_READS);
29 extra -= lid.x * N_READS;
30 start_idx += lid.x * N_READS;
31 in += start_idx;
32
33 if (extra >= N_READS) {
34 blocks++;
35 extra = 0;
36 }
37
38 for (IdxT b = 0; b < blocks; b++) {
39 for (int i = 0; i < N_READS; i++) {
40 total = op(static_cast<U>(in[i]), total);
41 }
42 in += lsize.x * N_READS;
43 }
44 if (extra > 0) {
45 for (int i = 0; i < extra; i++) {
46 total = op(static_cast<U>(in[i]), total);
47 }
48 }
49
50 // Reduction within simd group
51 total = op.simd_reduce(total);
52 if (simd_per_group > 1) {
53 if (simd_lane_id == 0) {
54 shared_vals[simd_group_id] = total;
55 }
56
57 // Reduction within thread group
58 threadgroup_barrier(mem_flags::mem_threadgroup);
59 total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init;
60 total = op.simd_reduce(total);
61 }
62
63 if (lid.x == 0) {
64 out[gid.y] = total;
65 }
66}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
Op op
Definition binary.h:129
static constexpr int REDUCE_N_READS
Definition defines.h:12
void all_reduce(const device T *in, device U *out, const constant size_t &in_size, const constant size_t &row_size, uint3 gid, uint3 lid, uint3 lsize, uint simd_per_group, uint simd_lane_id, uint simd_group_id)
Definition reduce_all.h:9