4#include <simd/vector.h>
16#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
17 __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
18 __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
19 __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
20 __TV_OS_VERSION_MIN_REQUIRED >= 180000
21#define MLX_SIMD_LIBRARY_VERSION 6
23#define MLX_SIMD_LIBRARY_VERSION 5
29namespace asd = ::simd;
33template <
typename T,
int N>
47 using v =
unsigned long;
54template <
typename T,
int N>
68 value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
73 return reinterpret_cast<const T*
>(&
value)[idx];
77 return reinterpret_cast<T*
>(&
value)[idx];
80 typename asd::Vector<scalar_t, N>::packed_t
value;
106#define SIMD_DEFAULT_UNARY(name, op) \
107 template <typename T, int N> \
108 Simd<T, N> name(Simd<T, N> v) { \
109 return op(v.value); \
135template <typename T,
int N>
140template <
typename T,
int N>
142 return asd::convert<char>(v.
value != v.
value);
146template <
typename T,
int N>
148 return asd::convert<char>(!v.
value);
151#define SIMD_DEFAULT_BINARY(OP) \
152 template <typename T, typename U, int N> \
153 Simd<T, N> operator OP(Simd<T, N> x, U y) { \
154 return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
156 template <typename T1, typename T2, int N> \
157 Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
158 return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
160 template <typename T1, typename T2, int N> \
161 Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
162 return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
177#define SIMD_DEFAULT_COMPARISONS(OP) \
178 template <int N, typename T, typename U> \
179 Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
180 return asd::convert<char>(a.value OP b); \
182 template <int N, typename T, typename U> \
183 Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
184 return asd::convert<char>(a OP b.value); \
186 template <int N, typename T1, typename T2> \
187 Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
188 return asd::convert<char>(a.value OP b.value); \
198template <typename T,
int N>
200 return asd::atan2(a.value, b.value);
203template <
typename T,
int N>
209template <
typename T,
int N>
215template <
typename T,
int N>
218 if constexpr (!std::is_integral_v<T>) {
223 if constexpr (std::is_signed_v<T>) {
224 auto mask = r != 0 && (r < 0 != b < 0);
225 r =
select(mask, r + b, r);
230template <
typename MaskT,
typename T1,
typename T2,
int N>
232 if constexpr (
sizeof(T1) == 1) {
233 return asd::bitselect(y.
value, x.
value, asd::convert<char>(mask.
value));
234 }
else if constexpr (
sizeof(T1) == 2) {
235 return asd::bitselect(y.
value, x.
value, asd::convert<short>(mask.
value));
236 }
else if constexpr (
sizeof(T1) == 4) {
237 return asd::bitselect(y.
value, x.
value, asd::convert<int>(mask.
value));
239 return asd::bitselect(y.
value, x.
value, asd::convert<long>(mask.
value));
243template <
typename T,
int N>
245 if constexpr (!std::is_integral_v<T>) {
246 return asd::pow(base.
value,
exp.value);
258template <
typename T,
int N>
263template <
typename T,
typename U,
int N>
270template <
typename T,
int N>
272 return asd::all(x.
value);
274template <
typename T,
int N>
276 return asd::any(x.
value);
278template <
typename T,
int N>
280 return asd::reduce_add(x.
value);
282template <
typename T,
int N>
284 return asd::reduce_max(x.
value);
286template <
typename T,
int N>
288 return asd::reduce_min(x.
value);
291template <
typename T,
int N>
294 auto lhs =
load<T,
N / 2>(ptr);
295 auto rhs =
load<T,
N / 2>(ptr +
N / 2);
296 return prod(lhs * rhs);
301#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#define SIMD_DEFAULT_UNARY(name, op)
Definition accelerate_simd.h:106
#define SIMD_DEFAULT_BINARY(OP)
Definition accelerate_simd.h:151
#define SIMD_DEFAULT_COMPARISONS(OP)
Definition accelerate_simd.h:177
Definition accelerate_fp16_simd.h:9
Simd< bool, N > isnan(Simd< T, N > v)
Definition accelerate_simd.h:141
Simd< float16_t, N > sinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:41
constexpr int N
Definition neon_fp16_simd.h:9
Simd< float16_t, N > atanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:34
Simd< T, N > minimum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:210
Simd< float16_t, N > pow(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:54
Simd< float16_t, N > atan2(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:52
T prod(Simd< T, N > x)
Definition accelerate_simd.h:292
Simd< float16_t, N > log10(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:39
Simd< T, N > rint(Simd< T, N > v)
Definition accelerate_simd.h:127
Simd< T, N > load(const T *x)
Definition base_simd.h:27
Simd< float16_t, N > tan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:42
Simd< T, N > abs(Simd< T, N > v)
Definition accelerate_simd.h:112
Simd< float16_t, N > acosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:30
bool all(Simd< T, N > x)
Definition accelerate_simd.h:271
T sum(Simd< T, N > x)
Definition accelerate_simd.h:279
Simd< float16_t, N > log2(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:38
T max(Simd< T, N > x)
Definition accelerate_simd.h:283
Simd< bool, N > operator!(Simd< T, N > v)
Definition accelerate_simd.h:147
Simd< T, N > maximum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:204
Simd< T, N > exp(Simd< T, N > in)
Compute exp(x) in an optimizer friendly way as follows:
Definition math.h:28
Simd< float16_t, N > log(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:37
Simd< T, N > floor(Simd< T, N > v)
Definition accelerate_simd.h:113
Simd< float16_t, N > expm1(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:36
Simd< float16_t, N > asin(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:31
bool any(Simd< T, N > x)
Definition accelerate_simd.h:275
Simd< T, N > fma(Simd< T, N > x, Simd< T, N > y, U z)
Definition accelerate_simd.h:264
Simd< float16_t, N > tanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:43
Simd< float16_t, N > atan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:33
Simd< float16_t, N > asinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:32
Simd< float16_t, N > remainder(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:53
static constexpr int max_size
Definition base_simd.h:13
T min(Simd< T, N > x)
Definition accelerate_simd.h:287
Simd< float16_t, N > log1p(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:40
Simd< T, N > ceil(Simd< T, N > v)
Definition accelerate_simd.h:120
Simd< T, N > recip(Simd< T, N > v)
Definition accelerate_simd.h:131
Simd< T, N > sqrt(Simd< T, N > v)
Definition accelerate_simd.h:129
Simd< T, N > clamp(Simd< T, N > v, Simd< T, N > min, Simd< T, N > max)
Definition accelerate_simd.h:259
Simd< float16_t, N > acos(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:29
Simd< T, N > rsqrt(Simd< T, N > v)
Definition accelerate_simd.h:130
Simd< float16_t, N > cosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:35
Simd< T1, N > select(Simd< MaskT, N > mask, Simd< T1, N > x, Simd< T2, N > y)
Definition accelerate_simd.h:231
char v
Definition accelerate_simd.h:39
long v
Definition accelerate_simd.h:51
char v
Definition accelerate_simd.h:43
unsigned long v
Definition accelerate_simd.h:47
Definition accelerate_simd.h:34
T v
Definition accelerate_simd.h:35
Definition accelerate_simd.h:55
T operator[](int idx) const
Definition accelerate_simd.h:72
typename ScalarT< T, N >::v scalar_t
Definition accelerate_simd.h:57
asd::Vector< scalar_t, N >::packed_t value
Definition accelerate_simd.h:80
Simd()
Definition accelerate_simd.h:59
static constexpr int size
Definition accelerate_simd.h:56
T & operator[](int idx)
Definition accelerate_simd.h:76