MLX
Loading...
Searching...
No Matches
reduce_all.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
4// All reduce helper
6
7template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
9 const device T* in,
10 const device size_t& in_size,
11 uint gid,
12 uint grid_size) {
13 Op op;
14 U total_val = Op::init;
15
16 if (gid * N_READS < in_size) {
17 in += gid * N_READS;
18
19 int r = 0;
20 for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
21 U vals[N_READS] = {op.init};
22
23 for (int i = 0; i < N_READS; i++) {
24 vals[i] = static_cast<U>(in[i]);
25 }
26 for (int i = 0; i < N_READS; i++) {
27 total_val = op(vals[i], total_val);
28 }
29
30 in += grid_size * N_READS;
31 }
32
33 // Separate case for the last set as we close the reduction size
34 size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
35 if (curr_idx < in_size) {
36 int max_reads = in_size - curr_idx;
37 T vals[N_READS];
38
39 for (int i = 0, idx = 0; i < N_READS; i++, idx++) {
40 idx = idx < max_reads ? idx : max_reads - 1;
41 vals[i] = in[idx];
42 }
43 for (int i = 0; i < N_READS; i++) {
44 U val = i < max_reads ? vals[i] : Op::init;
45 total_val = op(static_cast<U>(val), total_val);
46 }
47 }
48 }
49
50 return total_val;
51}
52
54// All reduce kernel
56
57// NB: This kernel assumes threads_per_threadgroup is at most
58// 1024. This way with a simd_size of 32, we are guaranteed to
59// complete the reduction in two steps of simd-level reductions.
60template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
61[[kernel]] void all_reduce(
62 const device T* in [[buffer(0)]],
63 device mlx_atomic<U>* out [[buffer(1)]],
64 const device size_t& in_size [[buffer(2)]],
65 uint gid [[thread_position_in_grid]],
66 uint lid [[thread_position_in_threadgroup]],
67 uint grid_size [[threads_per_grid]],
68 uint simd_per_group [[simdgroups_per_threadgroup]],
69 uint simd_lane_id [[thread_index_in_simdgroup]],
70 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
71 Op op;
72 threadgroup U local_vals[simd_size];
73
74 U total_val =
75 per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
76
77 // Reduction within simd group
78 total_val = op.simd_reduce(total_val);
79 if (simd_lane_id == 0) {
80 local_vals[simd_group_id] = total_val;
81 }
82
83 // Reduction within thread group
84 threadgroup_barrier(mem_flags::mem_threadgroup);
85 total_val = lid < simd_per_group ? local_vals[lid] : op.init;
86 total_val = op.simd_reduce(total_val);
87
88 // Reduction across threadgroups
89 if (lid == 0) {
90 op.atomic_update(out, total_val);
91 }
92}
93
94template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
95[[kernel]] void all_reduce_no_atomics(
96 const device T* in [[buffer(0)]],
97 device U* out [[buffer(1)]],
98 const device size_t& in_size [[buffer(2)]],
99 uint gid [[thread_position_in_grid]],
100 uint lid [[thread_position_in_threadgroup]],
101 uint grid_size [[threads_per_grid]],
102 uint simd_per_group [[simdgroups_per_threadgroup]],
103 uint simd_lane_id [[thread_index_in_simdgroup]],
104 uint simd_group_id [[simdgroup_index_in_threadgroup]],
105 uint thread_group_id [[threadgroup_position_in_grid]]) {
106 Op op;
107 threadgroup U local_vals[simd_size];
108
109 U total_val =
110 per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
111
112 // Reduction within simd group (simd_add isn't supported for uint64/int64
113 // types)
114 for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
115 lane_offset /= 2) {
116 total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
117 }
118 // Write simd group reduction results to local memory
119 if (simd_lane_id == 0) {
120 local_vals[simd_group_id] = total_val;
121 }
122 threadgroup_barrier(mem_flags::mem_threadgroup);
123
124 // Reduction of simdgroup reduction results within threadgroup.
125 total_val = lid < simd_per_group ? local_vals[lid] : op.init;
126 for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
127 lane_offset /= 2) {
128 total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
129 }
130
131 // Reduction across threadgroups
132 if (lid == 0) {
133 out[thread_group_id] = total_val;
134 }
135}
static constant constexpr const uint8_t simd_size
Definition ops.h:8
size_t ceildiv(size_t N, size_t M)
Compute ceil((float)N/(float)M)
Definition utils.h:296
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta)
Definition utils.h:329
Op op
Definition binary.h:141
METAL_FUNC U per_thread_all_reduce(const device T *in, const device size_t &in_size, uint gid, uint grid_size)
Definition reduce_all.h:8
void all_reduce_no_atomics(const device T *in, device U *out, const device size_t &in_size, uint gid, uint lid, uint grid_size, uint simd_per_group, uint simd_lane_id, uint simd_group_id, uint thread_group_id)
Definition reduce_all.h:95
void all_reduce(const device T *in, device mlx_atomic< U > *out, const device size_t &in_size, uint gid, uint lid, uint grid_size, uint simd_per_group, uint simd_lane_id, uint simd_group_id)
Definition reduce_all.h:61
Definition atomic.h:25