212    const device T* in [[buffer(0)]],
 
  213    device U* out [[buffer(1)]],
 
  214    const constant 
size_t& axis_size [[buffer(2)]],
 
  215    uint3 gid [[threadgroup_position_in_grid]],
 
  216    uint3 gsize [[threadgroups_per_grid]],
 
  217    uint3 lid [[thread_position_in_threadgroup]],
 
  218    uint3 lsize [[threads_per_threadgroup]],
 
  219    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  220    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
  225  size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size;
 
  235  threadgroup U simdgroup_sums[32];
 
  249  for (uint r = 0; r < 
ceildiv(axis_size, N_READS * lsize.x); r++) {
 
  251    uint offset = r * lsize.x * N_READS + lid.x * N_READS;
 
  255      if ((offset + N_READS) < axis_size) {
 
  257            values, in + axis_size - offset - N_READS);
 
  261            in + axis_size - offset - N_READS,
 
  267      if ((offset + N_READS) < axis_size) {
 
  271            values, in + offset, offset, axis_size, Op::init);
 
  276    for (
int i = 1; i < N_READS; i++) {
 
  277      values[i] = 
op(values[i], values[i - 1]);
 
  281    U prev_thread = 
op.simd_exclusive_scan(values[N_READS - 1]);
 
  285      simdgroup_sums[simd_group_id] = 
op(prev_thread, values[N_READS - 1]);
 
  287    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  290    if (simd_group_id == 0) {
 
  291      U prev_simdgroup = 
op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
 
  292      simdgroup_sums[simd_lane_id] = prev_simdgroup;
 
  294    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  297    for (
int i = 0; i < N_READS; i++) {
 
  298      values[i] = 
op(values[i], prefix);
 
  299      values[i] = 
op(values[i], simdgroup_sums[simd_group_id]);
 
  300      values[i] = 
op(values[i], prev_thread);
 
  306        if ((offset + N_READS) < axis_size) {
 
  308              values, out + axis_size - offset - N_READS);
 
  311              values, out + axis_size - offset - N_READS, offset, axis_size);
 
  314        if (lid.x == 0 && offset == 0) {
 
  315          out[axis_size - 1] = Op::init;
 
  317        if ((offset + N_READS + 1) < axis_size) {
 
  319              values, out + axis_size - offset - 1 - N_READS);
 
  323              out + axis_size - offset - 1 - N_READS,
 
  330        if ((offset + N_READS) < axis_size) {
 
  334              values, out + offset, offset, axis_size);
 
  337        if (lid.x == 0 && offset == 0) {
 
  340        if ((offset + N_READS + 1) < axis_size) {
 
  344              values, out + offset + 1, offset + 1, axis_size);
 
  348    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  351    if (simd_group_id == simd_groups - 1 && simd_lane_id == 
simd_size - 1) {
 
  352      simdgroup_sums[0] = values[N_READS - 1];
 
  354    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  355    prefix = simdgroup_sums[0];
 
 
  367    const device T* in [[buffer(0)]],
 
  368    device U* out [[buffer(1)]],
 
  369    const constant 
size_t& axis_size [[buffer(2)]],
 
  370    const constant 
size_t& stride [[buffer(3)]],
 
  371    const constant 
size_t& stride_blocks [[buffer(4)]],
 
  372    uint3 gid [[threadgroup_position_in_grid]],
 
  373    uint3 gsize [[threadgroups_per_grid]],
 
  374    uint3 lid [[thread_position_in_threadgroup]],
 
  375    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  376    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
  378  constexpr int BM = 32;
 
  379  constexpr int BN = 32;
 
  380  constexpr int BN_pad = 32 + 16 / 
sizeof(U);
 
  381  constexpr int n_simds = BN / N_READS;
 
  382  constexpr int n_scans = BN / n_simds;
 
  385  threadgroup U read_buffer[BM * BN_pad];
 
  388  for (
int i = 0; i < n_scans; i++) {
 
  389    prefix[i] = Op::init;
 
  393  size_t full_gid = gid.y + gsize.y * size_t(gid.z);
 
  394  size_t offset = full_gid / stride_blocks * axis_size * stride;
 
  395  size_t global_index_x = full_gid % stride_blocks * BN;
 
  396  uint read_offset_y = (lid.x * N_READS) / BN;
 
  397  uint read_offset_x = (lid.x * N_READS) % BN;
 
  398  uint scan_offset_y = simd_lane_id;
 
  399  uint scan_offset_x = simd_group_id * n_scans;
 
  401  uint stride_limit = stride - global_index_x;
 
  402  in += offset + global_index_x + read_offset_x;
 
  403  out += offset + global_index_x + read_offset_x;
 
  404  threadgroup U* read_into =
 
  405      read_buffer + read_offset_y * BN_pad + read_offset_x;
 
  406  threadgroup U* read_from =
 
  407      read_buffer + scan_offset_y * BN_pad + scan_offset_x;
 
  409  for (uint j = 0; j < axis_size; j += BM) {
 
  411    uint index_y = j + read_offset_y;
 
  412    uint check_index_y = index_y;
 
  414      index_y = axis_size - 1 - index_y;
 
  418    if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
 
  419      for (
int i = 0; i < N_READS; i++) {
 
  420        read_into[i] = in[index_y * stride + i];
 
  423      for (
int i = 0; i < N_READS; i++) {
 
  424        if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
 
  425          read_into[i] = in[index_y * stride + i];
 
  427          read_into[i] = Op::init;
 
  431    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  434    for (
int i = 0; i < n_scans; i++) {
 
  435      values[i] = read_from[i];
 
  437    simdgroup_barrier(mem_flags::mem_threadgroup);
 
  440    for (
int i = 0; i < n_scans; i++) {
 
  441      values[i] = 
op.simd_scan(values[i]);
 
  442      values[i] = 
op(values[i], prefix[i]);
 
  447    for (
int i = 0; i < n_scans; i++) {
 
  448      read_from[i] = values[i];
 
  450    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  454      if (check_index_y == 0) {
 
  455        if ((read_offset_x + N_READS) < stride_limit) {
 
  456          for (
int i = 0; i < N_READS; i++) {
 
  457            out[index_y * stride + i] = Op::init;
 
  460          for (
int i = 0; i < N_READS; i++) {
 
  461            if ((read_offset_x + i) < stride_limit) {
 
  462              out[index_y * stride + i] = Op::init;
 
  475    if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
 
  476      for (
int i = 0; i < N_READS; i++) {
 
  477        out[index_y * stride + i] = read_into[i];
 
  480      for (
int i = 0; i < N_READS; i++) {
 
  481        if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
 
  482          out[index_y * stride + i] = read_into[i];
 
 
void contiguous_scan(const device T *in, device U *out, const constant size_t &axis_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_group_id)
Definition scan.h:211
 
void strided_scan(const device T *in, device U *out, const constant size_t &axis_size, const constant size_t &stride, const constant size_t &stride_blocks, uint3 gid, uint3 gsize, uint3 lid, uint simd_lane_id, uint simd_group_id)
Definition scan.h:366