MLX
 
Loading...
Searching...
No Matches
hadamard.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2#include <metal_common>
3#include <metal_compute>
4
6
7using namespace metal;
8
9// Thread local Hadamard transform for 2^R
10template <short R>
11METAL_FUNC void radix_func(thread float* x) {
12 constexpr short logR = __builtin_ctz(R);
13 short h = 1;
15 for (short s = 0; s < logR; s++) {
17 for (short i = 0; i < R / 2; i++) {
18 short k = i & (h - 1);
19 short j = ((i - k) << 1) + k;
20 float a = x[j];
21 float b = x[j + h];
22 x[j] = a + b;
23 x[j + h] = a - b;
24 }
25 h <<= 1;
26 }
27}
28
29template <typename T, int N, int max_radix, int read_width>
30[[kernel]] void hadamard_n(
31 const device T* in [[buffer(0)]],
32 device T* out [[buffer(1)]],
33 constant const float& scale,
34 uint3 elem [[thread_position_in_grid]],
35 uint3 grid [[threads_per_grid]]) {
36 // Compute a Hadamard transform of size N = 2^k
37 //
38 // Equivalent to:
39 // from scipy.linalg import hadamard
40 // y = hadamard(len(x)) @ x
41
42 constexpr short num_threads = N / max_radix;
43 constexpr short logN = __builtin_ctz(N);
44 constexpr short logR = __builtin_ctz(max_radix);
45 constexpr short num_steps = logN / logR;
46 constexpr short logFinal = logN % logR;
47 constexpr short final_radix = 1 << (logFinal);
48
49 int batch_idx = elem.x * N;
50 short i = elem.y;
51
52 threadgroup T buf[N];
53
54 // Read values from device
56 for (short j = 0; j < max_radix / read_width; j++) {
57 short index = j * read_width * num_threads + i * read_width;
59 for (short r = 0; r < read_width; r++) {
60 buf[index + r] = in[batch_idx + index + r];
61 }
62 }
63
64 threadgroup_barrier(mem_flags::mem_threadgroup);
65
66 float x[max_radix];
67 short h = 1;
68
70 for (short s = 0; s < num_steps; s++) {
71 short k = i & (h - 1);
72 short j = ((i - k) << logR) + k;
73
75 for (short r = 0; r < max_radix; r++) {
76 x[r] = buf[j + h * r];
77 }
78
80
82 for (short r = 0; r < max_radix; r++) {
83 buf[j + h * r] = T(x[r]);
84 }
85
86 h <<= logR;
87 threadgroup_barrier(mem_flags::mem_threadgroup);
88 }
89
90 // Do the final radix
91 // e.g. max_radix = 16
92 // N = 1024 = 16 * 16 * 4
93 if (final_radix > 1) {
94 // Each thread does multiple butterflies
96 for (int t = 0; t < max_radix / final_radix; t++) {
97 short index = i + t * num_threads;
98 short k = index & (h - 1);
99 short j = ((index - k) << logFinal) + k;
101 for (short r = 0; r < final_radix; r++) {
102 x[r] = buf[j + h * r];
103 }
104
106
108 for (short r = 0; r < final_radix; r++) {
109 buf[j + h * r] = T(x[r]);
110 }
111 }
112 threadgroup_barrier(mem_flags::mem_threadgroup);
113 }
114
115 // Write values to device
117 for (short j = 0; j < max_radix / read_width; j++) {
118 short index = j * read_width * num_threads + i * read_width;
120 for (short r = 0; r < read_width; r++) {
121 out[batch_idx + index + r] = T(buf[index + r] * scale);
122 }
123 }
124}
125
126template <typename T, int N, int M, int read_width>
127[[kernel]] void hadamard_m(
128 const device T* in [[buffer(0)]],
129 device T* out [[buffer(1)]],
130 constant const float& scale,
131 uint3 elem [[thread_position_in_grid]],
132 uint3 grid [[threads_per_grid]]) {
133 // Compute a Hadamard transform of size M
134 // using a naive O(M^2) codelet.
135 //
136 // This kernel is the second stage in the computation
137 // of a Hadamard transform of size M*N where N = 2^k.
138
139 int index = elem.x * grid.y + elem.y;
140 short i = index % (N / read_width);
141 int batch_idx = index / (N / read_width) * M * N;
142
143 float x[read_width][M];
145 for (short c = 0; c < M; c++) {
147 for (short r = 0; r < read_width; r++) {
148 x[r][c] = in[batch_idx + c * N + i * read_width + r];
149 }
150 }
151
153 for (short r = 0; r < read_width; r++) {
154 // This function is JIT compiled for M
155 // using the Hadamard matrix strings in `metal/hadamard.cpp`
156 hadamard_radix_m(x[r]);
157 }
158
159 // Write back to device
161 for (short c = 0; c < M; c++) {
163 for (short r = 0; r < read_width; r++) {
164 out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
165 }
166 }
167}
METAL_FUNC void radix_func(thread float *x)
Definition hadamard.h:11
void hadamard_n(const device T *in, device T *out, constant const float &scale, uint3 elem, uint3 grid)
Definition hadamard.h:30
void hadamard_m(const device T *in, device T *out, constant const float &scale, uint3 elem, uint3 grid)
Definition hadamard.h:127
Definition bf16_math.h:226
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4