MLX
 
Loading...
Searching...
No Matches
accelerate_simd.h
Go to the documentation of this file.
1#pragma once
2
3#include <simd/math.h>
4#include <simd/vector.h>
5
6#include <stdint.h>
7#include <cmath>
8#include <complex>
9
11
12// There seems to be a bug in sims/base.h
13// __XROS_2_0 is not defined, the expression evaluates
14// to true instead of false setting the SIMD library
15// higher than it should be even on macOS < 15
16#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
17 __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
18 __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
19 __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
20 __TV_OS_VERSION_MIN_REQUIRED >= 180000
21#define MLX_SIMD_LIBRARY_VERSION 6
22#else
23#define MLX_SIMD_LIBRARY_VERSION 5
24#endif
25
26namespace mlx::core::simd {
27
28// Apple simd namespace
29namespace asd = ::simd;
30
31// This indirection is needed to remap certain types to ones that accelerate
32// SIMD can handle
33template <typename T, int N>
34struct ScalarT {
35 using v = T;
36};
37template <int N>
38struct ScalarT<bool, N> {
39 using v = char;
40};
41template <int N>
42struct ScalarT<int8_t, N> {
43 using v = char;
44};
45template <int N>
46struct ScalarT<uint64_t, N> {
47 using v = unsigned long;
48};
49template <int N>
50struct ScalarT<int64_t, N> {
51 using v = long;
52};
53
54template <typename T, int N>
55struct Simd {
56 static constexpr int size = N;
57 using scalar_t = typename ScalarT<T, N>::v;
58
59 Simd<T, N>() {}
60
61 template <typename U>
62 Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
63
64 template <typename U>
65 Simd<T, N>(U v) : value(v){};
66
68 value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
69 x.value, y.value);
70 };
71
72 T operator[](int idx) const {
73 return reinterpret_cast<const T*>(&value)[idx];
74 }
75
76 T& operator[](int idx) {
77 return reinterpret_cast<T*>(&value)[idx];
78 }
79
80 typename asd::Vector<scalar_t, N>::packed_t value;
81};
82
83// Values chosen based on benchmarks on M3 Max
84// TODO: consider choosing these more optimally
85template <>
86static constexpr int max_size<int8_t> = 16;
87template <>
88static constexpr int max_size<int16_t> = 16;
89template <>
90static constexpr int max_size<int> = 8;
91template <>
92static constexpr int max_size<int64_t> = 4;
93template <>
94static constexpr int max_size<uint8_t> = 16;
95template <>
96static constexpr int max_size<uint16_t> = 16;
97template <>
98static constexpr int max_size<uint32_t> = 8;
99template <>
100static constexpr int max_size<uint64_t> = 4;
101template <>
102static constexpr int max_size<float> = 8;
103template <>
104static constexpr int max_size<double> = 4;
105
106#define SIMD_DEFAULT_UNARY(name, op) \
107 template <typename T, int N> \
108 Simd<T, N> name(Simd<T, N> v) { \
109 return op(v.value); \
110 }
111
134
135template <typename T, int N>
136Simd<T, N> operator-(Simd<T, N> v) {
137 return -v.value;
138}
139
140template <typename T, int N>
142 return asd::convert<char>(v.value != v.value);
143}
144
145// No simd_boolN in accelerate, use int8_t instead
146template <typename T, int N>
148 return asd::convert<char>(!v.value);
149}
150
151#define SIMD_DEFAULT_BINARY(OP) \
152 template <typename T, typename U, int N> \
153 Simd<T, N> operator OP(Simd<T, N> x, U y) { \
154 return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
155 } \
156 template <typename T1, typename T2, int N> \
157 Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
158 return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
159 } \
160 template <typename T1, typename T2, int N> \
161 Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
162 return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
163 }
164
176
177#define SIMD_DEFAULT_COMPARISONS(OP) \
178 template <int N, typename T, typename U> \
179 Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
180 return asd::convert<char>(a.value OP b); \
181 } \
182 template <int N, typename T, typename U> \
183 Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
184 return asd::convert<char>(a OP b.value); \
185 } \
186 template <int N, typename T1, typename T2> \
187 Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
188 return asd::convert<char>(a.value OP b.value); \
189 }
190
197
198template <typename T, int N>
199Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
200 return asd::atan2(a.value, b.value);
201}
202
203template <typename T, int N>
205 // TODO add isnan
206 return asd::max(a.value, b.value);
207}
208
209template <typename T, int N>
211 // TODO add isnan
212 return asd::min(a.value, b.value);
213}
214
215template <typename T, int N>
217 Simd<T, N> r;
218 if constexpr (!std::is_integral_v<T>) {
219 r = asd::remainder(a.value, b.value);
220 } else {
221 r = a - b * (a / b);
222 }
223 if constexpr (std::is_signed_v<T>) {
224 auto mask = r != 0 && (r < 0 != b < 0);
225 r = select(mask, r + b, r);
226 }
227 return r;
228}
229
230template <typename MaskT, typename T1, typename T2, int N>
232 if constexpr (sizeof(T1) == 1) {
233 return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
234 } else if constexpr (sizeof(T1) == 2) {
235 return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));
236 } else if constexpr (sizeof(T1) == 4) {
237 return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));
238 } else {
239 return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));
240 }
241}
242
243template <typename T, int N>
245 if constexpr (!std::is_integral_v<T>) {
246 return asd::pow(base.value, exp.value);
247 } else {
248 Simd<T, N> res = 1;
249 while (any(exp)) {
250 res = select(exp & 1, res * base, res);
251 base = select(exp, base * base, base);
252 exp = exp >> 1;
253 }
254 return res;
255 }
256}
257
258template <typename T, int N>
260 return asd::clamp(v.value, min.value, max.value);
261}
262
263template <typename T, typename U, int N>
265 return asd::muladd(x.value, y.value, Simd<T, N>(z).value);
266}
267
268// Reductions
269
270template <typename T, int N>
272 return asd::all(x.value);
273}
274template <typename T, int N>
276 return asd::any(x.value);
277}
278template <typename T, int N>
280 return asd::reduce_add(x.value);
281}
282template <typename T, int N>
284 return asd::reduce_max(x.value);
285}
286template <typename T, int N>
288 return asd::reduce_min(x.value);
289}
290
291template <typename T, int N>
293 auto ptr = (T*)&x;
294 auto lhs = load<T, N / 2>(ptr);
295 auto rhs = load<T, N / 2>(ptr + N / 2);
296 return prod(lhs * rhs);
297}
298
299} // namespace mlx::core::simd
300
301#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
303#endif
#define SIMD_DEFAULT_UNARY(name, op)
Definition accelerate_simd.h:106
#define SIMD_DEFAULT_BINARY(OP)
Definition accelerate_simd.h:151
#define SIMD_DEFAULT_COMPARISONS(OP)
Definition accelerate_simd.h:177
Definition accelerate_fp16_simd.h:9
Simd< bool, N > isnan(Simd< T, N > v)
Definition accelerate_simd.h:141
Simd< float16_t, N > sinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:41
constexpr int N
Definition neon_fp16_simd.h:9
Simd< float16_t, N > atanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:34
Simd< T, N > minimum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:210
Simd< float16_t, N > pow(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:54
Simd< float16_t, N > atan2(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:52
T prod(Simd< T, N > x)
Definition accelerate_simd.h:292
Simd< float16_t, N > log10(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:39
Simd< T, N > rint(Simd< T, N > v)
Definition accelerate_simd.h:127
Simd< T, N > load(const T *x)
Definition base_simd.h:27
Simd< float16_t, N > tan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:42
Simd< T, N > abs(Simd< T, N > v)
Definition accelerate_simd.h:112
Simd< float16_t, N > acosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:30
bool all(Simd< T, N > x)
Definition accelerate_simd.h:271
T sum(Simd< T, N > x)
Definition accelerate_simd.h:279
Simd< float16_t, N > log2(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:38
T max(Simd< T, N > x)
Definition accelerate_simd.h:283
Simd< bool, N > operator!(Simd< T, N > v)
Definition accelerate_simd.h:147
Simd< T, N > maximum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:204
Simd< T, N > exp(Simd< T, N > in)
Compute exp(x) in an optimizer friendly way as follows:
Definition math.h:28
Simd< float16_t, N > log(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:37
Simd< T, N > floor(Simd< T, N > v)
Definition accelerate_simd.h:113
Simd< float16_t, N > expm1(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:36
Simd< float16_t, N > asin(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:31
bool any(Simd< T, N > x)
Definition accelerate_simd.h:275
Simd< T, N > fma(Simd< T, N > x, Simd< T, N > y, U z)
Definition accelerate_simd.h:264
Simd< float16_t, N > tanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:43
Simd< float16_t, N > atan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:33
Simd< float16_t, N > asinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:32
Simd< float16_t, N > remainder(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:53
static constexpr int max_size
Definition base_simd.h:13
T min(Simd< T, N > x)
Definition accelerate_simd.h:287
Simd< float16_t, N > log1p(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:40
Simd< T, N > ceil(Simd< T, N > v)
Definition accelerate_simd.h:120
Simd< T, N > recip(Simd< T, N > v)
Definition accelerate_simd.h:131
Simd< T, N > sqrt(Simd< T, N > v)
Definition accelerate_simd.h:129
Simd< T, N > clamp(Simd< T, N > v, Simd< T, N > min, Simd< T, N > max)
Definition accelerate_simd.h:259
Simd< float16_t, N > acos(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:29
Simd< T, N > rsqrt(Simd< T, N > v)
Definition accelerate_simd.h:130
Simd< float16_t, N > cosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:35
Simd< T1, N > select(Simd< MaskT, N > mask, Simd< T1, N > x, Simd< T2, N > y)
Definition accelerate_simd.h:231
char v
Definition accelerate_simd.h:39
long v
Definition accelerate_simd.h:51
char v
Definition accelerate_simd.h:43
unsigned long v
Definition accelerate_simd.h:47
Definition accelerate_simd.h:34
T v
Definition accelerate_simd.h:35
Definition accelerate_simd.h:55
T operator[](int idx) const
Definition accelerate_simd.h:72
typename ScalarT< T, N >::v scalar_t
Definition accelerate_simd.h:57
asd::Vector< scalar_t, N >::packed_t value
Definition accelerate_simd.h:80
Simd()
Definition accelerate_simd.h:59
static constexpr int size
Definition accelerate_simd.h:56
T & operator[](int idx)
Definition accelerate_simd.h:76