MLX
 
Loading...
Searching...
No Matches
accelerate_simd.h
Go to the documentation of this file.
1#pragma once
2
3#include <simd/math.h>
4#include <simd/vector.h>
5
6#include <stdint.h>
7#include <cmath>
8#include <complex>
9
11
12// There seems to be a bug in sims/base.h
13// __XROS_2_0 is not defined, the expression evaluates
14// to true instead of false setting the SIMD library
15// higher than it should be even on macOS < 15
16#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
17 __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
18 __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
19 __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
20 __TV_OS_VERSION_MIN_REQUIRED >= 180000
21#define MLX_SIMD_LIBRARY_VERSION 6
22#else
23#define MLX_SIMD_LIBRARY_VERSION 5
24#endif
25
26namespace mlx::core::simd {
27
28// Apple simd namespace
29namespace asd = ::simd;
30
31// This indirection is needed to remap certain types to ones that accelerate
32// SIMD can handle
33template <typename T, int N>
34struct ScalarT {
35 using v = T;
36};
37template <int N>
38struct ScalarT<bool, N> {
39 using v = char;
40};
41template <int N>
42struct ScalarT<int8_t, N> {
43 using v = char;
44};
45template <int N>
46struct ScalarT<uint64_t, N> {
47 using v = unsigned long;
48};
49template <int N>
50struct ScalarT<int64_t, N> {
51 using v = long;
52};
53
54template <typename T, int N>
55struct Simd {
56 static constexpr int size = N;
57 using scalar_t = typename ScalarT<T, N>::v;
58
59 Simd<T, N>() {}
60
61 template <typename U>
62 Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
63
64 template <typename U>
65 Simd<T, N>(U v) : value(v){};
66
68 value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
69 x.value, y.value);
70 };
71
72 T operator[](int idx) const {
73 return reinterpret_cast<const T*>(&value)[idx];
74 }
75
76 T& operator[](int idx) {
77 return reinterpret_cast<T*>(&value)[idx];
78 }
79
80 typename asd::Vector<scalar_t, N>::packed_t value;
81};
82
83// Values chosen based on benchmarks on M3 Max
84// TODO: consider choosing these more optimally
85template <>
86static constexpr int max_size<int8_t> = 16;
87template <>
88static constexpr int max_size<int16_t> = 16;
89template <>
90static constexpr int max_size<int> = 8;
91template <>
92static constexpr int max_size<int64_t> = 4;
93template <>
94static constexpr int max_size<uint8_t> = 16;
95template <>
96static constexpr int max_size<uint16_t> = 16;
97template <>
98static constexpr int max_size<uint32_t> = 8;
99template <>
100static constexpr int max_size<uint64_t> = 4;
101template <>
102static constexpr int max_size<float> = 8;
103template <>
104static constexpr int max_size<double> = 4;
105
106#define SIMD_DEFAULT_UNARY(name, op) \
107 template <typename T, int N> \
108 Simd<T, N> name(Simd<T, N> v) { \
109 return op(v.value); \
110 }
111
134
135template <typename T, int N>
136Simd<T, N> operator-(Simd<T, N> v) {
137 return -v.value;
138}
139
140template <typename T, int N>
142 return ~v.value;
143}
144
145template <typename T, int N>
147 return asd::convert<char>(v.value != v.value);
148}
149
150// No simd_boolN in accelerate, use int8_t instead
151template <typename T, int N>
153 return asd::convert<char>(!v.value);
154}
155
156#define SIMD_DEFAULT_BINARY(OP) \
157 template <typename T, typename U, int N> \
158 Simd<T, N> operator OP(Simd<T, N> x, U y) { \
159 return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
160 } \
161 template <typename T1, typename T2, int N> \
162 Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
163 return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
164 } \
165 template <typename T1, typename T2, int N> \
166 Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
167 return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
168 }
169
181
182#define SIMD_DEFAULT_COMPARISONS(OP) \
183 template <int N, typename T, typename U> \
184 Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
185 return asd::convert<char>(a.value OP b); \
186 } \
187 template <int N, typename T, typename U> \
188 Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
189 return asd::convert<char>(a OP b.value); \
190 } \
191 template <int N, typename T1, typename T2> \
192 Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
193 return asd::convert<char>(a.value OP b.value); \
194 }
195
202
203template <typename T, int N>
204Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
205 return asd::atan2(a.value, b.value);
206}
207
208template <typename T, int N>
210 // TODO add isnan
211 return asd::max(a.value, b.value);
212}
213
214template <typename T, int N>
216 // TODO add isnan
217 return asd::min(a.value, b.value);
218}
219
220template <typename T, int N>
222 Simd<T, N> r;
223 if constexpr (!std::is_integral_v<T>) {
224 r = asd::remainder(a.value, b.value);
225 } else {
226 r = a - b * (a / b);
227 }
228 if constexpr (std::is_signed_v<T>) {
229 auto mask = r != 0 && (r < 0 != b < 0);
230 r = select(mask, r + b, r);
231 }
232 return r;
233}
234
235template <typename MaskT, typename T1, typename T2, int N>
237 if constexpr (sizeof(T1) == 1) {
238 return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
239 } else if constexpr (sizeof(T1) == 2) {
240 return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));
241 } else if constexpr (sizeof(T1) == 4) {
242 return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));
243 } else {
244 return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));
245 }
246}
247
248template <typename T, int N>
250 if constexpr (!std::is_integral_v<T>) {
251 return asd::pow(base.value, exp.value);
252 } else {
253 Simd<T, N> res = 1;
254 while (any(exp)) {
255 res = select(exp & 1, res * base, res);
256 base = select(exp, base * base, base);
257 exp = exp >> 1;
258 }
259 return res;
260 }
261}
262
263template <typename T, int N>
265 return asd::clamp(v.value, min.value, max.value);
266}
267
268template <typename T, typename U, int N>
270 return asd::muladd(x.value, y.value, Simd<T, N>(z).value);
271}
272
273// Reductions
274
275template <typename T, int N>
277 return asd::all(x.value);
278}
279template <typename T, int N>
281 return asd::any(x.value);
282}
283template <typename T, int N>
285 return asd::reduce_add(x.value);
286}
287template <typename T, int N>
289 return asd::reduce_max(x.value);
290}
291template <typename T, int N>
293 return asd::reduce_min(x.value);
294}
295
296template <typename T, int N>
298 auto ptr = (T*)&x;
299 auto lhs = load<T, N / 2>(ptr);
300 auto rhs = load<T, N / 2>(ptr + N / 2);
301 return prod(lhs * rhs);
302}
303
304} // namespace mlx::core::simd
305
306#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
308#endif
#define SIMD_DEFAULT_UNARY(name, op)
Definition accelerate_simd.h:106
#define SIMD_DEFAULT_BINARY(OP)
Definition accelerate_simd.h:156
#define SIMD_DEFAULT_COMPARISONS(OP)
Definition accelerate_simd.h:182
Definition accelerate_fp16_simd.h:9
Simd< bool, N > isnan(Simd< T, N > v)
Definition accelerate_simd.h:146
Simd< float16_t, N > sinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:41
constexpr int N
Definition neon_fp16_simd.h:9
Simd< float16_t, N > atanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:34
Simd< T, N > minimum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:215
Simd< float16_t, N > pow(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:54
Simd< float16_t, N > atan2(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:52
T prod(Simd< T, N > x)
Definition accelerate_simd.h:297
Simd< T, N > operator~(Simd< T, N > v)
Definition accelerate_simd.h:141
Simd< float16_t, N > log10(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:39
Simd< T, N > rint(Simd< T, N > v)
Definition accelerate_simd.h:127
Simd< T, N > load(const T *x)
Definition base_simd.h:28
Simd< float16_t, N > tan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:42
Simd< T, N > abs(Simd< T, N > v)
Definition accelerate_simd.h:112
Simd< float16_t, N > acosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:30
bool all(Simd< T, N > x)
Definition accelerate_simd.h:276
T sum(Simd< T, N > x)
Definition accelerate_simd.h:284
Simd< float16_t, N > log2(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:38
T max(Simd< T, N > x)
Definition accelerate_simd.h:288
Simd< bool, N > operator!(Simd< T, N > v)
Definition accelerate_simd.h:152
Simd< T, N > maximum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:209
Simd< T, N > exp(Simd< T, N > in)
Compute exp(x) in an optimizer friendly way as follows:
Definition math.h:28
Simd< float16_t, N > log(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:37
Simd< T, N > floor(Simd< T, N > v)
Definition accelerate_simd.h:113
Simd< float16_t, N > expm1(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:36
Simd< float16_t, N > asin(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:31
bool any(Simd< T, N > x)
Definition accelerate_simd.h:280
Simd< T, N > fma(Simd< T, N > x, Simd< T, N > y, U z)
Definition accelerate_simd.h:269
Simd< float16_t, N > tanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:43
Simd< float16_t, N > atan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:33
Simd< float16_t, N > asinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:32
Simd< float16_t, N > remainder(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:53
static constexpr int max_size
Definition base_simd.h:14
T min(Simd< T, N > x)
Definition accelerate_simd.h:292
Simd< float16_t, N > log1p(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:40
Simd< T, N > ceil(Simd< T, N > v)
Definition accelerate_simd.h:120
Simd< T, N > recip(Simd< T, N > v)
Definition accelerate_simd.h:131
Simd< T, N > sqrt(Simd< T, N > v)
Definition accelerate_simd.h:129
Simd< T, N > clamp(Simd< T, N > v, Simd< T, N > min, Simd< T, N > max)
Definition accelerate_simd.h:264
Simd< float16_t, N > acos(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:29
Simd< T, N > rsqrt(Simd< T, N > v)
Definition accelerate_simd.h:130
Simd< float16_t, N > cosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:35
Simd< T1, N > select(Simd< MaskT, N > mask, Simd< T1, N > x, Simd< T2, N > y)
Definition accelerate_simd.h:236
char v
Definition accelerate_simd.h:39
long v
Definition accelerate_simd.h:51
char v
Definition accelerate_simd.h:43
unsigned long v
Definition accelerate_simd.h:47
Definition accelerate_simd.h:34
T v
Definition accelerate_simd.h:35
Definition accelerate_simd.h:55
T operator[](int idx) const
Definition accelerate_simd.h:72
typename ScalarT< T, N >::v scalar_t
Definition accelerate_simd.h:57
asd::Vector< scalar_t, N >::packed_t value
Definition accelerate_simd.h:80
Simd()
Definition accelerate_simd.h:59
static constexpr int size
Definition accelerate_simd.h:56
T & operator[](int idx)
Definition accelerate_simd.h:76