20 thread U totals[N_WRITES],
21 const device T* inputs[N_WRITES],
29 for (
int i = 0; i < N_WRITES; i++) {
34 for (
int i = 0; i < blocks; i++) {
35 for (
int j = 0; j < N_WRITES; j++) {
36 for (
int i = 0; i < N_READS; i++) {
37 totals[j] = op(
static_cast<U
>(inputs[j][i]), totals[j]);
40 inputs[j] += lsize_x * N_READS;
45 int index = lid_x * N_READS;
46 if (index + N_READS <= extra) {
47 for (
int j = 0; j < N_WRITES; j++) {
48 for (
int i = 0; i < N_READS; i++) {
49 totals[j] = op(
static_cast<U
>(inputs[j][i]), totals[j]);
53 for (
int j = 0; j < N_WRITES; j++) {
54 for (
int i = 0; index + i < extra; i++) {
55 totals[j] = op(
static_cast<U
>(inputs[j][i]), totals[j]);
71 thread U totals[N_WRITES],
73 const constant
size_t& reduction_size,
79 const device T* inputs[N_WRITES];
80 inputs[0] = in + lid_x * N_READS;
81 for (
int i = 1; i < N_READS; i++) {
82 inputs[i] = inputs[i - 1] + reduction_size;
86 totals, inputs, blocks, extra, lsize_x, lid_x);
99 thread U totals[N_WRITES],
101 const int64_t row_idx,
104 const constant
int* shape,
105 const constant int64_t* strides,
106 const constant
int& ndim,
110 const device T* inputs[N_WRITES];
111 in += lid_x * N_READS;
112 for (
int i = 0; i < N_READS; i++) {
113 inputs[i] = in +
elem_to_loc(row_idx + i, shape, strides, ndim);
117 totals, inputs, blocks, extra, lsize_x, lid_x);
130 thread U totals[N_WRITES],
131 threadgroup U* shared_vals,
132 uint3 lid [[thread_position_in_threadgroup]],
133 uint simd_lane_id [[thread_index_in_simdgroup]],
134 uint simd_per_group [[simdgroups_per_threadgroup]],
135 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
139 for (
int i = 0; i < N_WRITES; i++) {
140 totals[i] = op.simd_reduce(totals[i]);
144 if (simd_per_group > 1) {
145 if (simd_lane_id == 0) {
146 for (
int i = 0; i < N_WRITES; i++) {
147 shared_vals[simd_group_id * N_WRITES + i] = totals[i];
150 threadgroup_barrier(mem_flags::mem_threadgroup);
153 for (
int i = 0; i < N_WRITES; i++) {
154 values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
158 for (
int i = 0; i < N_WRITES; i++) {
159 totals[i] = op.simd_reduce(values[i]);
164template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
168 for (
int i = 0; i < blocks; i++) {
170 for (
int j = 0; j < N_READS; j++) {
173 for (
int j = 0; j < N_READS; j++) {
174 total = op(vals[j], total);
178 for (
int i = 0; i < extra; i++) {
179 total = op(*row++, total);
200 const device T* in [[buffer(0)]],
201 device U* out [[buffer(1)]],
202 const constant int64_t& row_size [[buffer(2)]],
203 const constant int64_t& non_row_reductions [[buffer(3)]],
204 const constant
int* shape [[buffer(4)]],
205 const constant int64_t* strides [[buffer(5)]],
206 const constant
int& ndim [[buffer(6)]],
207 const constant
int* reduce_shape [[buffer(7)]],
208 const constant int64_t* reduce_strides [[buffer(8)]],
209 const constant
int& reduce_ndim [[buffer(9)]],
210 uint simd_lane_id [[thread_index_in_simdgroup]],
211 uint3 gid [[threadgroup_position_in_grid]],
212 uint3 gsize [[threadgroups_per_grid]],
213 uint3 tid [[thread_position_in_grid]],
214 uint3 tsize [[threads_per_grid]]) {
217 U total_val = Op::init;
222 int blocks = IdxT(row_size) / N_READS;
223 int extra = IdxT(row_size) % N_READS;
225 if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
227 IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
230 for (uint r = 0; r < non_row_reductions; r++) {
231 row = in + loop.location();
233 loop.next(reduce_shape, reduce_strides);
236 out[out_idx] = total_val;
240 IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
243 loop.next(simd_lane_id, reduce_shape, reduce_strides);
245 for (uint r = simd_lane_id; r < non_row_reductions; r +=
simd_size) {
246 row = in + loop.location();
248 loop.next(
simd_size, reduce_shape, reduce_strides);
251 total_val = op.simd_reduce(total_val);
253 if (simd_lane_id == 0) {
254 out[out_idx] = total_val;
263 typename IdxT = int64_t,
267 const device T* in [[buffer(0)]],
268 device U* out [[buffer(1)]],
269 const constant
size_t& reduction_size [[buffer(2)]],
270 const constant int64_t& out_size [[buffer(3)]],
271 uint3 gid [[threadgroup_position_in_grid]],
272 uint3 gsize [[threadgroups_per_grid]],
273 uint3 lid [[thread_position_in_threadgroup]],
274 uint3 lsize [[threads_per_threadgroup]],
275 uint simd_lane_id [[thread_index_in_simdgroup]],
276 uint simd_per_group [[simdgroups_per_threadgroup]],
277 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
278 threadgroup U shared_vals[
simd_size * N_WRITES];
282 IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
283 if (out_idx + N_WRITES > out_size) {
284 out_idx = out_size - N_WRITES;
286 in += out_idx * IdxT(reduction_size);
290 int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
291 int extra = reduction_size - blocks * (lsize.x * N_READS);
293 totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
297 totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
301 for (
int i = 0; i < N_WRITES; i++) {
315 const device T* in [[buffer(0)]],
316 device U* out [[buffer(1)]],
317 const constant int64_t& row_size [[buffer(2)]],
318 const constant int64_t& non_row_reductions [[buffer(3)]],
319 const constant
int* shape [[buffer(4)]],
320 const constant int64_t* strides [[buffer(5)]],
321 const constant
int& ndim [[buffer(6)]],
322 const constant
int* reduce_shape [[buffer(7)]],
323 const constant int64_t* reduce_strides [[buffer(8)]],
324 const constant
int& reduce_ndim [[buffer(9)]],
325 uint3 gid [[threadgroup_position_in_grid]],
326 uint3 gsize [[threadgroups_per_grid]],
327 uint3 lid [[thread_position_in_threadgroup]],
328 uint3 lsize [[threads_per_threadgroup]],
329 uint simd_lane_id [[thread_index_in_simdgroup]],
330 uint simd_per_group [[simdgroups_per_threadgroup]],
331 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
336 IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
344 int blocks = IdxT(row_size) / (lsize.x * N_READS);
345 int extra = row_size - blocks * (lsize.x * N_READS);
347 for (IdxT i = 0; i < non_row_reductions; i++) {
348 row = in + loop.location();
353 &row_total, &row, blocks, extra, lsize.x, lid.x);
356 total = op(total, row_total);
358 loop.next(reduce_shape, reduce_strides);
363 &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
367 out[out_idx] = total;
static constexpr int REDUCE_N_READS
Definition defines.h:12
static constexpr int REDUCE_N_WRITES
Definition defines.h:13
void row_reduce_small(const device T *in, device U *out, const constant int64_t &row_size, const constant int64_t &non_row_reductions, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, uint simd_lane_id, uint3 gid, uint3 gsize, uint3 tid, uint3 tsize)
Definition reduce_row.h:199
void row_reduce_looped(const device T *in, device U *out, const constant int64_t &row_size, const constant int64_t &non_row_reductions, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:314
void row_reduce_simple(const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:266
METAL_FUNC void per_thread_row_reduce(thread U totals[N_WRITES], const device T *inputs[N_WRITES], int blocks, int extra, uint lsize_x, uint lid_x)
The thread group collaboratively reduces across the rows with bounds checking.
Definition reduce_row.h:19
METAL_FUNC void threadgroup_reduce(thread U totals[N_WRITES], threadgroup U *shared_vals, uint3 lid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Reduce within the threadgroup.
Definition reduce_row.h:129
METAL_FUNC void thread_reduce(thread U &total, const device T *row, int blocks, int extra)
Definition reduce_row.h:166