31    const device T* in [[buffer(0)]],
 
   32    device T* out [[buffer(1)]],
 
   33    constant 
const float& scale,
 
   34    uint3 elem [[thread_position_in_grid]],
 
   35    uint3 grid [[threads_per_grid]]) {
 
   42  constexpr short num_threads = N / max_radix;
 
   43  constexpr short logN = __builtin_ctz(N);
 
   44  constexpr short logR = __builtin_ctz(max_radix);
 
   45  constexpr short num_steps = logN / logR;
 
   46  constexpr short logFinal = logN % logR;
 
   47  constexpr short final_radix = 1 << (logFinal);
 
   49  int batch_idx = elem.x * N;
 
   56  for (
short j = 0; j < max_radix / read_width; j++) {
 
   57    short index = j * read_width * num_threads + i * read_width;
 
   59    for (
short r = 0; r < read_width; r++) {
 
   60      buf[index + r] = in[batch_idx + index + r];
 
   64  threadgroup_barrier(mem_flags::mem_threadgroup);
 
   70  for (
short s = 0; s < num_steps; s++) {
 
   71    short k = i & (h - 1);
 
   72    short j = ((i - k) << logR) + k;
 
   75    for (
short r = 0; r < max_radix; r++) {
 
   76      x[r] = 
buf[j + h * r];
 
   79    radix_func<max_radix>(x);
 
   82    for (
short r = 0; r < max_radix; r++) {
 
   83      buf[j + h * r] = T(x[r]);
 
   87    threadgroup_barrier(mem_flags::mem_threadgroup);
 
   93  if (final_radix > 1) {
 
   96    for (
int t = 0; t < max_radix / final_radix; t++) {
 
   97      short index = i + t * num_threads;
 
   98      short k = index & (h - 1);
 
   99      short j = ((index - k) << logFinal) + k;
 
  101      for (
short r = 0; r < final_radix; r++) {
 
  102        x[r] = 
buf[j + h * r];
 
  105      radix_func<final_radix>(x);
 
  108      for (
short r = 0; r < final_radix; r++) {
 
  109        buf[j + h * r] = T(x[r]);
 
  112    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  117  for (
short j = 0; j < max_radix / read_width; j++) {
 
  118    short index = j * read_width * num_threads + i * read_width;
 
  120    for (
short r = 0; r < read_width; r++) {
 
  121      out[batch_idx + index + r] = T(
buf[index + r] * scale);
 
 
  128    const device T* in [[buffer(0)]],
 
  129    device T* out [[buffer(1)]],
 
  130    constant 
const float& scale,
 
  131    uint3 elem [[thread_position_in_grid]],
 
  132    uint3 grid [[threads_per_grid]]) {
 
  139  int index = elem.x * grid.y + elem.y;
 
  140  short i = index % (N / read_width);
 
  141  int batch_idx = index / (N / read_width) * M * N;
 
  143  float x[read_width][M];
 
  145  for (
short c = 0; c < M; c++) {
 
  147    for (
short r = 0; r < read_width; r++) {
 
  148      x[r][c] = in[batch_idx + c * N + i * read_width + r];
 
  153  for (
short r = 0; r < read_width; r++) {
 
  156    hadamard_radix_m(x[r]);
 
  161  for (
short c = 0; c < M; c++) {
 
  163    for (
short r = 0; r < read_width; r++) {
 
  164      out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);