4#include <simd/vector.h>
16#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
17 __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
18 __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
19 __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
20 __TV_OS_VERSION_MIN_REQUIRED >= 180000
21#define MLX_SIMD_LIBRARY_VERSION 6
23#define MLX_SIMD_LIBRARY_VERSION 5
29namespace asd = ::simd;
33template <
typename T,
int N>
47 using v =
unsigned long;
54template <
typename T,
int N>
68 value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
73 return reinterpret_cast<const T*
>(&
value)[idx];
77 return reinterpret_cast<T*
>(&
value)[idx];
80 typename asd::Vector<scalar_t, N>::packed_t
value;
106#define SIMD_DEFAULT_UNARY(name, op) \
107 template <typename T, int N> \
108 Simd<T, N> name(Simd<T, N> v) { \
109 return op(v.value); \
135template <typename T,
int N>
140template <
typename T,
int N>
145template <
typename T,
int N>
147 return asd::convert<char>(v.
value != v.
value);
151template <
typename T,
int N>
153 return asd::convert<char>(!v.
value);
156#define SIMD_DEFAULT_BINARY(OP) \
157 template <typename T, typename U, int N> \
158 Simd<T, N> operator OP(Simd<T, N> x, U y) { \
159 return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
161 template <typename T1, typename T2, int N> \
162 Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
163 return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
165 template <typename T1, typename T2, int N> \
166 Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
167 return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
182#define SIMD_DEFAULT_COMPARISONS(OP) \
183 template <int N, typename T, typename U> \
184 Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
185 return asd::convert<char>(a.value OP b); \
187 template <int N, typename T, typename U> \
188 Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
189 return asd::convert<char>(a OP b.value); \
191 template <int N, typename T1, typename T2> \
192 Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
193 return asd::convert<char>(a.value OP b.value); \
203template <typename T,
int N>
205 return asd::atan2(a.value, b.value);
208template <
typename T,
int N>
214template <
typename T,
int N>
220template <
typename T,
int N>
223 if constexpr (!std::is_integral_v<T>) {
228 if constexpr (std::is_signed_v<T>) {
229 auto mask = r != 0 && (r < 0 != b < 0);
230 r =
select(mask, r + b, r);
235template <
typename MaskT,
typename T1,
typename T2,
int N>
237 if constexpr (
sizeof(T1) == 1) {
238 return asd::bitselect(y.
value, x.
value, asd::convert<char>(mask.
value));
239 }
else if constexpr (
sizeof(T1) == 2) {
240 return asd::bitselect(y.
value, x.
value, asd::convert<short>(mask.
value));
241 }
else if constexpr (
sizeof(T1) == 4) {
242 return asd::bitselect(y.
value, x.
value, asd::convert<int>(mask.
value));
244 return asd::bitselect(y.
value, x.
value, asd::convert<long>(mask.
value));
248template <
typename T,
int N>
250 if constexpr (!std::is_integral_v<T>) {
251 return asd::pow(base.
value,
exp.value);
263template <
typename T,
int N>
268template <
typename T,
typename U,
int N>
275template <
typename T,
int N>
277 return asd::all(x.
value);
279template <
typename T,
int N>
281 return asd::any(x.
value);
283template <
typename T,
int N>
285 return asd::reduce_add(x.
value);
287template <
typename T,
int N>
289 return asd::reduce_max(x.
value);
291template <
typename T,
int N>
293 return asd::reduce_min(x.
value);
296template <
typename T,
int N>
299 auto lhs =
load<T,
N / 2>(ptr);
300 auto rhs =
load<T,
N / 2>(ptr +
N / 2);
301 return prod(lhs * rhs);
306#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#define SIMD_DEFAULT_UNARY(name, op)
Definition accelerate_simd.h:106
#define SIMD_DEFAULT_BINARY(OP)
Definition accelerate_simd.h:156
#define SIMD_DEFAULT_COMPARISONS(OP)
Definition accelerate_simd.h:182
Definition accelerate_fp16_simd.h:9
Simd< bool, N > isnan(Simd< T, N > v)
Definition accelerate_simd.h:146
Simd< float16_t, N > sinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:41
constexpr int N
Definition neon_fp16_simd.h:9
Simd< float16_t, N > atanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:34
Simd< T, N > minimum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:215
Simd< float16_t, N > pow(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:54
Simd< float16_t, N > atan2(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:52
T prod(Simd< T, N > x)
Definition accelerate_simd.h:297
Simd< T, N > operator~(Simd< T, N > v)
Definition accelerate_simd.h:141
Simd< float16_t, N > log10(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:39
Simd< T, N > rint(Simd< T, N > v)
Definition accelerate_simd.h:127
Simd< T, N > load(const T *x)
Definition base_simd.h:28
Simd< float16_t, N > tan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:42
Simd< T, N > abs(Simd< T, N > v)
Definition accelerate_simd.h:112
Simd< float16_t, N > acosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:30
bool all(Simd< T, N > x)
Definition accelerate_simd.h:276
T sum(Simd< T, N > x)
Definition accelerate_simd.h:284
Simd< float16_t, N > log2(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:38
T max(Simd< T, N > x)
Definition accelerate_simd.h:288
Simd< bool, N > operator!(Simd< T, N > v)
Definition accelerate_simd.h:152
Simd< T, N > maximum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:209
Simd< T, N > exp(Simd< T, N > in)
Compute exp(x) in an optimizer friendly way as follows:
Definition math.h:28
Simd< float16_t, N > log(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:37
Simd< T, N > floor(Simd< T, N > v)
Definition accelerate_simd.h:113
Simd< float16_t, N > expm1(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:36
Simd< float16_t, N > asin(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:31
bool any(Simd< T, N > x)
Definition accelerate_simd.h:280
Simd< T, N > fma(Simd< T, N > x, Simd< T, N > y, U z)
Definition accelerate_simd.h:269
Simd< float16_t, N > tanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:43
Simd< float16_t, N > atan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:33
Simd< float16_t, N > asinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:32
Simd< float16_t, N > remainder(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:53
static constexpr int max_size
Definition base_simd.h:14
T min(Simd< T, N > x)
Definition accelerate_simd.h:292
Simd< float16_t, N > log1p(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:40
Simd< T, N > ceil(Simd< T, N > v)
Definition accelerate_simd.h:120
Simd< T, N > recip(Simd< T, N > v)
Definition accelerate_simd.h:131
Simd< T, N > sqrt(Simd< T, N > v)
Definition accelerate_simd.h:129
Simd< T, N > clamp(Simd< T, N > v, Simd< T, N > min, Simd< T, N > max)
Definition accelerate_simd.h:264
Simd< float16_t, N > acos(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:29
Simd< T, N > rsqrt(Simd< T, N > v)
Definition accelerate_simd.h:130
Simd< float16_t, N > cosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:35
Simd< T1, N > select(Simd< MaskT, N > mask, Simd< T1, N > x, Simd< T2, N > y)
Definition accelerate_simd.h:236
char v
Definition accelerate_simd.h:39
long v
Definition accelerate_simd.h:51
char v
Definition accelerate_simd.h:43
unsigned long v
Definition accelerate_simd.h:47
Definition accelerate_simd.h:34
T v
Definition accelerate_simd.h:35
Definition accelerate_simd.h:55
T operator[](int idx) const
Definition accelerate_simd.h:72
typename ScalarT< T, N >::v scalar_t
Definition accelerate_simd.h:57
asd::Vector< scalar_t, N >::packed_t value
Definition accelerate_simd.h:80
Simd()
Definition accelerate_simd.h:59
static constexpr int size
Definition accelerate_simd.h:56
T & operator[](int idx)
Definition accelerate_simd.h:76