20 thread U totals[N_WRITES],
21 const device T* inputs[N_WRITES],
29 for (
int i = 0; i < N_WRITES; i++) {
34 for (
int i = 0; i < blocks; i++) {
35 for (
int j = 0; j < N_WRITES; j++) {
36 for (
int i = 0; i < N_READS; i++) {
37 totals[j] =
op(
static_cast<U
>(inputs[j][i]), totals[j]);
40 inputs[j] += lsize_x * N_READS;
45 int index = lid_x * N_READS;
46 if (index + N_READS <= extra) {
47 for (
int j = 0; j < N_WRITES; j++) {
48 for (
int i = 0; i < N_READS; i++) {
49 totals[j] =
op(
static_cast<U
>(inputs[j][i]), totals[j]);
53 for (
int j = 0; j < N_WRITES; j++) {
54 for (
int i = 0; index + i < extra; i++) {
55 totals[j] =
op(
static_cast<U
>(inputs[j][i]), totals[j]);
71 thread U totals[N_WRITES],
73 const constant
size_t& reduction_size,
79 const device T* inputs[N_WRITES];
80 inputs[0] = in + lid_x * N_READS;
81 for (
int i = 1; i < N_READS; i++) {
82 inputs[i] = inputs[i - 1] + reduction_size;
85 per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
86 totals, inputs, blocks, extra, lsize_x, lid_x);
99 thread U totals[N_WRITES],
101 const size_t row_idx,
104 const constant
int* shape,
105 const constant
size_t* strides,
106 const constant
int& ndim,
110 const device T* inputs[N_WRITES];
111 in += lid_x * N_READS;
112 for (
int i = 0; i < N_READS; i++) {
113 inputs[i] = in +
elem_to_loc(row_idx + i, shape, strides, ndim);
116 per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
117 totals, inputs, blocks, extra, lsize_x, lid_x);
130 thread U totals[N_WRITES],
131 threadgroup U* shared_vals,
132 uint3 lid [[thread_position_in_threadgroup]],
133 uint simd_lane_id [[thread_index_in_simdgroup]],
134 uint simd_per_group [[simdgroups_per_threadgroup]],
135 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
139 for (
int i = 0; i < N_WRITES; i++) {
140 totals[i] =
op.simd_reduce(totals[i]);
144 if (simd_per_group > 1) {
145 if (simd_lane_id == 0) {
146 for (
int i = 0; i < N_WRITES; i++) {
147 shared_vals[simd_group_id * N_WRITES + i] = totals[i];
150 threadgroup_barrier(mem_flags::mem_threadgroup);
153 for (
int i = 0; i < N_WRITES; i++) {
154 values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
158 for (
int i = 0; i < N_WRITES; i++) {
159 totals[i] =
op.simd_reduce(values[i]);
164template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
168 for (
int i = 0; i < blocks; i++) {
170 for (
int j = 0; j < N_READS; j++) {
173 for (
int j = 0; j < N_READS; j++) {
174 total =
op(vals[j], total);
178 for (
int i = 0; i < extra; i++) {
179 total =
op(*row++, total);
199 const device T* in [[buffer(0)]],
200 device U* out [[buffer(1)]],
201 const constant
size_t& row_size [[buffer(2)]],
202 const constant
size_t& non_row_reductions [[buffer(3)]],
203 const constant
int* shape [[buffer(4)]],
204 const constant
size_t* strides [[buffer(5)]],
205 const constant
int& ndim [[buffer(6)]],
206 const constant
int* reduce_shape [[buffer(7)]],
207 const constant
size_t* reduce_strides [[buffer(8)]],
208 const constant
int& reduce_ndim [[buffer(9)]],
209 uint simd_lane_id [[thread_index_in_simdgroup]],
210 uint3 gid [[threadgroup_position_in_grid]],
211 uint3 gsize [[threadgroups_per_grid]],
212 uint3 tid [[thread_position_in_grid]],
213 uint3 tsize [[threads_per_grid]]) {
216 U total_val = Op::init;
221 int blocks = row_size / N_READS;
222 int extra = row_size % N_READS;
224 if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
226 size_t out_idx = tid.x + tsize.y * size_t(tid.y);
229 for (uint r = 0; r < non_row_reductions; r++) {
230 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
231 thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
232 loop.
next(reduce_shape, reduce_strides);
235 out[out_idx] = total_val;
239 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
242 loop.
next(simd_lane_id, reduce_shape, reduce_strides);
244 for (uint r = simd_lane_id; r < non_row_reductions; r +=
simd_size) {
245 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
246 thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
250 total_val =
op.simd_reduce(total_val);
252 if (simd_lane_id == 0) {
253 out[out_idx] = total_val;
265 const device T* in [[buffer(0)]],
266 device U* out [[buffer(1)]],
267 const constant
size_t& reduction_size [[buffer(2)]],
268 const constant
size_t& out_size [[buffer(3)]],
269 uint3 gid [[threadgroup_position_in_grid]],
270 uint3 gsize [[threadgroups_per_grid]],
271 uint3 lid [[thread_position_in_threadgroup]],
272 uint3 lsize [[threads_per_threadgroup]],
273 uint simd_lane_id [[thread_index_in_simdgroup]],
274 uint simd_per_group [[simdgroups_per_threadgroup]],
275 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
276 threadgroup U shared_vals[
simd_size * N_WRITES];
280 size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
281 if (out_idx + N_WRITES > out_size) {
282 out_idx = out_size - N_WRITES;
284 in += out_idx * reduction_size;
288 int blocks = reduction_size / (lsize.x * N_READS);
289 int extra = reduction_size - blocks * (lsize.x * N_READS);
290 per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
291 totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
294 threadgroup_reduce<T, U, Op, N_READS, N_WRITES>(
295 totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
299 for (
int i = 0; i < N_WRITES; i++) {
312 const device T* in [[buffer(0)]],
313 device U* out [[buffer(1)]],
314 const constant
size_t& row_size [[buffer(2)]],
315 const constant
size_t& non_row_reductions [[buffer(3)]],
316 const constant
int* shape [[buffer(4)]],
317 const constant
size_t* strides [[buffer(5)]],
318 const constant
int& ndim [[buffer(6)]],
319 const constant
int* reduce_shape [[buffer(7)]],
320 const constant
size_t* reduce_strides [[buffer(8)]],
321 const constant
int& reduce_ndim [[buffer(9)]],
322 uint3 gid [[threadgroup_position_in_grid]],
323 uint3 gsize [[threadgroups_per_grid]],
324 uint3 lid [[thread_position_in_threadgroup]],
325 uint3 lsize [[threads_per_threadgroup]],
326 uint simd_lane_id [[thread_index_in_simdgroup]],
327 uint simd_per_group [[simdgroups_per_threadgroup]],
328 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
333 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
337 in +=
elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
341 int blocks = row_size / (lsize.x * N_READS);
342 int extra = row_size - blocks * (lsize.x * N_READS);
344 for (
size_t i = 0; i < non_row_reductions; i++) {
345 row = in + loop.
location(i, reduce_shape, reduce_strides, reduce_ndim);
349 per_thread_row_reduce<T, U, Op, N_READS, 1>(
350 &row_total, &row, blocks, extra, lsize.x, lid.x);
353 total =
op(total, row_total);
355 loop.
next(reduce_shape, reduce_strides);
359 threadgroup_reduce<T, U, Op, N_READS, 1>(
360 &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
364 out[out_idx] = total;
Op op
Definition binary.h:141
static constexpr int REDUCE_N_READS
Definition defines.h:12
static constexpr int REDUCE_N_WRITES
Definition defines.h:13
void row_reduce_small(const device T *in, device U *out, const constant size_t &row_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, uint simd_lane_id, uint3 gid, uint3 gsize, uint3 tid, uint3 tsize)
Definition reduce_row.h:198
METAL_FUNC void per_thread_row_reduce(thread U totals[N_WRITES], const device T *inputs[N_WRITES], int blocks, int extra, uint lsize_x, uint lid_x)
The thread group collaboratively reduces across the rows with bounds checking.
Definition reduce_row.h:19
METAL_FUNC void threadgroup_reduce(thread U totals[N_WRITES], threadgroup U *shared_vals, uint3 lid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Reduce within the threadgroup.
Definition reduce_row.h:129
void row_reduce_simple(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:264
void row_reduce_looped(const device T *in, device U *out, const constant size_t &row_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:311
METAL_FUNC void thread_reduce(thread U &total, const device T *row, int blocks, int extra)
Definition reduce_row.h:166
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:339
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:366