178 const device T* in [[buffer(0)]],
179 device U* out [[buffer(1)]],
180 const constant
size_t& axis_size [[buffer(2)]],
181 uint gid [[thread_position_in_grid]],
182 uint lid [[thread_position_in_threadgroup]],
183 uint lsize [[threads_per_threadgroup]],
184 uint
simd_size [[threads_per_simdgroup]],
185 uint simd_lane_id [[thread_index_in_simdgroup]],
186 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
190 in += (gid / lsize) * axis_size;
191 out += (gid / lsize) * axis_size;
199 threadgroup U simdgroup_sums[32];
213 for (uint r = 0; r <
ceildiv(axis_size, N_READS * lsize); r++) {
215 uint offset = r * lsize * N_READS + lid * N_READS;
219 if ((offset + N_READS) < axis_size) {
220 load_unsafe<T, U, N_READS, reverse>(
221 values, in + axis_size - offset - N_READS);
223 load_safe<T, U, N_READS, reverse>(
225 in + axis_size - offset - N_READS,
231 if ((offset + N_READS) < axis_size) {
232 load_unsafe<T, U, N_READS, reverse>(values, in + offset);
234 load_safe<T, U, N_READS, reverse>(
235 values, in + offset, offset, axis_size, Op::init);
240 for (
int i = 1; i < N_READS; i++) {
241 values[i] =
op(values[i], values[i - 1]);
245 U prev_thread =
op.simd_exclusive_scan(values[N_READS - 1]);
249 simdgroup_sums[simd_group_id] =
op(prev_thread, values[N_READS - 1]);
251 threadgroup_barrier(mem_flags::mem_threadgroup);
254 if (simd_group_id == 0) {
255 U prev_simdgroup =
op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
256 simdgroup_sums[simd_lane_id] = prev_simdgroup;
258 threadgroup_barrier(mem_flags::mem_threadgroup);
261 for (
int i = 0; i < N_READS; i++) {
262 values[i] =
op(values[i], prefix);
263 values[i] =
op(values[i], simdgroup_sums[simd_group_id]);
264 values[i] =
op(values[i], prev_thread);
270 if ((offset + N_READS) < axis_size) {
271 write_unsafe<U, N_READS, reverse>(
272 values, out + axis_size - offset - N_READS);
274 write_safe<U, N_READS, reverse>(
275 values, out + axis_size - offset - N_READS, offset, axis_size);
278 if (lid == 0 && offset == 0) {
279 out[axis_size - 1] = Op::init;
281 if ((offset + N_READS + 1) < axis_size) {
282 write_unsafe<U, N_READS, reverse>(
283 values, out + axis_size - offset - 1 - N_READS);
285 write_safe<U, N_READS, reverse>(
287 out + axis_size - offset - 1 - N_READS,
294 if ((offset + N_READS) < axis_size) {
295 write_unsafe<U, N_READS, reverse>(values, out + offset);
297 write_safe<U, N_READS, reverse>(
298 values, out + offset, offset, axis_size);
301 if (lid == 0 && offset == 0) {
304 if ((offset + N_READS + 1) < axis_size) {
305 write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
307 write_safe<U, N_READS, reverse>(
308 values, out + offset + 1, offset + 1, axis_size);
312 threadgroup_barrier(mem_flags::mem_threadgroup);
315 if (simd_group_id == simd_groups - 1 && simd_lane_id ==
simd_size - 1) {
316 simdgroup_sums[0] = values[N_READS - 1];
318 threadgroup_barrier(mem_flags::mem_threadgroup);
319 prefix = simdgroup_sums[0];
331 const device T* in [[buffer(0)]],
332 device U* out [[buffer(1)]],
333 const constant
size_t& axis_size [[buffer(2)]],
334 const constant
size_t& stride [[buffer(3)]],
335 uint2 gid [[threadgroup_position_in_grid]],
336 uint2 lid [[thread_position_in_threadgroup]],
337 uint2 lsize [[threads_per_threadgroup]],
338 uint
simd_size [[threads_per_simdgroup]]) {
342 threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
345 for (
int i = 0; i < N_READS; i++) {
346 prefix[i] = Op::init;
350 int offset = gid.y * axis_size * stride;
351 int global_index_x = gid.x * lsize.y * N_READS;
353 for (uint j = 0; j < axis_size; j +=
simd_size) {
355 uint index_y = j + lid.y;
356 uint check_index_y = index_y;
357 uint index_x = global_index_x + lid.x * N_READS;
359 index_y = axis_size - 1 - index_y;
363 if (check_index_y < axis_size && (index_x + N_READS) < stride) {
364 for (
int i = 0; i < N_READS; i++) {
365 read_buffer[lid.y *
simd_size * N_READS + lid.x * N_READS + i] =
366 in[offset + index_y * stride + index_x + i];
369 for (
int i = 0; i < N_READS; i++) {
370 if (check_index_y < axis_size && (index_x + i) < stride) {
371 read_buffer[lid.y *
simd_size * N_READS + lid.x * N_READS + i] =
372 in[offset + index_y * stride + index_x + i];
374 read_buffer[lid.y *
simd_size * N_READS + lid.x * N_READS + i] =
379 threadgroup_barrier(mem_flags::mem_threadgroup);
382 for (
int i = 0; i < N_READS; i++) {
384 read_buffer[lid.x *
simd_size * N_READS + lid.y * N_READS + i];
388 simdgroup_barrier(mem_flags::mem_threadgroup);
391 for (
int i = 0; i < N_READS; i++) {
392 values[i] =
op.simd_scan(values[i]);
393 values[i] =
op(values[i], prefix[i]);
394 prefix[i] = simd_shuffle(values[i],
simd_size - 1);
398 for (
int i = 0; i < N_READS; i++) {
399 read_buffer[lid.x *
simd_size * N_READS + lid.y * N_READS + i] =
402 threadgroup_barrier(mem_flags::mem_threadgroup);
406 if (check_index_y == 0) {
407 if ((index_x + N_READS) < stride) {
408 for (
int i = 0; i < N_READS; i++) {
409 out[offset + index_y * stride + index_x + i] = Op::init;
412 for (
int i = 0; i < N_READS; i++) {
413 if ((index_x + i) < stride) {
414 out[offset + index_y * stride + index_x + i] = Op::init;
427 if (check_index_y < axis_size && (index_x + N_READS) < stride) {
428 for (
int i = 0; i < N_READS; i++) {
429 out[offset + index_y * stride + index_x + i] =
430 read_buffer[lid.y *
simd_size * N_READS + lid.x * N_READS + i];
433 for (
int i = 0; i < N_READS; i++) {
434 if (check_index_y < axis_size && (index_x + i) < stride) {
435 out[offset + index_y * stride + index_x + i] =
436 read_buffer[lid.y *
simd_size * N_READS + lid.x * N_READS + i];