MLX
 
Loading...
Searching...
No Matches
base_simd.h
Go to the documentation of this file.
1#pragma once
2
3#include <stdint.h>
4#include <algorithm>
5#include <cmath>
6#include <complex>
7#include <functional>
8
9namespace mlx::core::simd {
10template <typename T, int N>
11struct Simd;
12
13template <typename T>
14static constexpr int max_size = 1;
15
16template <typename T>
17struct Simd<T, 1> {
18 static constexpr int size = 1;
20 Simd() {}
21 template <typename U>
23 template <typename U>
24 Simd(U v) : value(v) {}
25};
26
27template <typename T, int N>
28Simd<T, N> load(const T* x) {
29 return *(Simd<T, N>*)x;
30}
31
32template <typename T, int N>
33void store(T* dst, Simd<T, N> x) {
34 // Maintain invariant that bool is either 0 or 1 as
35 // simd comparison ops set all bits in the result to 1
36 if constexpr (std::is_same_v<T, bool> && N > 1) {
37 x = x & 1;
38 }
39 *(Simd<T, N>*)dst = x;
40}
41
42template <typename, typename = void>
43constexpr bool is_complex = false;
44
45template <typename T>
47 true;
48
49template <typename T>
51 if constexpr (is_complex<T>) {
52 return Simd<T, 1>{
53 T{std::rint(in.value.real()), std::rint(in.value.imag())}};
54 } else {
55 return Simd<T, 1>{std::rint(in.value)};
56 }
57}
58
59template <typename T>
61 return T(1.0) / sqrt(in);
62}
63
64template <typename T>
66 return T(1.0) / in;
67}
68
69#define DEFAULT_UNARY(name, op) \
70 template <typename T> \
71 Simd<T, 1> name(Simd<T, 1> in) { \
72 return op(in.value); \
73 }
74
75DEFAULT_UNARY(operator-, std::negate{})
76DEFAULT_UNARY(operator!, std::logical_not{})
97
98template <typename T>
99Simd<T, 1> operator~(Simd<T, 1> in) {
100 return ~in.value;
101}
102
103template <typename T>
104auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
105 return std::real(in.value);
106}
107template <typename T>
108auto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {
109 return std::imag(in.value);
110}
111template <typename T>
113 return std::isnan(in.value);
114}
115
116#define DEFAULT_BINARY(OP) \
117 template <typename T1, typename T2> \
118 auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b) \
119 ->Simd<decltype(a.value OP b.value), 1> { \
120 return a.value OP b.value; \
121 } \
122 template <typename T1, typename T2> \
123 auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \
124 return a OP b.value; \
125 } \
126 template <typename T1, typename T2> \
127 auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \
128 return a.value OP b; \
129 }
130
142
143template <typename T>
144Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
145 T a = a_.value;
146 T b = b_.value;
147 T r;
148 if constexpr (std::is_integral_v<T>) {
149 r = a % b;
150 } else {
151 r = std::remainder(a, b);
152 }
153 if constexpr (std::is_signed_v<T>) {
154 if (r != 0 && (r < 0 != b < 0)) {
155 r += b;
156 }
157 }
158 return r;
159}
160
161template <typename T>
163 T a = a_.value;
164 T b = b_.value;
165 if constexpr (!std::is_integral_v<T>) {
166 if (std::isnan(a)) {
167 return a;
168 }
169 }
170 return (a > b) ? a : b;
171}
172
173template <typename T>
175 T a = a_.value;
176 T b = b_.value;
177 if constexpr (!std::is_integral_v<T>) {
178 if (std::isnan(a)) {
179 return a;
180 }
181 }
182 return (a < b) ? a : b;
183}
184
185template <typename T>
187 T base = a.value;
188 T exp = b.value;
189 if constexpr (!std::is_integral_v<T>) {
190 return std::pow(base, exp);
191 } else {
192 T res = 1;
193 while (exp) {
194 if (exp & 1) {
195 res *= base;
196 }
197 exp >>= 1;
198 base *= base;
199 }
200 return res;
201 }
202}
203
204template <typename T>
206 return std::atan2(a.value, b.value);
207}
208
209#define DEFAULT_COMPARISONS(OP) \
210 template <typename T1, typename T2> \
211 Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \
212 return a.value OP b.value; \
213 } \
214 template <typename T1, typename T2> \
215 Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) { \
216 return a OP b.value; \
217 } \
218 template <typename T1, typename T2> \
219 Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) { \
220 return a.value OP b; \
221 }
222
229
230template <typename MaskT, typename T>
231Simd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {
232 return mask.value ? x.value : y.value;
233}
234
235template <typename T>
237 return std::clamp(v.value, min.value, max.value);
238}
239
240template <typename T, typename U>
242 return std::fma(x.value, y.value, Simd<T, 1>(z).value);
243}
244
245// Reductions
246#define DEFAULT_REDUCTION(name, type) \
247 template <typename T> \
248 type name(Simd<T, 1> x) { \
249 return x.value; \
250 }
251
258
259} // namespace mlx::core::simd
#define DEFAULT_REDUCTION(name, type)
Definition base_simd.h:246
#define DEFAULT_UNARY(name, op)
Definition base_simd.h:69
#define DEFAULT_BINARY(OP)
Definition base_simd.h:116
#define DEFAULT_COMPARISONS(OP)
Definition base_simd.h:209
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Definition accelerate_fp16_simd.h:9
Simd< bool, N > isnan(Simd< T, N > v)
Definition accelerate_simd.h:146
Simd< float16_t, N > sinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:41
constexpr int N
Definition neon_fp16_simd.h:9
Simd< float16_t, N > atanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:34
Simd< T, N > minimum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:215
Simd< float16_t, N > pow(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:54
Simd< float16_t, N > atan2(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:52
T prod(Simd< T, N > x)
Definition accelerate_simd.h:297
Simd< float16_t, N > log10(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:39
Simd< T, N > rint(Simd< T, N > v)
Definition accelerate_simd.h:127
Simd< T, N > load(const T *x)
Definition base_simd.h:28
Simd< float16_t, N > tan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:42
Simd< T, N > abs(Simd< T, N > v)
Definition accelerate_simd.h:112
Simd< float16_t, N > acosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:30
bool all(Simd< T, N > x)
Definition accelerate_simd.h:276
T sum(Simd< T, N > x)
Definition accelerate_simd.h:284
constexpr bool is_complex
Definition base_simd.h:43
Simd< T, 1 > conj(Simd< T, 1 > in)
Definition base_simd.h:85
Simd< float16_t, N > log2(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:38
T max(Simd< T, N > x)
Definition accelerate_simd.h:288
Simd< T, N > maximum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:209
Simd< T, N > exp(Simd< T, N > in)
Compute exp(x) in an optimizer friendly way as follows:
Definition math.h:28
Simd< float16_t, N > log(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:37
Simd< T, N > floor(Simd< T, N > v)
Definition accelerate_simd.h:113
Simd< float16_t, N > expm1(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:36
auto imag(Simd< T, 1 > in) -> Simd< decltype(std::imag(in.value)), 1 >
Definition base_simd.h:108
Simd< float16_t, N > asin(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:31
bool any(Simd< T, N > x)
Definition accelerate_simd.h:280
Simd< T, N > fma(Simd< T, N > x, Simd< T, N > y, U z)
Definition accelerate_simd.h:269
Simd< float16_t, N > tanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:43
Simd< float16_t, N > atan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:33
Simd< float16_t, N > asinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:32
Simd< float16_t, N > remainder(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:53
static constexpr int max_size
Definition base_simd.h:14
T min(Simd< T, N > x)
Definition accelerate_simd.h:292
auto real(Simd< T, 1 > in) -> Simd< decltype(std::real(in.value)), 1 >
Definition base_simd.h:104
Simd< float16_t, N > log1p(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:40
Simd< T, N > ceil(Simd< T, N > v)
Definition accelerate_simd.h:120
Simd< T, N > recip(Simd< T, N > v)
Definition accelerate_simd.h:131
Simd< T, N > sqrt(Simd< T, N > v)
Definition accelerate_simd.h:129
Simd< T, N > clamp(Simd< T, N > v, Simd< T, N > min, Simd< T, N > max)
Definition accelerate_simd.h:264
Simd< float16_t, N > acos(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:29
Simd< T, N > rsqrt(Simd< T, N > v)
Definition accelerate_simd.h:130
Simd< float16_t, N > cosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:35
void store(T *dst, Simd< T, N > x)
Definition base_simd.h:33
Simd< T1, N > select(Simd< MaskT, N > mask, Simd< T1, N > x, Simd< T2, N > y)
Definition accelerate_simd.h:236
static constexpr int size
Definition base_simd.h:18
Simd()
Definition base_simd.h:20
Simd(Simd< U, 1 > v)
Definition base_simd.h:22
T value
Definition base_simd.h:19
Simd(U v)
Definition base_simd.h:24
Definition accelerate_simd.h:55
asd::Vector< scalar_t, N >::packed_t value
Definition accelerate_simd.h:80