MLX
Loading...
Searching...
No Matches
reduce_row.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
4// Small row reductions
6
7// Each thread reduces for one output
8template <typename T, typename U, typename Op>
10 const device T* in [[buffer(0)]],
11 device U* out [[buffer(1)]],
12 const constant size_t& reduction_size [[buffer(2)]],
13 const constant size_t& out_size [[buffer(3)]],
14 const constant size_t& non_row_reductions [[buffer(4)]],
15 const constant int* shape [[buffer(5)]],
16 const constant size_t* strides [[buffer(6)]],
17 const constant int& ndim [[buffer(7)]],
18 uint lid [[thread_position_in_grid]]) {
19 Op op;
20
21 uint out_idx = lid;
22
23 if (out_idx >= out_size) {
24 return;
25 }
26
27 U total_val = Op::init;
28
29 for (short r = 0; r < short(non_row_reductions); r++) {
30 uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
31 const device T* in_row = in + in_idx;
32
33 for (short i = 0; i < short(reduction_size); i++) {
34 total_val = op(static_cast<U>(in_row[i]), total_val);
35 }
36 }
37
38 out[out_idx] = total_val;
39}
40
41// Each simdgroup reduces for one output
42template <typename T, typename U, typename Op>
43[[kernel]] void row_reduce_general_med(
44 const device T* in [[buffer(0)]],
45 device U* out [[buffer(1)]],
46 const constant size_t& reduction_size [[buffer(2)]],
47 const constant size_t& out_size [[buffer(3)]],
48 const constant size_t& non_row_reductions [[buffer(4)]],
49 const constant int* shape [[buffer(5)]],
50 const constant size_t* strides [[buffer(6)]],
51 const constant int& ndim [[buffer(7)]],
52 uint tid [[threadgroup_position_in_grid]],
53 uint simd_lane_id [[thread_index_in_simdgroup]],
54 uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
55 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
56 Op op;
57
58 uint out_idx = simd_per_group * tid + simd_group_id;
59
60 if (out_idx >= out_size) {
61 return;
62 }
63
64 U total_val = Op::init;
65
66 if (short(non_row_reductions) == 1) {
67 uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
68 const device T* in_row = in + in_idx;
69
70 for (short i = simd_lane_id; i < short(reduction_size); i += 32) {
71 total_val = op(static_cast<U>(in_row[i]), total_val);
72 }
73 }
74
75 else if (short(non_row_reductions) >= 32) {
76 for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
77 uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
78 const device T* in_row = in + in_idx;
79
80 for (short i = 0; i < short(reduction_size); i++) {
81 total_val = op(static_cast<U>(in_row[i]), total_val);
82 }
83 }
84
85 }
86
87 else {
88 const short n_reductions =
89 short(reduction_size) * short(non_row_reductions);
90 const short reductions_per_thread =
91 (n_reductions + simd_size - 1) / simd_size;
92
93 const short r_st = simd_lane_id / reductions_per_thread;
94 const short r_ed = short(non_row_reductions);
95 const short r_jump = simd_size / reductions_per_thread;
96
97 const short i_st = simd_lane_id % reductions_per_thread;
98 const short i_ed = short(reduction_size);
99 const short i_jump = reductions_per_thread;
100
101 if (r_st < r_jump) {
102 for (short r = r_st; r < r_ed; r += r_jump) {
103 uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
104 const device T* in_row = in + in_idx;
105
106 for (short i = i_st; i < i_ed; i += i_jump) {
107 total_val = op(static_cast<U>(in_row[i]), total_val);
108 }
109 }
110 }
111 }
112
113 total_val = op.simd_reduce(total_val);
114
115 if (simd_lane_id == 0) {
116 out[out_idx] = total_val;
117 }
118}
119
121// Large row reductions
123
124template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
126 const device T* in,
127 const constant size_t& reduction_size,
128 const constant size_t& out_size,
129 const constant int* shape,
130 const constant size_t* strides,
131 const constant int& ndim,
132 uint lsize_x,
133 uint lid_x,
134 uint2 tid) {
135 Op op;
136
137 // Each threadgroup handles 1 reduction
138 // TODO: Specializing elem_to_loc would be slightly faster
139 int idx = tid.y * out_size + tid.x;
140 int extra_offset = elem_to_loc(idx, shape, strides, ndim);
141 in += extra_offset + lid_x * N_READS;
142
143 // The reduction is accumulated here
144 U total_val = Op::init;
145
146 // Loop over the reduction size within thread group
147 int r = 0;
148 for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) {
149 T vals[N_READS];
150 for (int i = 0; i < N_READS; i++) {
151 vals[i] = in[i];
152 }
153 for (int i = 0; i < N_READS; i++) {
154 total_val = op(static_cast<U>(vals[i]), total_val);
155 }
156
157 in += lsize_x * N_READS;
158 }
159
160 // Separate case for the last set as we close the reduction size
161 size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
162 if (reduction_index < reduction_size) {
163 int max_reads = reduction_size - reduction_index;
164
165 T vals[N_READS];
166 for (int i = 0; i < N_READS; i++) {
167 int idx = min(i, max_reads - 1);
168 vals[i] = static_cast<U>(in[idx]);
169 }
170 for (int i = 0; i < N_READS; i++) {
171 T val = i < max_reads ? vals[i] : Op::init;
172 total_val = op(static_cast<U>(val), total_val);
173 }
174 }
175
176 return total_val;
177}
178
179template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
180[[kernel]] void row_reduce_general(
181 const device T* in [[buffer(0)]],
182 device mlx_atomic<U>* out [[buffer(1)]],
183 const constant size_t& reduction_size [[buffer(2)]],
184 const constant size_t& out_size [[buffer(3)]],
185 const constant size_t& non_row_reductions [[buffer(4)]],
186 const constant int* shape [[buffer(5)]],
187 const constant size_t* strides [[buffer(6)]],
188 const constant int& ndim [[buffer(7)]],
189 uint3 lid [[thread_position_in_threadgroup]],
190 uint3 lsize [[threads_per_threadgroup]],
191 uint3 tid [[threadgroup_position_in_grid]],
192 uint simd_lane_id [[thread_index_in_simdgroup]],
193 uint simd_per_group [[simdgroups_per_threadgroup]],
194 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
195 (void)non_row_reductions;
196
197 Op op;
198 threadgroup U local_vals[simd_size];
199
200 U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
201 in,
202 reduction_size,
203 out_size,
204 shape,
205 strides,
206 ndim,
207 lsize.x,
208 lid.x,
209 tid.xy);
210
211 total_val = op.simd_reduce(total_val);
212
213 // Prepare next level
214 if (simd_lane_id == 0) {
215 local_vals[simd_group_id] = total_val;
216 }
217 threadgroup_barrier(mem_flags::mem_threadgroup);
218
219 // Reduction within thread group
220 // Only needed if multiple simd groups
221 if (reduction_size > simd_size) {
222 total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
223 total_val = op.simd_reduce(total_val);
224 }
225 // Update output
226 if (lid.x == 0) {
227 op.atomic_update(out, total_val, tid.x);
228 }
229}
230
231template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
233 const device T* in [[buffer(0)]],
234 device U* out [[buffer(1)]],
235 const constant size_t& reduction_size [[buffer(2)]],
236 const constant size_t& out_size [[buffer(3)]],
237 const constant size_t& non_row_reductions [[buffer(4)]],
238 const constant int* shape [[buffer(5)]],
239 const constant size_t* strides [[buffer(6)]],
240 const constant int& ndim [[buffer(7)]],
241 uint3 lid [[thread_position_in_threadgroup]],
242 uint3 lsize [[threads_per_threadgroup]],
243 uint3 gsize [[threads_per_grid]],
244 uint3 tid [[threadgroup_position_in_grid]],
245 uint simd_lane_id [[thread_index_in_simdgroup]],
246 uint simd_per_group [[simdgroups_per_threadgroup]],
247 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
248 (void)non_row_reductions;
249
250 Op op;
251
252 threadgroup U local_vals[simd_size];
253 U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
254 in,
255 reduction_size,
256 out_size,
257 shape,
258 strides,
259 ndim,
260 lsize.x,
261 lid.x,
262 tid.xy);
263
264 // Reduction within simd group - simd_add isn't supported for int64 types
265 for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
266 total_val = op(total_val, simd_shuffle_down(total_val, i));
267 }
268
269 // Prepare next level
270 if (simd_lane_id == 0) {
271 local_vals[simd_group_id] = total_val;
272 }
273 threadgroup_barrier(mem_flags::mem_threadgroup);
274
275 // Reduction within thread group
276 // Only needed if thread group has multiple simd groups
277 if (ceildiv(reduction_size, N_READS) > simd_size) {
278 total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
279 for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
280 total_val = op(total_val, simd_shuffle_down(total_val, i));
281 }
282 }
283 // Write row reduce output for threadgroup with 1st thread in thread group
284 if (lid.x == 0) {
285 out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
286 }
287}
static constant constexpr const uint8_t simd_size
Definition ops.h:8
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:77
size_t ceildiv(size_t N, size_t M)
Compute ceil((float)N/(float)M)
Definition utils.h:296
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta)
Definition utils.h:329
Op op
Definition binary.h:141
void row_reduce_general_no_atomics(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 gsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:232
METAL_FUNC U per_thread_row_reduce(const device T *in, const constant size_t &reduction_size, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lsize_x, uint lid_x, uint2 tid)
Definition reduce_row.h:125
void row_reduce_general(const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:180
void row_reduce_general_med(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:43
void row_reduce_general_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lid)
Definition reduce_row.h:9
Definition atomic.h:25