MLX
 
Loading...
Searching...
No Matches
base_simd.h
Go to the documentation of this file.
1#pragma once
2
3#include <stdint.h>
4#include <algorithm>
5#include <cmath>
6#include <complex>
7
8namespace mlx::core::simd {
9template <typename T, int N>
10struct Simd;
11
12template <typename T>
13static constexpr int max_size = 1;
14
15template <typename T>
16struct Simd<T, 1> {
17 static constexpr int size = 1;
19 Simd() {}
20 template <typename U>
22 template <typename U>
23 Simd(U v) : value(v) {}
24};
25
26template <typename T, int N>
27Simd<T, N> load(const T* x) {
28 return *(Simd<T, N>*)x;
29}
30
31template <typename T, int N>
32void store(T* dst, Simd<T, N> x) {
33 // Maintain invariant that bool is either 0 or 1 as
34 // simd comparison ops set all bits in the result to 1
35 if constexpr (std::is_same_v<T, bool> && N > 1) {
36 x = x & 1;
37 }
38 *(Simd<T, N>*)dst = x;
39}
40
41template <typename, typename = void>
42constexpr bool is_complex = false;
43
44template <typename T>
46 true;
47
48template <typename T>
50 if constexpr (is_complex<T>) {
51 return Simd<T, 1>{
52 T{std::rint(in.value.real()), std::rint(in.value.imag())}};
53 } else {
54 return Simd<T, 1>{std::rint(in.value)};
55 }
56}
57
58template <typename T>
60 return T(1.0) / sqrt(in);
61}
62
63template <typename T>
65 return T(1.0) / in;
66}
67
68#define DEFAULT_UNARY(name, op) \
69 template <typename T> \
70 Simd<T, 1> name(Simd<T, 1> in) { \
71 return op(in.value); \
72 }
73
74DEFAULT_UNARY(operator-, std::negate{})
75DEFAULT_UNARY(operator!, std::logical_not{})
96
97template <typename T>
98auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
99 return std::real(in.value);
100}
101template <typename T>
102auto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {
103 return std::imag(in.value);
104}
105template <typename T>
107 return std::isnan(in.value);
108}
109
110#define DEFAULT_BINARY(OP) \
111 template <typename T1, typename T2> \
112 auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b) \
113 ->Simd<decltype(a.value OP b.value), 1> { \
114 return a.value OP b.value; \
115 } \
116 template <typename T1, typename T2> \
117 auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \
118 return a OP b.value; \
119 } \
120 template <typename T1, typename T2> \
121 auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \
122 return a.value OP b; \
123 }
124
136
137template <typename T>
138Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
139 T a = a_.value;
140 T b = b_.value;
141 T r;
142 if constexpr (std::is_integral_v<T>) {
143 r = a % b;
144 } else {
145 r = std::remainder(a, b);
146 }
147 if constexpr (std::is_signed_v<T>) {
148 if (r != 0 && (r < 0 != b < 0)) {
149 r += b;
150 }
151 }
152 return r;
153}
154
155template <typename T>
157 T a = a_.value;
158 T b = b_.value;
159 if constexpr (!std::is_integral_v<T>) {
160 if (std::isnan(a)) {
161 return a;
162 }
163 }
164 return (a > b) ? a : b;
165}
166
167template <typename T>
169 T a = a_.value;
170 T b = b_.value;
171 if constexpr (!std::is_integral_v<T>) {
172 if (std::isnan(a)) {
173 return a;
174 }
175 }
176 return (a < b) ? a : b;
177}
178
179template <typename T>
181 T base = a.value;
182 T exp = b.value;
183 if constexpr (!std::is_integral_v<T>) {
184 return std::pow(base, exp);
185 } else {
186 T res = 1;
187 while (exp) {
188 if (exp & 1) {
189 res *= base;
190 }
191 exp >>= 1;
192 base *= base;
193 }
194 return res;
195 }
196}
197
198template <typename T>
200 return std::atan2(a.value, b.value);
201}
202
203#define DEFAULT_COMPARISONS(OP) \
204 template <typename T1, typename T2> \
205 Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \
206 return a.value OP b.value; \
207 } \
208 template <typename T1, typename T2> \
209 Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) { \
210 return a OP b.value; \
211 } \
212 template <typename T1, typename T2> \
213 Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) { \
214 return a.value OP b; \
215 }
216
223
224template <typename MaskT, typename T>
225Simd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {
226 return mask.value ? x.value : y.value;
227}
228
229template <typename T>
231 return std::clamp(v.value, min.value, max.value);
232}
233
234template <typename T, typename U>
236 return std::fma(x.value, y.value, Simd<T, 1>(z).value);
237}
238
239// Reductions
240#define DEFAULT_REDUCTION(name, type) \
241 template <typename T> \
242 type name(Simd<T, 1> x) { \
243 return x.value; \
244 }
245
252
253} // namespace mlx::core::simd
#define DEFAULT_REDUCTION(name, type)
Definition base_simd.h:240
#define DEFAULT_UNARY(name, op)
Definition base_simd.h:68
#define DEFAULT_BINARY(OP)
Definition base_simd.h:110
#define DEFAULT_COMPARISONS(OP)
Definition base_simd.h:203
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Definition accelerate_fp16_simd.h:9
Simd< bool, N > isnan(Simd< T, N > v)
Definition accelerate_simd.h:141
Simd< float16_t, N > sinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:41
constexpr int N
Definition neon_fp16_simd.h:9
Simd< float16_t, N > atanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:34
Simd< T, N > minimum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:210
Simd< float16_t, N > pow(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:54
Simd< float16_t, N > atan2(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:52
T prod(Simd< T, N > x)
Definition accelerate_simd.h:292
Simd< float16_t, N > log10(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:39
Simd< T, N > rint(Simd< T, N > v)
Definition accelerate_simd.h:127
Simd< T, N > load(const T *x)
Definition base_simd.h:27
Simd< float16_t, N > tan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:42
Simd< T, N > abs(Simd< T, N > v)
Definition accelerate_simd.h:112
Simd< float16_t, N > acosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:30
bool all(Simd< T, N > x)
Definition accelerate_simd.h:271
T sum(Simd< T, N > x)
Definition accelerate_simd.h:279
constexpr bool is_complex
Definition base_simd.h:42
Simd< T, 1 > conj(Simd< T, 1 > in)
Definition base_simd.h:84
Simd< float16_t, N > log2(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:38
T max(Simd< T, N > x)
Definition accelerate_simd.h:283
Simd< T, N > maximum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:204
Simd< T, N > exp(Simd< T, N > in)
Compute exp(x) in an optimizer friendly way as follows:
Definition math.h:28
Simd< float16_t, N > log(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:37
Simd< T, N > floor(Simd< T, N > v)
Definition accelerate_simd.h:113
Simd< float16_t, N > expm1(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:36
auto imag(Simd< T, 1 > in) -> Simd< decltype(std::imag(in.value)), 1 >
Definition base_simd.h:102
Simd< float16_t, N > asin(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:31
bool any(Simd< T, N > x)
Definition accelerate_simd.h:275
Simd< T, N > fma(Simd< T, N > x, Simd< T, N > y, U z)
Definition accelerate_simd.h:264
Simd< float16_t, N > tanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:43
Simd< float16_t, N > atan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:33
Simd< float16_t, N > asinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:32
Simd< float16_t, N > remainder(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:53
static constexpr int max_size
Definition base_simd.h:13
T min(Simd< T, N > x)
Definition accelerate_simd.h:287
auto real(Simd< T, 1 > in) -> Simd< decltype(std::real(in.value)), 1 >
Definition base_simd.h:98
Simd< float16_t, N > log1p(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:40
Simd< T, N > ceil(Simd< T, N > v)
Definition accelerate_simd.h:120
Simd< T, N > recip(Simd< T, N > v)
Definition accelerate_simd.h:131
Simd< T, N > sqrt(Simd< T, N > v)
Definition accelerate_simd.h:129
Simd< T, N > clamp(Simd< T, N > v, Simd< T, N > min, Simd< T, N > max)
Definition accelerate_simd.h:259
Simd< float16_t, N > acos(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:29
Simd< T, N > rsqrt(Simd< T, N > v)
Definition accelerate_simd.h:130
Simd< float16_t, N > cosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:35
void store(T *dst, Simd< T, N > x)
Definition base_simd.h:32
Simd< T1, N > select(Simd< MaskT, N > mask, Simd< T1, N > x, Simd< T2, N > y)
Definition accelerate_simd.h:231
static constexpr int size
Definition base_simd.h:17
Simd()
Definition base_simd.h:19
Simd(Simd< U, 1 > v)
Definition base_simd.h:21
T value
Definition base_simd.h:18
Simd(U v)
Definition base_simd.h:23
Definition accelerate_simd.h:55
asd::Vector< scalar_t, N >::packed_t value
Definition accelerate_simd.h:80