5 const device T* in [[buffer(0)]],
6 device U* out [[buffer(1)]],
7 const constant
size_t& reduction_size [[buffer(2)]],
8 const constant
size_t& reduction_stride [[buffer(3)]],
9 const constant
int* shape [[buffer(4)]],
10 const constant
size_t* strides [[buffer(5)]],
11 const constant
int& ndim [[buffer(6)]],
12 const constant
int* reduce_shape [[buffer(7)]],
13 const constant
size_t* reduce_strides [[buffer(8)]],
14 const constant
int& reduce_ndim [[buffer(9)]],
15 const constant
size_t& non_col_reductions [[buffer(10)]],
16 uint3 gid [[threadgroup_position_in_grid]],
17 uint3 gsize [[threadgroups_per_grid]],
18 uint3 lid [[thread_position_in_threadgroup]],
19 uint3 lsize [[threads_per_threadgroup]]) {
20 constexpr int n_reads = 4;
26 for (
int i = 0; i < n_reads; i++) {
30 size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
31 if (column >= reduction_stride) {
34 bool safe = column + n_reads <= reduction_stride;
36 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
37 size_t in_idx =
elem_to_loc(out_idx, shape, strides, ndim);
38 in += in_idx + column;
40 size_t total_rows = non_col_reductions * reduction_size;
41 loop.
next(lid.y, reduce_shape, reduce_strides);
42 for (
size_t r = lid.y; r < total_rows; r += lsize.y) {
43 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
45 for (
int i = 0; i < n_reads; i++) {
46 totals[i] =
op(
static_cast<U
>(row[i]), totals[i]);
50 for (
int i = 0; i < n_reads; i++) {
52 (column + i < reduction_stride) ? static_cast<U>(row[i]) :
op.init;
54 for (
int i = 0; i < n_reads; i++) {
55 totals[i] =
op(vals[i], totals[i]);
58 loop.
next(lsize.y, reduce_shape, reduce_strides);
63 threadgroup U shared_vals[32 * 8 * n_reads];
64 for (
int i = 0; i < n_reads; i++) {
65 shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
67 threadgroup_barrier(mem_flags::mem_threadgroup);
69 for (
int i = 0; i < n_reads; i++) {
70 totals[i] = shared_vals[lid.x * n_reads + i];
72 for (uint j = 1; j < lsize.y; j++) {
73 for (
int i = 0; i < n_reads; i++) {
75 op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
83 out += out_idx * reduction_stride + column;
85 for (
int i = 0; i < n_reads; i++) {
89 for (
int i = 0; column + i < reduction_stride; i++) {
98 const device T* in [[buffer(0)]],
99 device U* out [[buffer(1)]],
100 const constant
size_t& reduction_size [[buffer(2)]],
101 const constant
size_t& reduction_stride [[buffer(3)]],
102 const constant
int* shape [[buffer(4)]],
103 const constant
size_t* strides [[buffer(5)]],
104 const constant
int& ndim [[buffer(6)]],
105 const constant
int* reduce_shape [[buffer(7)]],
106 const constant
size_t* reduce_strides [[buffer(8)]],
107 const constant
int& reduce_ndim [[buffer(9)]],
108 const constant
size_t& non_col_reductions [[buffer(10)]],
109 const constant
size_t& out_size [[buffer(11)]],
110 uint3 gid [[threadgroup_position_in_grid]],
111 uint3 gsize [[threadgroups_per_grid]],
112 uint3 lid [[thread_position_in_threadgroup]],
113 uint3 lsize [[threads_per_threadgroup]]) {
118 size_t out_idx = gid.x + gsize.x * size_t(gid.y);
119 size_t in_idx =
elem_to_loc(out_idx, shape, strides, ndim);
120 in += in_idx + lid.x;
123 size_t total_rows = non_col_reductions * reduction_size;
124 loop.
next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
125 for (
size_t r = gid.z * lsize.y + lid.y; r < total_rows;
126 r += lsize.y * gsize.z) {
127 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
128 total =
op(
static_cast<U
>(*row), total);
129 loop.
next(lsize.y * gsize.z, reduce_shape, reduce_strides);
132 threadgroup U shared_vals[32 * 32];
133 shared_vals[lid.y * lsize.x + lid.x] = total;
134 threadgroup_barrier(mem_flags::mem_threadgroup);
136 for (uint i = 1; i < lsize.y; i++) {
137 total =
op(total, shared_vals[i * lsize.x + lid.x]);
139 out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
156 const device T* in [[buffer(0)]],
157 device U* out [[buffer(1)]],
158 const constant
size_t& reduction_size [[buffer(2)]],
159 const constant
size_t& reduction_stride [[buffer(3)]],
160 const constant
int* shape [[buffer(4)]],
161 const constant
size_t* strides [[buffer(5)]],
162 const constant
int& ndim [[buffer(6)]],
163 const constant
int* reduce_shape [[buffer(7)]],
164 const constant
size_t* reduce_strides [[buffer(8)]],
165 const constant
int& reduce_ndim [[buffer(9)]],
166 const constant
size_t& non_col_reductions [[buffer(10)]],
167 uint3 gid [[threadgroup_position_in_grid]],
168 uint3 gsize [[threadgroups_per_grid]],
169 uint simd_lane_id [[thread_index_in_simdgroup]],
170 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
172 constexpr int n_simdgroups = 8;
173 constexpr short tgp_size = n_simdgroups *
simd_size;
174 constexpr short n_reads = (BM * BN) / tgp_size;
175 constexpr short n_read_blocks = BN / n_reads;
177 threadgroup U shared_vals[BN * BM];
182 for (
int i = 0; i < n_reads; i++) {
183 totals[i] = Op::init;
186 short lid = simd_group_id *
simd_size + simd_lane_id;
187 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
188 size_t column = BN * gid.x + offset.x;
189 bool safe = column + n_reads <= reduction_stride;
191 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
192 size_t in_idx =
elem_to_loc(out_idx, shape, strides, ndim);
193 in += in_idx + column;
195 size_t total = non_col_reductions * reduction_size;
196 loop.
next(offset.y, reduce_shape, reduce_strides);
197 for (
size_t r = offset.y; r < total; r += BM) {
198 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
201 for (
int i = 0; i < n_reads; i++) {
202 totals[i] =
op(
static_cast<U
>(row[i]), totals[i]);
206 for (
int i = 0; i < n_reads; i++) {
208 (column + i < reduction_stride) ? static_cast<U>(row[i]) :
op.init;
210 for (
int i = 0; i < n_reads; i++) {
211 totals[i] =
op(vals[i], totals[i]);
215 loop.
next(BM, reduce_shape, reduce_strides);
222 constexpr int n_outputs = BN / n_simdgroups;
224 BM != 32 || n_outputs == n_reads,
225 "The tile should be selected such that n_outputs == n_reads");
226 for (
int i = 0; i < n_reads; i++) {
227 shared_vals[offset.y * BN + offset.x + i] = totals[i];
229 threadgroup_barrier(mem_flags::mem_threadgroup);
230 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
231 for (
int i = 0; i < n_outputs; i++) {
233 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
237 if (simd_lane_id == 0) {
238 size_t out_column = BN * gid.x + out_offset.x;
239 out += out_idx * reduction_stride + out_column;
240 if (out_column + n_outputs <= reduction_stride) {
241 for (
int i = 0; i < n_outputs; i++) {
245 for (
int i = 0; out_column + i < reduction_stride; i++) {
256 short x_block = offset.x / n_reads;
257 for (
int i = 0; i < n_reads; i++) {
258 shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
260 threadgroup_barrier(mem_flags::mem_threadgroup);
262 for (
int i = 0; i < n_reads; i++) {
263 for (
int j = 1; j < BM; j++) {
265 op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
272 out += out_idx * reduction_stride + column;
274 for (
int i = 0; i < n_reads; i++) {
278 for (
int i = 0; column + i < reduction_stride; i++) {
288 const device T* in [[buffer(0)]],
289 device U* out [[buffer(1)]],
290 const constant
size_t& reduction_size [[buffer(2)]],
291 const constant
size_t& reduction_stride [[buffer(3)]],
292 const constant
int* shape [[buffer(4)]],
293 const constant
size_t* strides [[buffer(5)]],
294 const constant
int& ndim [[buffer(6)]],
295 const constant
int* reduce_shape [[buffer(7)]],
296 const constant
size_t* reduce_strides [[buffer(8)]],
297 const constant
int& reduce_ndim [[buffer(9)]],
298 const constant
size_t& non_col_reductions [[buffer(10)]],
299 const constant
size_t& out_size [[buffer(11)]],
300 uint3 gid [[threadgroup_position_in_grid]],
301 uint3 gsize [[threadgroups_per_grid]],
302 uint simd_lane_id [[thread_index_in_simdgroup]],
303 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
305 constexpr int n_simdgroups = 8;
306 constexpr short tgp_size = n_simdgroups *
simd_size;
307 constexpr short n_reads = (BM * BN) / tgp_size;
308 constexpr short n_read_blocks = BN / n_reads;
309 constexpr int n_outputs = BN / n_simdgroups;
310 constexpr short outer_blocks = 32;
311 static_assert(BM == 32,
"BM should be equal to 32");
313 threadgroup U shared_vals[BN * BM];
318 for (
int i = 0; i < n_reads; i++) {
319 totals[i] = Op::init;
322 short lid = simd_group_id *
simd_size + simd_lane_id;
323 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
324 size_t column = BN * gid.x + offset.x;
325 bool safe = column + n_reads <= reduction_stride;
327 size_t full_idx = gid.y + gsize.y * size_t(gid.z);
328 size_t block_idx = full_idx / out_size;
329 size_t out_idx = full_idx % out_size;
330 size_t in_idx =
elem_to_loc(out_idx, shape, strides, ndim);
331 in += in_idx + column;
333 size_t total = non_col_reductions * reduction_size;
334 loop.
next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
335 for (
size_t r = offset.y + block_idx * BM; r < total;
336 r += outer_blocks * BM) {
337 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
340 for (
int i = 0; i < n_reads; i++) {
341 totals[i] =
op(
static_cast<U
>(row[i]), totals[i]);
345 for (
int i = 0; i < n_reads; i++) {
347 (column + i < reduction_stride) ? static_cast<U>(row[i]) :
op.init;
349 for (
int i = 0; i < n_reads; i++) {
350 totals[i] =
op(vals[i], totals[i]);
354 loop.
next(outer_blocks * BM, reduce_shape, reduce_strides);
360 for (
int i = 0; i < n_reads; i++) {
361 shared_vals[offset.y * BN + offset.x + i] = totals[i];
363 threadgroup_barrier(mem_flags::mem_threadgroup);
364 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
365 for (
int i = 0; i < n_outputs; i++) {
367 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
371 if (simd_lane_id == 0) {
372 size_t out_column = BN * gid.x + out_offset.x;
373 out += full_idx * reduction_stride + out_column;
374 if (out_column + n_outputs <= reduction_stride) {
375 for (
int i = 0; i < n_outputs; i++) {
379 for (
int i = 0; out_column + i < reduction_stride; i++) {
void col_reduce_2pass(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Definition reduce_col.h:287
void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Our approach is the following simple looped approach:
Definition reduce_col.h:155
void col_reduce_longcolumn(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
Definition reduce_col.h:97
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
Definition reduce_col.h:4