178    const device T* in [[buffer(0)]],
 
  179    device U* out [[buffer(1)]],
 
  180    const constant 
size_t& axis_size [[buffer(2)]],
 
  181    uint gid [[thread_position_in_grid]],
 
  182    uint lid [[thread_position_in_threadgroup]],
 
  183    uint lsize [[threads_per_threadgroup]],
 
  184    uint 
simd_size [[threads_per_simdgroup]],
 
  185    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  186    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 
  190  in += (gid / lsize) * axis_size;
 
  191  out += (gid / lsize) * axis_size;
 
  199  threadgroup U simdgroup_sums[32];
 
  213  for (uint r = 0; r < 
ceildiv(axis_size, N_READS * lsize); r++) {
 
  215    uint offset = r * lsize * N_READS + lid * N_READS;
 
  219      if ((offset + N_READS) < axis_size) {
 
  220        load_unsafe<T, U, N_READS, reverse>(
 
  221            values, in + axis_size - offset - N_READS);
 
  223        load_safe<T, U, N_READS, reverse>(
 
  225            in + axis_size - offset - N_READS,
 
  231      if ((offset + N_READS) < axis_size) {
 
  232        load_unsafe<T, U, N_READS, reverse>(values, in + offset);
 
  234        load_safe<T, U, N_READS, reverse>(
 
  235            values, in + offset, offset, axis_size, Op::init);
 
  240    for (
int i = 1; i < N_READS; i++) {
 
  241      values[i] = 
op(values[i], values[i - 1]);
 
  245    U prev_thread = 
op.simd_exclusive_scan(values[N_READS - 1]);
 
  249      simdgroup_sums[simd_group_id] = 
op(prev_thread, values[N_READS - 1]);
 
  251    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  254    if (simd_group_id == 0) {
 
  255      U prev_simdgroup = 
op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
 
  256      simdgroup_sums[simd_lane_id] = prev_simdgroup;
 
  258    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  261    for (
int i = 0; i < N_READS; i++) {
 
  262      values[i] = 
op(values[i], prefix);
 
  263      values[i] = 
op(values[i], simdgroup_sums[simd_group_id]);
 
  264      values[i] = 
op(values[i], prev_thread);
 
  270        if ((offset + N_READS) < axis_size) {
 
  271          write_unsafe<U, N_READS, reverse>(
 
  272              values, out + axis_size - offset - N_READS);
 
  274          write_safe<U, N_READS, reverse>(
 
  275              values, out + axis_size - offset - N_READS, offset, axis_size);
 
  278        if (lid == 0 && offset == 0) {
 
  279          out[axis_size - 1] = Op::init;
 
  281        if ((offset + N_READS + 1) < axis_size) {
 
  282          write_unsafe<U, N_READS, reverse>(
 
  283              values, out + axis_size - offset - 1 - N_READS);
 
  285          write_safe<U, N_READS, reverse>(
 
  287              out + axis_size - offset - 1 - N_READS,
 
  294        if ((offset + N_READS) < axis_size) {
 
  295          write_unsafe<U, N_READS, reverse>(values, out + offset);
 
  297          write_safe<U, N_READS, reverse>(
 
  298              values, out + offset, offset, axis_size);
 
  301        if (lid == 0 && offset == 0) {
 
  304        if ((offset + N_READS + 1) < axis_size) {
 
  305          write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
 
  307          write_safe<U, N_READS, reverse>(
 
  308              values, out + offset + 1, offset + 1, axis_size);
 
  312    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  315    if (simd_group_id == simd_groups - 1 && simd_lane_id == 
simd_size - 1) {
 
  316      simdgroup_sums[0] = values[N_READS - 1];
 
  318    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  319    prefix = simdgroup_sums[0];
 
 
  331    const device T* in [[buffer(0)]],
 
  332    device U* out [[buffer(1)]],
 
  333    const constant 
size_t& axis_size [[buffer(2)]],
 
  334    const constant 
size_t& stride [[buffer(3)]],
 
  335    uint2 gid [[threadgroup_position_in_grid]],
 
  336    uint2 lid [[thread_position_in_threadgroup]],
 
  337    uint2 lsize [[threads_per_threadgroup]],
 
  338    uint 
simd_size [[threads_per_simdgroup]]) {
 
  342  threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
 
  345  for (
int i = 0; i < N_READS; i++) {
 
  346    prefix[i] = Op::init;
 
  350  int offset = gid.y * axis_size * stride;
 
  351  int global_index_x = gid.x * lsize.y * N_READS;
 
  353  for (uint j = 0; j < axis_size; j += 
simd_size) {
 
  355    uint index_y = j + lid.y;
 
  356    uint check_index_y = index_y;
 
  357    uint index_x = global_index_x + lid.x * N_READS;
 
  359      index_y = axis_size - 1 - index_y;
 
  363    if (check_index_y < axis_size && (index_x + N_READS) < stride) {
 
  364      for (
int i = 0; i < N_READS; i++) {
 
  365        read_buffer[lid.y * 
simd_size * N_READS + lid.x * N_READS + i] =
 
  366            in[offset + index_y * stride + index_x + i];
 
  369      for (
int i = 0; i < N_READS; i++) {
 
  370        if (check_index_y < axis_size && (index_x + i) < stride) {
 
  371          read_buffer[lid.y * 
simd_size * N_READS + lid.x * N_READS + i] =
 
  372              in[offset + index_y * stride + index_x + i];
 
  374          read_buffer[lid.y * 
simd_size * N_READS + lid.x * N_READS + i] =
 
  379    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  382    for (
int i = 0; i < N_READS; i++) {
 
  384          read_buffer[lid.x * 
simd_size * N_READS + lid.y * N_READS + i];
 
  388    simdgroup_barrier(mem_flags::mem_threadgroup);
 
  391    for (
int i = 0; i < N_READS; i++) {
 
  392      values[i] = 
op.simd_scan(values[i]);
 
  393      values[i] = 
op(values[i], prefix[i]);
 
  394      prefix[i] = simd_shuffle(values[i], 
simd_size - 1);
 
  398    for (
int i = 0; i < N_READS; i++) {
 
  399      read_buffer[lid.x * 
simd_size * N_READS + lid.y * N_READS + i] =
 
  402    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  406      if (check_index_y == 0) {
 
  407        if ((index_x + N_READS) < stride) {
 
  408          for (
int i = 0; i < N_READS; i++) {
 
  409            out[offset + index_y * stride + index_x + i] = Op::init;
 
  412          for (
int i = 0; i < N_READS; i++) {
 
  413            if ((index_x + i) < stride) {
 
  414              out[offset + index_y * stride + index_x + i] = Op::init;
 
  427    if (check_index_y < axis_size && (index_x + N_READS) < stride) {
 
  428      for (
int i = 0; i < N_READS; i++) {
 
  429        out[offset + index_y * stride + index_x + i] =
 
  430            read_buffer[lid.y * 
simd_size * N_READS + lid.x * N_READS + i];
 
  433      for (
int i = 0; i < N_READS; i++) {
 
  434        if (check_index_y < axis_size && (index_x + i) < stride) {
 
  435          out[offset + index_y * stride + index_x + i] =
 
  436              read_buffer[lid.y * 
simd_size * N_READS + lid.x * N_READS + i];