10 const device T* in [[buffer(0)]],
11 device U* out [[buffer(1)]],
12 const constant
size_t& reduction_size [[buffer(2)]],
13 const constant
size_t& reduction_stride [[buffer(3)]],
14 const constant
int* shape [[buffer(4)]],
15 const constant
size_t* strides [[buffer(5)]],
16 const constant
int& ndim [[buffer(6)]],
17 const constant
int* reduce_shape [[buffer(7)]],
18 const constant
size_t* reduce_strides [[buffer(8)]],
19 const constant
int& reduce_ndim [[buffer(9)]],
20 const constant
size_t& non_col_reductions [[buffer(10)]],
21 uint3 gid [[threadgroup_position_in_grid]],
22 uint3 gsize [[threadgroups_per_grid]],
23 uint simd_lane_id [[thread_index_in_simdgroup]],
24 uint simd_group_id [[simdgroup_index_in_threadgroup]],
25 uint3 tid [[thread_position_in_grid]],
26 uint3 tsize [[threads_per_grid]]) {
32 if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
34 for (
int i = 0; i < 31; i++) {
38 short stride = reduction_stride;
39 short size = reduction_size;
40 short blocks = stride / N_READS;
41 short extra = stride - blocks * N_READS;
43 size_t out_idx = tid.x + tsize.y * size_t(tid.y);
46 for (uint r = 0; r < non_col_reductions; r++) {
47 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
49 for (
short i = 0; i < size; i++) {
50 for (
short j = 0; j < blocks; j++) {
51 for (
short k = 0; k < N_READS; k++) {
52 totals[j * N_READS + k] =
53 op(totals[j * N_READS + k],
54 static_cast<U
>(row[i * stride + j * N_READS + k]));
57 for (
short k = 0; k < extra; k++) {
58 totals[blocks * N_READS + k] =
59 op(totals[blocks * N_READS + k],
60 static_cast<U
>(row[i * stride + blocks * N_READS + k]));
64 loop.
next(reduce_shape, reduce_strides);
66 out += out_idx * reduction_stride;
67 for (
short j = 0; j < stride; j++) {
73 else if (reduction_size * non_col_reductions < 32) {
75 for (
int i = 0; i < N_READS; i++) {
79 short size = reduction_size;
80 size_t offset = size_t(tid.x) * N_READS;
81 bool safe = offset + N_READS <= reduction_stride;
82 short extra = reduction_stride - offset;
84 size_t out_idx = tid.y + tsize.z * size_t(tid.z);
85 in +=
elem_to_loc(out_idx, shape, strides, ndim) + offset;
87 for (uint r = 0; r < non_col_reductions; r++) {
88 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
91 for (
short i = 0; i < size; i++) {
92 for (
short j = 0; j < N_READS; j++) {
94 op(
static_cast<U
>(row[i * reduction_stride + j]), totals[j]);
98 for (
short i = 0; i < size; i++) {
99 for (
short j = 0; j < extra; j++) {
101 op(
static_cast<U
>(row[i * reduction_stride + j]), totals[j]);
106 loop.
next(reduce_shape, reduce_strides);
108 out += out_idx * reduction_stride + offset;
110 for (
short i = 0; i < N_READS; i++) {
114 for (
short i = 0; i < extra; i++) {
122 threadgroup U shared_vals[1024];
124 for (
int i = 0; i < N_READS; i++) {
125 totals[i] = Op::init;
128 short stride = reduction_stride;
129 short lid = simd_group_id *
simd_size + simd_lane_id;
130 short2 tile((stride + N_READS - 1) / N_READS, 32);
131 short2 offset((lid % tile.x) * N_READS, lid / tile.x);
132 short sm_stride = tile.x * N_READS;
133 bool safe = offset.x + N_READS <= stride;
135 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
136 in +=
elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
139 size_t total = non_col_reductions * reduction_size;
140 loop.
next(offset.y, reduce_shape, reduce_strides);
141 for (
size_t r = offset.y; r < total; r +=
simd_size) {
142 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
145 for (
int i = 0; i < N_READS; i++) {
146 totals[i] =
op(
static_cast<U
>(row[i]), totals[i]);
150 for (
int i = 0; i < N_READS; i++) {
151 vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) :
op.init;
153 for (
int i = 0; i < N_READS; i++) {
154 totals[i] =
op(vals[i], totals[i]);
164 for (
int i = 0; i < N_READS; i++) {
165 shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
167 threadgroup_barrier(mem_flags::mem_threadgroup);
168 for (
int i = 0; i < N_READS; i++) {
169 totals[i] =
op.simd_reduce(
170 shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
174 if (simd_lane_id == 0) {
175 short column = simd_group_id * N_READS;
176 out += out_idx * reduction_stride + column;
177 if (column + N_READS <= stride) {
178 for (
int i = 0; i < N_READS; i++) {
182 for (
int i = 0; column + i < stride; i++) {
203 const device T* in [[buffer(0)]],
204 device U* out [[buffer(1)]],
205 const constant
size_t& reduction_size [[buffer(2)]],
206 const constant
size_t& reduction_stride [[buffer(3)]],
207 const constant
int* shape [[buffer(4)]],
208 const constant
size_t* strides [[buffer(5)]],
209 const constant
int& ndim [[buffer(6)]],
210 const constant
int* reduce_shape [[buffer(7)]],
211 const constant
size_t* reduce_strides [[buffer(8)]],
212 const constant
int& reduce_ndim [[buffer(9)]],
213 const constant
size_t& non_col_reductions [[buffer(10)]],
214 uint3 gid [[threadgroup_position_in_grid]],
215 uint3 gsize [[threadgroups_per_grid]],
216 uint simd_lane_id [[thread_index_in_simdgroup]],
217 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
219 constexpr int n_simdgroups = 4;
220 constexpr short tgp_size = n_simdgroups *
simd_size;
221 constexpr short n_reads = (BM * BN) / tgp_size;
222 constexpr short n_read_blocks = BN / n_reads;
224 threadgroup U shared_vals[BN * BM];
229 for (
int i = 0; i < n_reads; i++) {
230 totals[i] = Op::init;
233 short lid = simd_group_id *
simd_size + simd_lane_id;
234 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
235 size_t column = BN * gid.x + offset.x;
236 bool safe = column + n_reads <= reduction_stride;
238 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
239 size_t in_idx =
elem_to_loc(out_idx, shape, strides, ndim);
240 in += in_idx + column;
242 size_t total = non_col_reductions * reduction_size;
243 loop.
next(offset.y, reduce_shape, reduce_strides);
244 for (
size_t r = offset.y; r < total; r += BM) {
245 row = in + loop.
location(r, reduce_shape, reduce_strides, reduce_ndim);
248 for (
int i = 0; i < n_reads; i++) {
249 totals[i] =
op(
static_cast<U
>(row[i]), totals[i]);
253 for (
int i = 0; i < n_reads; i++) {
255 (column + i < reduction_stride) ? static_cast<U>(row[i]) :
op.init;
257 for (
int i = 0; i < n_reads; i++) {
258 totals[i] =
op(vals[i], totals[i]);
262 loop.
next(BM, reduce_shape, reduce_strides);
269 constexpr int n_outputs = BN / n_simdgroups;
271 BM != 32 || n_outputs == n_reads,
272 "The tile should be selected such that n_outputs == n_reads");
273 for (
int i = 0; i < n_reads; i++) {
274 shared_vals[offset.y * BN + offset.x + i] = totals[i];
276 threadgroup_barrier(mem_flags::mem_threadgroup);
277 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
278 for (
int i = 0; i < n_outputs; i++) {
280 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
284 if (simd_lane_id == 0) {
285 size_t out_column = BN * gid.x + out_offset.x;
286 out += out_idx * reduction_stride + out_column;
287 if (out_column + n_outputs <= reduction_stride) {
288 for (
int i = 0; i < n_outputs; i++) {
292 for (
int i = 0; out_column + i < reduction_stride; i++) {
303 short x_block = offset.x / n_reads;
304 for (
int i = 0; i < n_reads; i++) {
305 shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
307 threadgroup_barrier(mem_flags::mem_threadgroup);
309 for (
int i = 0; i < n_reads; i++) {
310 for (
int j = 1; j < BM; j++) {
312 op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
319 out += out_idx * reduction_stride + column;
321 for (
int i = 0; i < n_reads; i++) {
325 for (
int i = 0; column + i < reduction_stride; i++) {
void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Our approach is the following simple looped approach:
Definition reduce_col.h:202
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)
Definition reduce_col.h:9