212 const device T* in [[buffer(0)]],
213 device U* out [[buffer(1)]],
214 const constant
size_t& axis_size [[buffer(2)]],
215 uint3 gid [[threadgroup_position_in_grid]],
216 uint3 gsize [[threadgroups_per_grid]],
217 uint3 lid [[thread_position_in_threadgroup]],
218 uint3 lsize [[threads_per_threadgroup]],
219 uint simd_lane_id [[thread_index_in_simdgroup]],
220 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
225 size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size;
235 threadgroup U simdgroup_sums[32];
249 for (uint r = 0; r <
ceildiv(axis_size, N_READS * lsize.x); r++) {
251 uint offset = r * lsize.x * N_READS + lid.x * N_READS;
255 if ((offset + N_READS) < axis_size) {
257 values, in + axis_size - offset - N_READS);
261 in + axis_size - offset - N_READS,
267 if ((offset + N_READS) < axis_size) {
271 values, in + offset, offset, axis_size, Op::init);
276 for (
int i = 1; i < N_READS; i++) {
277 values[i] = op(values[i], values[i - 1]);
281 U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
285 simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
287 threadgroup_barrier(mem_flags::mem_threadgroup);
290 if (simd_group_id == 0) {
291 U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
292 simdgroup_sums[simd_lane_id] = prev_simdgroup;
294 threadgroup_barrier(mem_flags::mem_threadgroup);
297 for (
int i = 0; i < N_READS; i++) {
298 values[i] = op(values[i], prefix);
299 values[i] = op(values[i], simdgroup_sums[simd_group_id]);
300 values[i] = op(values[i], prev_thread);
306 if ((offset + N_READS) < axis_size) {
308 values, out + axis_size - offset - N_READS);
311 values, out + axis_size - offset - N_READS, offset, axis_size);
314 if (lid.x == 0 && offset == 0) {
315 out[axis_size - 1] = Op::init;
317 if ((offset + N_READS + 1) < axis_size) {
319 values, out + axis_size - offset - 1 - N_READS);
323 out + axis_size - offset - 1 - N_READS,
330 if ((offset + N_READS) < axis_size) {
334 values, out + offset, offset, axis_size);
337 if (lid.x == 0 && offset == 0) {
340 if ((offset + N_READS + 1) < axis_size) {
344 values, out + offset + 1, offset + 1, axis_size);
348 threadgroup_barrier(mem_flags::mem_threadgroup);
351 if (simd_group_id == simd_groups - 1 && simd_lane_id ==
simd_size - 1) {
352 simdgroup_sums[0] = values[N_READS - 1];
354 threadgroup_barrier(mem_flags::mem_threadgroup);
355 prefix = simdgroup_sums[0];
367 const device T* in [[buffer(0)]],
368 device U* out [[buffer(1)]],
369 const constant
size_t& axis_size [[buffer(2)]],
370 const constant
size_t& stride [[buffer(3)]],
371 const constant
size_t& stride_blocks [[buffer(4)]],
372 uint3 gid [[threadgroup_position_in_grid]],
373 uint3 gsize [[threadgroups_per_grid]],
374 uint3 lid [[thread_position_in_threadgroup]],
375 uint simd_lane_id [[thread_index_in_simdgroup]],
376 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
378 constexpr int BM = 32;
379 constexpr int BN = 32;
380 constexpr int BN_pad = 32 + 16 /
sizeof(U);
381 constexpr int n_simds = BN / N_READS;
382 constexpr int n_scans = BN / n_simds;
385 threadgroup U read_buffer[BM * BN_pad];
388 for (
int i = 0; i < n_scans; i++) {
389 prefix[i] = Op::init;
393 size_t full_gid = gid.y + gsize.y * size_t(gid.z);
394 size_t offset = full_gid / stride_blocks * axis_size * stride;
395 size_t global_index_x = full_gid % stride_blocks * BN;
396 uint read_offset_y = (lid.x * N_READS) / BN;
397 uint read_offset_x = (lid.x * N_READS) % BN;
398 uint scan_offset_y = simd_lane_id;
399 uint scan_offset_x = simd_group_id * n_scans;
401 uint stride_limit = stride - global_index_x;
402 in += offset + global_index_x + read_offset_x;
403 out += offset + global_index_x + read_offset_x;
404 threadgroup U* read_into =
405 read_buffer + read_offset_y * BN_pad + read_offset_x;
406 threadgroup U* read_from =
407 read_buffer + scan_offset_y * BN_pad + scan_offset_x;
409 for (uint j = 0; j < axis_size; j += BM) {
411 uint index_y = j + read_offset_y;
412 uint check_index_y = index_y;
414 index_y = axis_size - 1 - index_y;
418 if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
419 for (
int i = 0; i < N_READS; i++) {
420 read_into[i] = in[index_y * stride + i];
423 for (
int i = 0; i < N_READS; i++) {
424 if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
425 read_into[i] = in[index_y * stride + i];
427 read_into[i] = Op::init;
431 threadgroup_barrier(mem_flags::mem_threadgroup);
434 for (
int i = 0; i < n_scans; i++) {
435 values[i] = read_from[i];
437 simdgroup_barrier(mem_flags::mem_threadgroup);
440 for (
int i = 0; i < n_scans; i++) {
441 values[i] = op.simd_scan(values[i]);
442 values[i] = op(values[i], prefix[i]);
447 for (
int i = 0; i < n_scans; i++) {
448 read_from[i] = values[i];
450 threadgroup_barrier(mem_flags::mem_threadgroup);
454 if (check_index_y == 0) {
455 if ((read_offset_x + N_READS) < stride_limit) {
456 for (
int i = 0; i < N_READS; i++) {
457 out[index_y * stride + i] = Op::init;
460 for (
int i = 0; i < N_READS; i++) {
461 if ((read_offset_x + i) < stride_limit) {
462 out[index_y * stride + i] = Op::init;
475 if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
476 for (
int i = 0; i < N_READS; i++) {
477 out[index_y * stride + i] = read_into[i];
480 for (
int i = 0; i < N_READS; i++) {
481 if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
482 out[index_y * stride + i] = read_into[i];
void contiguous_scan(const device T *in, device U *out, const constant size_t &axis_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_group_id)
Definition scan.h:211
void strided_scan(const device T *in, device U *out, const constant size_t &axis_size, const constant size_t &stride, const constant size_t &stride_blocks, uint3 gid, uint3 gsize, uint3 lid, uint simd_lane_id, uint simd_group_id)
Definition scan.h:366