5 const device T* in [[buffer(0)]],
6 device U* out [[buffer(1)]],
7 const constant
size_t& reduction_size [[buffer(2)]],
8 const constant int64_t& reduction_stride [[buffer(3)]],
9 const constant
int* shape [[buffer(4)]],
10 const constant int64_t* strides [[buffer(5)]],
11 const constant
int& ndim [[buffer(6)]],
12 const constant
int* reduce_shape [[buffer(7)]],
13 const constant int64_t* reduce_strides [[buffer(8)]],
14 const constant
int& reduce_ndim [[buffer(9)]],
15 const constant
size_t& non_col_reductions [[buffer(10)]],
16 uint3 gid [[threadgroup_position_in_grid]],
17 uint3 gsize [[threadgroups_per_grid]],
18 uint3 lid [[thread_position_in_threadgroup]],
19 uint3 lsize [[threads_per_threadgroup]]) {
20 constexpr int n_reads = 4;
26 for (
int i = 0; i < n_reads; i++) {
30 IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;
31 if (column >= reduction_stride) {
34 bool safe = column + n_reads <= reduction_stride;
36 IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
38 in += in_idx + column;
40 IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
41 loop.next(lid.y, reduce_shape, reduce_strides);
42 for (IdxT r = lid.y; r < total_rows; r += lsize.y) {
43 row = in + loop.location();
45 for (
int i = 0; i < n_reads; i++) {
46 totals[i] = op(
static_cast<U
>(row[i]), totals[i]);
50 for (
int i = 0; i < n_reads; i++) {
52 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
54 for (
int i = 0; i < n_reads; i++) {
55 totals[i] = op(vals[i], totals[i]);
58 loop.next(lsize.y, reduce_shape, reduce_strides);
63 threadgroup U shared_vals[32 * 8 * n_reads];
64 for (
int i = 0; i < n_reads; i++) {
65 shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
67 threadgroup_barrier(mem_flags::mem_threadgroup);
69 for (
int i = 0; i < n_reads; i++) {
70 totals[i] = shared_vals[lid.x * n_reads + i];
72 for (uint j = 1; j < lsize.y; j++) {
73 for (
int i = 0; i < n_reads; i++) {
75 op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
83 out += out_idx * IdxT(reduction_stride) + column;
85 for (
int i = 0; i < n_reads; i++) {
89 for (
int i = 0; column + i < reduction_stride; i++) {
98 const device T* in [[buffer(0)]],
99 device U* out [[buffer(1)]],
100 const constant
size_t& reduction_size [[buffer(2)]],
101 const constant
size_t& reduction_stride [[buffer(3)]],
102 const constant
int* shape [[buffer(4)]],
103 const constant int64_t* strides [[buffer(5)]],
104 const constant
int& ndim [[buffer(6)]],
105 const constant
int* reduce_shape [[buffer(7)]],
106 const constant int64_t* reduce_strides [[buffer(8)]],
107 const constant
int& reduce_ndim [[buffer(9)]],
108 const constant
size_t& non_col_reductions [[buffer(10)]],
109 const constant
size_t& out_size [[buffer(11)]],
110 uint3 gid [[threadgroup_position_in_grid]],
111 uint3 gsize [[threadgroups_per_grid]],
112 uint3 lid [[thread_position_in_threadgroup]],
113 uint3 lsize [[threads_per_threadgroup]]) {
118 IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
120 in += in_idx + lid.x;
123 IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
124 loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
125 for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;
126 r += lsize.y * gsize.z) {
127 row = in + loop.location();
128 total = op(
static_cast<U
>(*row), total);
129 loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
132 threadgroup U shared_vals[32 * 32];
133 shared_vals[lid.y * lsize.x + lid.x] = total;
134 threadgroup_barrier(mem_flags::mem_threadgroup);
136 for (uint i = 1; i < lsize.y; i++) {
137 total = op(total, shared_vals[i * lsize.x + lid.x]);
139 out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =
164 const device T* in [[buffer(0)]],
165 device U* out [[buffer(1)]],
166 const constant
size_t& reduction_size [[buffer(2)]],
167 const constant int64_t& reduction_stride [[buffer(3)]],
168 const constant
int* shape [[buffer(4)]],
169 const constant int64_t* strides [[buffer(5)]],
170 const constant
int& ndim [[buffer(6)]],
171 const constant
int* reduce_shape [[buffer(7)]],
172 const constant int64_t* reduce_strides [[buffer(8)]],
173 const constant
int& reduce_ndim [[buffer(9)]],
174 const constant
size_t& non_col_reductions [[buffer(10)]],
175 uint3 gid [[threadgroup_position_in_grid]],
176 uint3 gsize [[threadgroups_per_grid]],
177 uint simd_lane_id [[thread_index_in_simdgroup]],
178 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
180 constexpr int n_simdgroups = 8;
181 constexpr short tgp_size = n_simdgroups *
simd_size;
182 constexpr short n_reads = (BM * BN) / tgp_size;
183 constexpr short n_read_blocks = BN / n_reads;
185 threadgroup U shared_vals[BN * BM];
190 for (
int i = 0; i < n_reads; i++) {
191 totals[i] = Op::init;
194 short lid = simd_group_id *
simd_size + simd_lane_id;
195 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
196 IdxT column = BN * gid.x + offset.x;
197 bool safe = column + n_reads <= reduction_stride;
199 IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
201 in += in_idx + column;
203 IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
204 loop.next(offset.y, reduce_shape, reduce_strides);
205 for (IdxT r = offset.y; r < total; r += BM) {
206 row = in + loop.location();
209 for (
int i = 0; i < n_reads; i++) {
210 totals[i] = op(
static_cast<U
>(row[i]), totals[i]);
214 for (
int i = 0; i < n_reads; i++) {
216 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
218 for (
int i = 0; i < n_reads; i++) {
219 totals[i] = op(vals[i], totals[i]);
223 loop.next(BM, reduce_shape, reduce_strides);
230 constexpr int n_outputs = BN / n_simdgroups;
232 BM != 32 || n_outputs == n_reads,
233 "The tile should be selected such that n_outputs == n_reads");
234 for (
int i = 0; i < n_reads; i++) {
235 shared_vals[offset.y * BN + offset.x + i] = totals[i];
237 threadgroup_barrier(mem_flags::mem_threadgroup);
238 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
239 for (
int i = 0; i < n_outputs; i++) {
241 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
245 if (simd_lane_id == 0) {
246 IdxT out_column = BN * gid.x + out_offset.x;
247 out += out_idx * IdxT(reduction_stride) + out_column;
248 if (out_column + n_outputs <= reduction_stride) {
249 for (
int i = 0; i < n_outputs; i++) {
253 for (
int i = 0; out_column + i < reduction_stride; i++) {
264 short x_block = offset.x / n_reads;
265 for (
int i = 0; i < n_reads; i++) {
266 shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
268 threadgroup_barrier(mem_flags::mem_threadgroup);
270 for (
int i = 0; i < n_reads; i++) {
271 for (
int j = 1; j < BM; j++) {
273 op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
280 out += out_idx * IdxT(reduction_stride) + column;
282 for (
int i = 0; i < n_reads; i++) {
286 for (
int i = 0; column + i < reduction_stride; i++) {
303 const device T* in [[buffer(0)]],
304 device U* out [[buffer(1)]],
305 const constant
size_t& reduction_size [[buffer(2)]],
306 const constant int64_t& reduction_stride [[buffer(3)]],
307 const constant
int* shape [[buffer(4)]],
308 const constant int64_t* strides [[buffer(5)]],
309 const constant
int& ndim [[buffer(6)]],
310 const constant
int* reduce_shape [[buffer(7)]],
311 const constant int64_t* reduce_strides [[buffer(8)]],
312 const constant
int& reduce_ndim [[buffer(9)]],
313 const constant
size_t& non_col_reductions [[buffer(10)]],
314 const constant
size_t& out_size [[buffer(11)]],
315 uint3 gid [[threadgroup_position_in_grid]],
316 uint3 gsize [[threadgroups_per_grid]],
317 uint simd_lane_id [[thread_index_in_simdgroup]],
318 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
320 constexpr int n_simdgroups = 8;
321 constexpr short tgp_size = n_simdgroups *
simd_size;
322 constexpr short n_reads = (BM * BN) / tgp_size;
323 constexpr short n_read_blocks = BN / n_reads;
324 constexpr int n_outputs = BN / n_simdgroups;
325 constexpr short outer_blocks = 32;
326 static_assert(BM == 32,
"BM should be equal to 32");
328 threadgroup U shared_vals[BN * BM];
333 for (
int i = 0; i < n_reads; i++) {
334 totals[i] = Op::init;
337 short lid = simd_group_id *
simd_size + simd_lane_id;
338 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
339 IdxT column = BN * gid.x + offset.x;
340 bool safe = column + n_reads <= reduction_stride;
342 IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
343 IdxT block_idx = full_idx / IdxT(out_size);
344 IdxT out_idx = full_idx % IdxT(out_size);
346 in += in_idx + column;
348 IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
349 loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
350 for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {
351 row = in + loop.location();
354 for (
int i = 0; i < n_reads; i++) {
355 totals[i] = op(
static_cast<U
>(row[i]), totals[i]);
359 for (
int i = 0; i < n_reads; i++) {
361 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
363 for (
int i = 0; i < n_reads; i++) {
364 totals[i] = op(vals[i], totals[i]);
368 loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
374 for (
int i = 0; i < n_reads; i++) {
375 shared_vals[offset.y * BN + offset.x + i] = totals[i];
377 threadgroup_barrier(mem_flags::mem_threadgroup);
378 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
379 for (
int i = 0; i < n_outputs; i++) {
381 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
385 if (simd_lane_id == 0) {
386 IdxT out_column = BN * gid.x + out_offset.x;
387 out += full_idx * IdxT(reduction_stride) + out_column;
388 if (out_column + n_outputs <= reduction_stride) {
389 for (
int i = 0; i < n_outputs; i++) {
393 for (
int i = 0; out_column + i < reduction_stride; i++) {
void col_reduce_longcolumn(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
Definition reduce_col.h:97
void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Our approach is the following simple looped approach:
Definition reduce_col.h:163
void col_reduce_2pass(const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Definition reduce_col.h:302
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
Definition reduce_col.h:4