31 const device T* in [[buffer(0)]],
32 device T* out [[buffer(1)]],
33 constant
const float& scale,
34 uint3 elem [[thread_position_in_grid]],
35 uint3 grid [[threads_per_grid]]) {
42 constexpr short num_threads = N / max_radix;
43 constexpr short logN = __builtin_ctz(N);
44 constexpr short logR = __builtin_ctz(max_radix);
45 constexpr short num_steps = logN / logR;
46 constexpr short logFinal = logN % logR;
47 constexpr short final_radix = 1 << (logFinal);
49 int batch_idx = elem.x * N;
56 for (
short j = 0; j < max_radix / read_width; j++) {
57 short index = j * read_width * num_threads + i * read_width;
59 for (
short r = 0; r < read_width; r++) {
60 buf[index + r] = in[batch_idx + index + r];
64 threadgroup_barrier(mem_flags::mem_threadgroup);
70 for (
short s = 0; s < num_steps; s++) {
71 short k = i & (h - 1);
72 short j = ((i - k) << logR) + k;
75 for (
short r = 0; r < max_radix; r++) {
76 x[r] =
buf[j + h * r];
82 for (
short r = 0; r < max_radix; r++) {
83 buf[j + h * r] = T(x[r]);
87 threadgroup_barrier(mem_flags::mem_threadgroup);
93 if (final_radix > 1) {
96 for (
int t = 0; t < max_radix / final_radix; t++) {
97 short index = i + t * num_threads;
98 short k = index & (h - 1);
99 short j = ((index - k) << logFinal) + k;
101 for (
short r = 0; r < final_radix; r++) {
102 x[r] =
buf[j + h * r];
108 for (
short r = 0; r < final_radix; r++) {
109 buf[j + h * r] = T(x[r]);
112 threadgroup_barrier(mem_flags::mem_threadgroup);
117 for (
short j = 0; j < max_radix / read_width; j++) {
118 short index = j * read_width * num_threads + i * read_width;
120 for (
short r = 0; r < read_width; r++) {
121 out[batch_idx + index + r] = T(
buf[index + r] * scale);
128 const device T* in [[buffer(0)]],
129 device T* out [[buffer(1)]],
130 constant
const float& scale,
131 uint3 elem [[thread_position_in_grid]],
132 uint3 grid [[threads_per_grid]]) {
139 int index = elem.x * grid.y + elem.y;
140 short i = index % (N / read_width);
141 int batch_idx = index / (N / read_width) * M * N;
143 float x[read_width][M];
145 for (
short c = 0; c < M; c++) {
147 for (
short r = 0; r < read_width; r++) {
148 x[r][c] = in[batch_idx + c * N + i * read_width + r];
153 for (
short r = 0; r < read_width; r++) {
156 hadamard_radix_m(x[r]);
161 for (
short c = 0; c < M; c++) {
163 for (
short r = 0; r < read_width; r++) {
164 out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);