MLX
 
Loading...
Searching...
No Matches
unary_ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_integer>
6#include <metal_math>
7
10
11namespace {
12constant float inf = metal::numeric_limits<float>::infinity();
13}
14
15struct Abs {
16 template <typename T>
17 T operator()(T x) {
18 return metal::abs(x);
19 };
20 template <>
21 uint8_t operator()(uint8_t x) {
22 return x;
23 };
24 template <>
25 uint16_t operator()(uint16_t x) {
26 return x;
27 };
28 template <>
29 uint32_t operator()(uint32_t x) {
30 return x;
31 };
32 template <>
33 uint64_t operator()(uint64_t x) {
34 return x;
35 };
36 template <>
37 bool operator()(bool x) {
38 return x;
39 };
40 template <>
42 return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
43 };
44};
45
46struct ArcCos {
47 template <typename T>
48 T operator()(T x) {
49 return metal::precise::acos(x);
50 };
51};
52
53struct ArcCosh {
54 template <typename T>
55 T operator()(T x) {
56 return metal::precise::acosh(x);
57 };
58};
59
60struct ArcSin {
61 template <typename T>
62 T operator()(T x) {
63 return metal::precise::asin(x);
64 };
65};
66
67struct ArcSinh {
68 template <typename T>
69 T operator()(T x) {
70 return metal::precise::asinh(x);
71 };
72};
73
74struct ArcTan {
75 template <typename T>
76 T operator()(T x) {
77 return metal::precise::atan(x);
78 };
79};
80
81struct ArcTanh {
82 template <typename T>
83 T operator()(T x) {
84 return metal::precise::atanh(x);
85 };
86};
87
89 template <typename T>
90 T operator()(T x) {
91 return ~x;
92 };
93};
94
95struct Ceil {
96 template <typename T>
97 T operator()(T x) {
98 return metal::ceil(x);
99 };
100 template <>
101 int8_t operator()(int8_t x) {
102 return x;
103 };
104 template <>
105 int16_t operator()(int16_t x) {
106 return x;
107 };
108 template <>
109 int32_t operator()(int32_t x) {
110 return x;
111 };
112 template <>
113 int64_t operator()(int64_t x) {
114 return x;
115 };
116 template <>
117 uint8_t operator()(uint8_t x) {
118 return x;
119 };
120 template <>
121 uint16_t operator()(uint16_t x) {
122 return x;
123 };
124 template <>
125 uint32_t operator()(uint32_t x) {
126 return x;
127 };
128 template <>
129 uint64_t operator()(uint64_t x) {
130 return x;
131 };
132 template <>
133 bool operator()(bool x) {
134 return x;
135 };
136};
137
138struct Cos {
139 template <typename T>
140 T operator()(T x) {
141 return metal::precise::cos(x);
142 };
143
144 template <>
150};
151
152struct Cosh {
153 template <typename T>
154 T operator()(T x) {
155 return metal::precise::cosh(x);
156 };
157
158 template <>
164};
165
166struct Conjugate {
168 return complex64_t{x.real, -x.imag};
169 }
170};
171
172struct Erf {
173 template <typename T>
174 T operator()(T x) {
175 return static_cast<T>(erf(static_cast<float>(x)));
176 };
177};
178
179struct ErfInv {
180 template <typename T>
181 T operator()(T x) {
182 return static_cast<T>(erfinv(static_cast<float>(x)));
183 };
184};
185
186struct Exp {
187 template <typename T>
188 T operator()(T x) {
189 return metal::precise::exp(x);
190 };
191 template <>
196};
197
198struct Expm1 {
199 template <typename T>
200 T operator()(T x) {
201 return static_cast<T>(expm1f(static_cast<float>(x)));
202 };
203};
204
205struct Floor {
206 template <typename T>
207 T operator()(T x) {
208 return metal::floor(x);
209 };
210 template <>
211 int8_t operator()(int8_t x) {
212 return x;
213 };
214 template <>
215 int16_t operator()(int16_t x) {
216 return x;
217 };
218 template <>
219 int32_t operator()(int32_t x) {
220 return x;
221 };
222 template <>
223 int64_t operator()(int64_t x) {
224 return x;
225 };
226 template <>
227 uint8_t operator()(uint8_t x) {
228 return x;
229 };
230 template <>
231 uint16_t operator()(uint16_t x) {
232 return x;
233 };
234 template <>
235 uint32_t operator()(uint32_t x) {
236 return x;
237 };
238 template <>
239 uint64_t operator()(uint64_t x) {
240 return x;
241 };
242 template <>
243 bool operator()(bool x) {
244 return x;
245 };
246};
247
248struct Imag {
249 template <typename T>
250 T operator()(T x) {
251 return x.imag;
252 };
253};
254
255struct Log {
256 template <typename T>
257 T operator()(T x) {
258 return metal::precise::log(x);
259 };
260};
261
262struct Log2 {
263 template <typename T>
264 T operator()(T x) {
265 return metal::precise::log2(x);
266 };
267};
268
269struct Log10 {
270 template <typename T>
271 T operator()(T x) {
272 return metal::precise::log10(x);
273 };
274};
275
276struct Log1p {
277 template <typename T>
278 T operator()(T x) {
279 return log1p(x);
280 };
281};
282
284 template <typename T>
285 T operator()(T x) {
286 return !x;
287 };
288};
289
290struct Negative {
291 template <typename T>
292 T operator()(T x) {
293 return -x;
294 };
295};
296
297struct Real {
298 template <typename T>
299 T operator()(T x) {
300 return x.real;
301 };
302};
303
304struct Round {
305 template <typename T>
306 T operator()(T x) {
307 return metal::rint(x);
308 };
309 template <>
313};
314
315struct Sigmoid {
316 template <typename T>
317 T operator()(T x) {
318 auto y = 1 / (1 + metal::exp(-metal::abs(x)));
319 return (x < 0) ? 1 - y : y;
320 }
321};
322
323struct Sign {
324 template <typename T>
325 T operator()(T x) {
326 return (x > T(0)) - (x < T(0));
327 };
328 template <>
329 uint32_t operator()(uint32_t x) {
330 return x != 0;
331 };
332 template <>
334 if (x == complex64_t(0)) {
335 return x;
336 }
337 return x /
339 };
340};
341
342struct Sin {
343 template <typename T>
344 T operator()(T x) {
345 return metal::precise::sin(x);
346 };
347
348 template <>
354};
355
356struct Sinh {
357 template <typename T>
358 T operator()(T x) {
359 return metal::precise::sinh(x);
360 };
361
362 template <>
368};
369
370struct Square {
371 template <typename T>
372 T operator()(T x) {
373 return x * x;
374 };
375};
376
377struct Sqrt {
378 template <typename T>
379 T operator()(T x) {
380 return metal::precise::sqrt(x);
381 };
382};
383
384struct Rsqrt {
385 template <typename T>
386 T operator()(T x) {
387 return metal::precise::rsqrt(x);
388 };
389};
390
391struct Tan {
392 template <typename T>
393 T operator()(T x) {
394 return metal::precise::tan(x);
395 };
396
397 template <>
399 float tan_a = metal::precise::tan(x.real);
400 float tanh_b = metal::precise::tanh(x.imag);
401 float t1 = tan_a * tanh_b;
402 float denom = 1. + t1 * t1;
403 return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
404 };
405};
406
407struct Tanh {
408 template <typename T>
409 T operator()(T x) {
410 return metal::precise::tanh(x);
411 };
412
413 template <>
415 float tanh_a = metal::precise::tanh(x.real);
416 float tan_b = metal::precise::tan(x.imag);
417 float t1 = tanh_a * tan_b;
418 float denom = 1. + t1 * t1;
419 return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
420 };
421};
float log1p(float x)
Definition utils.h:307
float erfinv(float a)
Definition erf.h:42
float erf(float a)
Definition erf.h:11
float expm1f(float a)
Definition expm1f.h:80
METAL_FUNC bfloat16_t acosh(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t log10(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t log2(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t sin(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t cosh(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t tanh(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t tan(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t acos(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t atanh(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t asinh(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t atan(bfloat16_t y_over_x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t sinh(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t cos(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t sqrt(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t asin(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t rsqrt(bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t floor(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t rint(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t ceil(bfloat16_t x)
Definition bf16_math.h:232
constexpr float inf
Definition math.h:9
Definition unary_ops.h:15
uint8_t operator()(uint8_t x)
Definition unary_ops.h:21
uint32_t operator()(uint32_t x)
Definition unary_ops.h:29
T operator()(T x)
Definition unary_ops.h:17
complex64_t operator()(complex64_t x)
Definition unary_ops.h:41
bool operator()(bool x)
Definition unary_ops.h:37
uint64_t operator()(uint64_t x)
Definition unary_ops.h:33
uint16_t operator()(uint16_t x)
Definition unary_ops.h:25
Definition unary_ops.h:46
T operator()(T x)
Definition unary_ops.h:48
Definition unary_ops.h:53
T operator()(T x)
Definition unary_ops.h:55
Definition unary_ops.h:60
T operator()(T x)
Definition unary_ops.h:62
Definition unary_ops.h:67
T operator()(T x)
Definition unary_ops.h:69
Definition unary_ops.h:74
T operator()(T x)
Definition unary_ops.h:76
Definition unary_ops.h:81
T operator()(T x)
Definition unary_ops.h:83
Definition unary_ops.h:88
T operator()(T x)
Definition unary_ops.h:90
Definition unary_ops.h:95
int16_t operator()(int16_t x)
Definition unary_ops.h:105
bool operator()(bool x)
Definition unary_ops.h:133
uint32_t operator()(uint32_t x)
Definition unary_ops.h:125
int8_t operator()(int8_t x)
Definition unary_ops.h:101
T operator()(T x)
Definition unary_ops.h:97
int64_t operator()(int64_t x)
Definition unary_ops.h:113
uint64_t operator()(uint64_t x)
Definition unary_ops.h:129
uint8_t operator()(uint8_t x)
Definition unary_ops.h:117
uint16_t operator()(uint16_t x)
Definition unary_ops.h:121
int32_t operator()(int32_t x)
Definition unary_ops.h:109
Definition unary_ops.h:166
complex64_t operator()(complex64_t x)
Definition unary_ops.h:167
Definition unary_ops.h:138
complex64_t operator()(complex64_t x)
Definition unary_ops.h:145
T operator()(T x)
Definition unary_ops.h:140
Definition unary_ops.h:152
T operator()(T x)
Definition unary_ops.h:154
complex64_t operator()(complex64_t x)
Definition unary_ops.h:159
Definition unary_ops.h:172
T operator()(T x)
Definition unary_ops.h:174
Definition unary_ops.h:179
T operator()(T x)
Definition unary_ops.h:181
Definition unary_ops.h:186
complex64_t operator()(complex64_t x)
Definition unary_ops.h:192
T operator()(T x)
Definition unary_ops.h:188
Definition unary_ops.h:198
T operator()(T x)
Definition unary_ops.h:200
Definition unary_ops.h:205
int8_t operator()(int8_t x)
Definition unary_ops.h:211
int16_t operator()(int16_t x)
Definition unary_ops.h:215
int32_t operator()(int32_t x)
Definition unary_ops.h:219
uint16_t operator()(uint16_t x)
Definition unary_ops.h:231
uint64_t operator()(uint64_t x)
Definition unary_ops.h:239
uint32_t operator()(uint32_t x)
Definition unary_ops.h:235
int64_t operator()(int64_t x)
Definition unary_ops.h:223
bool operator()(bool x)
Definition unary_ops.h:243
uint8_t operator()(uint8_t x)
Definition unary_ops.h:227
T operator()(T x)
Definition unary_ops.h:207
Definition unary_ops.h:248
T operator()(T x)
Definition unary_ops.h:250
Definition unary_ops.h:269
T operator()(T x)
Definition unary_ops.h:271
Definition unary_ops.h:276
T operator()(T x)
Definition unary_ops.h:278
Definition unary_ops.h:262
T operator()(T x)
Definition unary_ops.h:264
Definition unary_ops.h:255
T operator()(T x)
Definition unary_ops.h:257
Definition unary_ops.h:283
T operator()(T x)
Definition unary_ops.h:285
Definition unary_ops.h:290
T operator()(T x)
Definition unary_ops.h:292
Definition unary_ops.h:297
T operator()(T x)
Definition unary_ops.h:299
Definition unary_ops.h:304
T operator()(T x)
Definition unary_ops.h:306
complex64_t operator()(complex64_t x)
Definition unary_ops.h:310
Definition unary_ops.h:384
T operator()(T x)
Definition unary_ops.h:386
Definition unary_ops.h:315
T operator()(T x)
Definition unary_ops.h:317
Definition unary_ops.h:323
T operator()(T x)
Definition unary_ops.h:325
uint32_t operator()(uint32_t x)
Definition unary_ops.h:329
complex64_t operator()(complex64_t x)
Definition unary_ops.h:333
Definition unary_ops.h:342
T operator()(T x)
Definition unary_ops.h:344
complex64_t operator()(complex64_t x)
Definition unary_ops.h:349
Definition unary_ops.h:356
T operator()(T x)
Definition unary_ops.h:358
complex64_t operator()(complex64_t x)
Definition unary_ops.h:363
Definition unary_ops.h:377
T operator()(T x)
Definition unary_ops.h:379
Definition unary_ops.h:370
T operator()(T x)
Definition unary_ops.h:372
Definition unary_ops.h:391
T operator()(T x)
Definition unary_ops.h:393
complex64_t operator()(complex64_t x)
Definition unary_ops.h:398
Definition unary_ops.h:407
complex64_t operator()(complex64_t x)
Definition unary_ops.h:414
T operator()(T x)
Definition unary_ops.h:409
Definition complex.h:20
float imag
Definition complex.h:22
float real
Definition complex.h:21