MLX
Loading...
Searching...
No Matches
reduce_all.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
4[[kernel]] void all_reduce(
5 const device T* in [[buffer(0)]],
6 device U* out [[buffer(1)]],
7 const constant size_t& in_size [[buffer(2)]],
8 const constant size_t& row_size [[buffer(3)]],
9 uint3 gid [[threadgroup_position_in_grid]],
10 uint3 lid [[thread_position_in_threadgroup]],
11 uint3 lsize [[threads_per_threadgroup]],
12 uint simd_per_group [[simdgroups_per_threadgroup]],
13 uint simd_lane_id [[thread_index_in_simdgroup]],
14 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
15 Op op;
16 threadgroup U shared_vals[simd_size];
17
18 U total = Op::init;
19 int64_t start_idx = gid.y * row_size;
20 int64_t actual_row =
21 (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
22 int64_t blocks = actual_row / (lsize.x * N_READS);
23 int extra = actual_row - blocks * (lsize.x * N_READS);
24 extra -= lid.x * N_READS;
25 start_idx += lid.x * N_READS;
26 in += start_idx;
27
28 if (extra >= N_READS) {
29 blocks++;
30 extra = 0;
31 }
32
33 for (int64_t b = 0; b < blocks; b++) {
34 for (int i = 0; i < N_READS; i++) {
35 total = op(static_cast<U>(in[i]), total);
36 }
37 in += lsize.x * N_READS;
38 }
39 if (extra > 0) {
40 for (int i = 0; i < extra; i++) {
41 total = op(static_cast<U>(in[i]), total);
42 }
43 }
44
45 // Reduction within simd group
46 total = op.simd_reduce(total);
47 if (simd_per_group > 1) {
48 if (simd_lane_id == 0) {
49 shared_vals[simd_group_id] = total;
50 }
51
52 // Reduction within thread group
53 threadgroup_barrier(mem_flags::mem_threadgroup);
54 total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init;
55 total = op.simd_reduce(total);
56 }
57
58 if (lid.x == 0) {
59 out[gid.y] = total;
60 }
61}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
Op op
Definition binary.h:141
void all_reduce(const device T *in, device U *out, const constant size_t &in_size, const constant size_t &row_size, uint3 gid, uint3 lid, uint3 lsize, uint simd_per_group, uint simd_lane_id, uint simd_group_id)
Definition reduce_all.h:4