MLX
Loading...
Searching...
No Matches
radix.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3/* Radix kernels
4
5We provide optimized, single threaded Radix codelets
6for n=2,3,4,5,6,7,8,10,11,12,13.
7
8For n=2,3,4,5,6 we hand write the codelets.
9For n=8,10,12 we combine smaller codelets.
10For n=7,11,13 we use Rader's algorithm which decomposes
11them into (n-1)=6,10,12 codelets. */
12
13#pragma once
14
15#include <metal_common>
16#include <metal_math>
17#include <metal_stdlib>
18
19METAL_FUNC float2 complex_mul(float2 a, float2 b) {
20 return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
21}
22
23// Complex mul followed by conjugate
24METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
25 return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);
26}
27
28// Compute an FFT twiddle factor
29METAL_FUNC float2 get_twiddle(int k, int p) {
30 float theta = -2.0f * k * M_PI_F / p;
31
32 float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};
33 return twiddle;
34}
35
36METAL_FUNC void radix2(thread float2* x, thread float2* y) {
37 y[0] = x[0] + x[1];
38 y[1] = x[0] - x[1];
39}
40
41METAL_FUNC void radix3(thread float2* x, thread float2* y) {
42 float pi_2_3 = -0.8660254037844387;
43
44 float2 a_1 = x[1] + x[2];
45 float2 a_2 = x[1] - x[2];
46
47 y[0] = x[0] + a_1;
48 float2 b_1 = x[0] - 0.5 * a_1;
49 float2 b_2 = pi_2_3 * a_2;
50
51 float2 b_2_j = {-b_2.y, b_2.x};
52 y[1] = b_1 + b_2_j;
53 y[2] = b_1 - b_2_j;
54}
55
56METAL_FUNC void radix4(thread float2* x, thread float2* y) {
57 float2 z_0 = x[0] + x[2];
58 float2 z_1 = x[0] - x[2];
59 float2 z_2 = x[1] + x[3];
60 float2 z_3 = x[1] - x[3];
61 float2 z_3_i = {z_3.y, -z_3.x};
62
63 y[0] = z_0 + z_2;
64 y[1] = z_1 + z_3_i;
65 y[2] = z_0 - z_2;
66 y[3] = z_1 - z_3_i;
67}
68
69METAL_FUNC void radix5(thread float2* x, thread float2* y) {
70 float2 root_5_4 = 0.5590169943749475;
71 float2 sin_2pi_5 = 0.9510565162951535;
72 float2 sin_1pi_5 = 0.5877852522924731;
73
74 float2 a_1 = x[1] + x[4];
75 float2 a_2 = x[2] + x[3];
76 float2 a_3 = x[1] - x[4];
77 float2 a_4 = x[2] - x[3];
78
79 float2 a_5 = a_1 + a_2;
80 float2 a_6 = root_5_4 * (a_1 - a_2);
81 float2 a_7 = x[0] - a_5 / 4;
82 float2 a_8 = a_7 + a_6;
83 float2 a_9 = a_7 - a_6;
84 float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4;
85 float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4;
86 float2 a_10_j = {a_10.y, -a_10.x};
87 float2 a_11_j = {a_11.y, -a_11.x};
88
89 y[0] = x[0] + a_5;
90 y[1] = a_8 + a_10_j;
91 y[2] = a_9 + a_11_j;
92 y[3] = a_9 - a_11_j;
93 y[4] = a_8 - a_10_j;
94}
95
96METAL_FUNC void radix6(thread float2* x, thread float2* y) {
97 float sin_pi_3 = 0.8660254037844387;
98 float2 a_1 = x[2] + x[4];
99 float2 a_2 = x[0] - a_1 / 2;
100 float2 a_3 = sin_pi_3 * (x[2] - x[4]);
101 float2 a_4 = x[5] + x[1];
102 float2 a_5 = x[3] - a_4 / 2;
103 float2 a_6 = sin_pi_3 * (x[5] - x[1]);
104 float2 a_7 = x[0] + a_1;
105
106 float2 a_3_i = {a_3.y, -a_3.x};
107 float2 a_6_i = {a_6.y, -a_6.x};
108 float2 a_8 = a_2 + a_3_i;
109 float2 a_9 = a_2 - a_3_i;
110 float2 a_10 = x[3] + a_4;
111 float2 a_11 = a_5 + a_6_i;
112 float2 a_12 = a_5 - a_6_i;
113
114 y[0] = a_7 + a_10;
115 y[1] = a_8 - a_11;
116 y[2] = a_9 + a_12;
117 y[3] = a_7 - a_10;
118 y[4] = a_8 + a_11;
119 y[5] = a_9 - a_12;
120}
121
122METAL_FUNC void radix7(thread float2* x, thread float2* y) {
123 // Rader's algorithm
124 float2 inv = {1 / 6.0, -1 / 6.0};
125
126 // fft
127 float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]};
128 radix6(in1, y + 1);
129
130 y[0] = y[1] + x[0];
131
132 // b_q
133 y[1] = complex_mul_conj(y[1], float2(-1, 0));
134 y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879));
135 y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629));
136 y[4] = complex_mul_conj(y[4], float2(0, -2.64575131));
137 y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629));
138 y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879));
139
140 // ifft
141 radix6(y + 1, x + 1);
142
143 y[1] = x[1] * inv + x[0];
144 y[5] = x[2] * inv + x[0];
145 y[4] = x[3] * inv + x[0];
146 y[6] = x[4] * inv + x[0];
147 y[2] = x[5] * inv + x[0];
148 y[3] = x[6] * inv + x[0];
149}
150
151METAL_FUNC void radix8(thread float2* x, thread float2* y) {
152 float cos_pi_4 = 0.7071067811865476;
153 float2 w_0 = {cos_pi_4, -cos_pi_4};
154 float2 w_1 = {-cos_pi_4, -cos_pi_4};
155 float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]};
156 radix4(temp, x);
157 radix4(temp + 4, x + 4);
158
159 y[0] = x[0] + x[4];
160 y[4] = x[0] - x[4];
161 float2 x_5 = complex_mul(x[5], w_0);
162 y[1] = x[1] + x_5;
163 y[5] = x[1] - x_5;
164 float2 x_6 = {x[6].y, -x[6].x};
165 y[2] = x[2] + x_6;
166 y[6] = x[2] - x_6;
167 float2 x_7 = complex_mul(x[7], w_1);
168 y[3] = x[3] + x_7;
169 y[7] = x[3] - x_7;
170}
171
172template <bool raders_perm>
173METAL_FUNC void radix10(thread float2* x, thread float2* y) {
174 float2 w[4];
175 w[0] = {0.8090169943749475, -0.5877852522924731};
176 w[1] = {0.30901699437494745, -0.9510565162951535};
177 w[2] = {-w[1].x, w[1].y};
178 w[3] = {-w[0].x, w[0].y};
179
180 if (raders_perm) {
181 float2 temp[10] = {
182 x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]};
183 radix5(temp, x);
184 radix5(temp + 5, x + 5);
185 } else {
186 float2 temp[10] = {
187 x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]};
188 radix5(temp, x);
189 radix5(temp + 5, x + 5);
190 }
191
192 y[0] = x[0] + x[5];
193 y[5] = x[0] - x[5];
194 for (int t = 1; t < 5; t++) {
195 float2 a = complex_mul(x[t + 5], w[t - 1]);
196 y[t] = x[t] + a;
197 y[t + 5] = x[t] - a;
198 }
199}
200
201METAL_FUNC void radix11(thread float2* x, thread float2* y) {
202 // Raders Algorithm
203 float2 inv = {1 / 10.0, -1 / 10.0};
204
205 // fft
206 radix10<true>(x + 1, y + 1);
207
208 y[0] = y[1] + x[0];
209
210 // b_q
211 y[1] = complex_mul_conj(y[1], float2(-1, 0));
212 y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649));
213 y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656));
214 y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479));
215 y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150));
216 y[6] = complex_mul_conj(y[6], float2(0, -3.31662479));
217 y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150));
218 y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479));
219 y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656));
220 y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649));
221
222 // ifft
223 radix10<false>(y + 1, x + 1);
224
225 y[1] = x[1] * inv + x[0];
226 y[6] = x[2] * inv + x[0];
227 y[3] = x[3] * inv + x[0];
228 y[7] = x[4] * inv + x[0];
229 y[9] = x[5] * inv + x[0];
230 y[10] = x[6] * inv + x[0];
231 y[5] = x[7] * inv + x[0];
232 y[8] = x[8] * inv + x[0];
233 y[4] = x[9] * inv + x[0];
234 y[2] = x[10] * inv + x[0];
235}
236
237template <bool raders_perm>
238METAL_FUNC void radix12(thread float2* x, thread float2* y) {
239 float2 w[6];
240 float sin_pi_3 = 0.8660254037844387;
241 w[0] = {sin_pi_3, -0.5};
242 w[1] = {0.5, -sin_pi_3};
243 w[2] = {0, -1};
244 w[3] = {-0.5, -sin_pi_3};
245 w[4] = {-sin_pi_3, -0.5};
246
247 if (raders_perm) {
248 float2 temp[12] = {
249 x[0],
250 x[3],
251 x[2],
252 x[11],
253 x[8],
254 x[9],
255 x[1],
256 x[7],
257 x[5],
258 x[10],
259 x[4],
260 x[6]};
261 radix6(temp, x);
262 radix6(temp + 6, x + 6);
263 } else {
264 float2 temp[12] = {
265 x[0],
266 x[2],
267 x[4],
268 x[6],
269 x[8],
270 x[10],
271 x[1],
272 x[3],
273 x[5],
274 x[7],
275 x[9],
276 x[11]};
277 radix6(temp, x);
278 radix6(temp + 6, x + 6);
279 }
280
281 y[0] = x[0] + x[6];
282 y[6] = x[0] - x[6];
283 for (int t = 1; t < 6; t++) {
284 float2 a = complex_mul(x[t + 6], w[t - 1]);
285 y[t] = x[t] + a;
286 y[t + 6] = x[t] - a;
287 }
288}
289
290METAL_FUNC void radix13(thread float2* x, thread float2* y) {
291 // Raders Algorithm
292 float2 inv = {1 / 12.0, -1 / 12.0};
293
294 // fft
295 radix12<true>(x + 1, y + 1);
296
297 y[0] = y[1] + x[0];
298
299 // b_q
300 y[1] = complex_mul_conj(y[1], float2(-1, 0));
301 y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669));
302 y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823));
303 y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161));
304 y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690));
305 y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267));
306 y[7] = complex_mul_conj(y[7], float2(3.60555128, 0));
307 y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267));
308 y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690));
309 y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161));
310 y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823));
311 y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669));
312
313 // ifft
314 radix12<false>(y + 1, x + 1);
315
316 y[1] = x[1] * inv + x[0];
317 y[7] = x[2] * inv + x[0];
318 y[10] = x[3] * inv + x[0];
319 y[5] = x[4] * inv + x[0];
320 y[9] = x[5] * inv + x[0];
321 y[11] = x[6] * inv + x[0];
322 y[12] = x[7] * inv + x[0];
323 y[6] = x[8] * inv + x[0];
324 y[3] = x[9] * inv + x[0];
325 y[8] = x[10] * inv + x[0];
326 y[4] = x[11] * inv + x[0];
327 y[2] = x[12] * inv + x[0];
328}
METAL_FUNC bfloat16_t sin(bfloat16_t x)
Definition bf16_math.h:242
METAL_FUNC bfloat16_t cos(bfloat16_t x)
Definition bf16_math.h:242
METAL_FUNC void radix5(thread float2 *x, thread float2 *y)
Definition radix.h:69
METAL_FUNC float2 complex_mul_conj(float2 a, float2 b)
Definition radix.h:24
METAL_FUNC void radix4(thread float2 *x, thread float2 *y)
Definition radix.h:56
METAL_FUNC void radix10(thread float2 *x, thread float2 *y)
Definition radix.h:173
METAL_FUNC void radix11(thread float2 *x, thread float2 *y)
Definition radix.h:201
METAL_FUNC void radix12(thread float2 *x, thread float2 *y)
Definition radix.h:238
METAL_FUNC void radix3(thread float2 *x, thread float2 *y)
Definition radix.h:41
METAL_FUNC float2 complex_mul(float2 a, float2 b)
Definition radix.h:19
METAL_FUNC void radix8(thread float2 *x, thread float2 *y)
Definition radix.h:151
METAL_FUNC void radix7(thread float2 *x, thread float2 *y)
Definition radix.h:122
METAL_FUNC void radix2(thread float2 *x, thread float2 *y)
Definition radix.h:36
METAL_FUNC void radix13(thread float2 *x, thread float2 *y)
Definition radix.h:290
METAL_FUNC float2 get_twiddle(int k, int p)
Definition radix.h:29
METAL_FUNC void radix6(thread float2 *x, thread float2 *y)
Definition radix.h:96