MLX
Loading...
Searching...
No Matches
ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4#include <stdint.h>
5#include <cmath>
6#include <complex>
7
9
10namespace {
11constexpr float inf = std::numeric_limits<float>::infinity();
12} // namespace
13
14typedef union {
15 int i;
16 float f;
18
19inline float fast_exp(float x) {
20 if (x == -std::numeric_limits<float>::infinity()) {
21 return 0.0f;
22 } else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
23 return x;
24 }
25 x *= 1.442695; // multiply with log_2(e)
26 float ipart, fpart;
27 IntOrFloat epart;
28 x = std::max(-80.f, std::min(x, 80.f));
29 ipart = std::floor(x + 0.5);
30 fpart = x - ipart;
31
32 x = 1.535336188319500e-4f;
33 x = x * fpart + 1.339887440266574e-3f;
34 x = x * fpart + 9.618437357674640e-3f;
35 x = x * fpart + 5.550332471162809e-2f;
36 x = x * fpart + 2.402264791363012e-1f;
37 x = x * fpart + 6.931472028550421e-1f;
38 x = x * fpart + 1.000000000000000f;
39
40 // generate 2**ipart in the floating point representation using integer
41 // bitshifting
42 epart.i = (int(ipart) + 127) << 23;
43
44 return epart.f * x;
45}
46
47inline float fast_erf(float a) {
48 float r, s, t, u;
49 t = std::abs(a);
50 s = a * a;
51 if (t > 0.927734375f) {
52 // maximum error 0.99527 ulp
53 r = std::fma(
54 -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
55 u = std::fma(
56 -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
57 r = std::fma(r, s, u);
58 r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
59 r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
60 r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
61 r = std::fma(r, t, -t);
62 // TODO, replace with expm1 when implemented
63 r = 1.0f - std::exp(r);
64 r = std::copysign(r, a);
65 } else {
66 // maximum error 0.98929 ulp
67 r = -5.96761703e-4f; // -0x1.38e000p-11
68 r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
69 r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
70 r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
71 r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
72 r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
73 r = std::fma(r, a, a);
74 }
75 return r;
76}
77
78inline float fast_erfinv(float a) {
79 auto t = std::fma(a, 0.0f - a, 1.0f);
80 t = std::log(t);
81 float p;
82 if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
83 p = 3.03697567e-10f; // 0x1.4deb44p-32
84 p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
85 p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
86 p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
87 p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
88 p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
89 p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
90 p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
91 p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
92 } else { // maximum ulp error = 2.35002
93 p = 5.43877832e-9f; // 0x1.75c000p-28
94 p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
95 p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
96 p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
97 p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
98 p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
99 p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
100 p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
101 p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
102 p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
103 }
104 return a * p;
105}
106
107struct Abs {
108 template <typename T>
109 T operator()(T x) {
110 return std::abs(x);
111 }
112 uint8_t operator()(uint8_t x) {
113 return x;
114 }
115 uint16_t operator()(uint16_t x) {
116 return x;
117 }
118 uint32_t operator()(uint32_t x) {
119 return x;
120 }
121 uint64_t operator()(uint64_t x) {
122 return x;
123 }
124 bool operator()(bool x) {
125 return x;
126 }
127};
128
129struct ArcCos {
130 template <typename T>
131 T operator()(T x) {
132 return std::acos(x);
133 }
134};
135
136struct ArcCosh {
137 template <typename T>
138 T operator()(T x) {
139 return std::acosh(x);
140 }
141};
142
143struct ArcSin {
144 template <typename T>
145 T operator()(T x) {
146 return std::asin(x);
147 }
148};
149
150struct ArcSinh {
151 template <typename T>
152 T operator()(T x) {
153 return std::asinh(x);
154 }
155};
156
157struct ArcTan {
158 template <typename T>
159 T operator()(T x) {
160 return std::atan(x);
161 }
162};
163
164struct ArcTan2 {
165 template <typename T>
166 T operator()(T y, T x) {
167 return std::atan2(y, x);
168 }
169};
170
171struct ArcTanh {
172 template <typename T>
173 T operator()(T x) {
174 return std::atanh(x);
175 }
176};
177
178struct Ceil {
179 template <typename T>
180 T operator()(T x) {
181 return std::ceil(x);
182 }
183 int8_t operator()(int8_t x) {
184 return x;
185 }
186 int16_t operator()(int16_t x) {
187 return x;
188 }
189 int32_t operator()(int32_t x) {
190 return x;
191 }
192 int64_t operator()(int64_t x) {
193 return x;
194 }
195 uint8_t operator()(uint8_t x) {
196 return x;
197 }
198 uint16_t operator()(uint16_t x) {
199 return x;
200 }
201 uint32_t operator()(uint32_t x) {
202 return x;
203 }
204 uint64_t operator()(uint64_t x) {
205 return x;
206 }
207 bool operator()(bool x) {
208 return x;
209 }
210};
211
212struct Conjugate {
214 return std::conj(x);
215 }
216};
217
218struct Cos {
219 template <typename T>
220 T operator()(T x) {
221 return std::cos(x);
222 }
223};
224
225struct Cosh {
226 template <typename T>
227 T operator()(T x) {
228 return std::cosh(x);
229 }
230};
231
232struct Erf {
233 template <typename T>
234 T operator()(T x) {
235 return static_cast<T>(fast_erf(static_cast<float>(x)));
236 }
237};
238
239struct ErfInv {
240 template <typename T>
241 T operator()(T x) {
242 return static_cast<T>(fast_erfinv(static_cast<float>(x)));
243 }
244};
245
246struct Exp {
247 template <typename T>
248 T operator()(T x) {
249 return fast_exp(x);
250 }
251
253 return std::exp(x);
254 }
255};
256
257struct Expm1 {
258 template <typename T>
259 T operator()(T x) {
260 return expm1(x);
261 }
262};
263
264struct Floor {
265 template <typename T>
266 T operator()(T x) {
267 return std::floor(x);
268 }
269 int8_t operator()(int8_t x) {
270 return x;
271 }
272 int16_t operator()(int16_t x) {
273 return x;
274 }
275 int32_t operator()(int32_t x) {
276 return x;
277 }
278 int64_t operator()(int64_t x) {
279 return x;
280 }
281 uint8_t operator()(uint8_t x) {
282 return x;
283 }
284 uint16_t operator()(uint16_t x) {
285 return x;
286 }
287 uint32_t operator()(uint32_t x) {
288 return x;
289 }
290 uint64_t operator()(uint64_t x) {
291 return x;
292 }
293 bool operator()(bool x) {
294 return x;
295 }
296};
297
298struct Log {
299 template <typename T>
300 T operator()(T x) {
301 return std::log(x);
302 }
303};
304
305struct Log2 {
306 template <typename T>
307 T operator()(T x) {
308 return std::log2(x);
309 }
310};
311
312struct Log10 {
313 template <typename T>
314 T operator()(T x) {
315 return std::log10(x);
316 }
317};
318
319struct Log1p {
320 template <typename T>
321 T operator()(T x) {
322 return log1p(x);
323 }
324};
325
327 template <typename T>
328 T operator()(T x) {
329 return !x;
330 }
331};
332
333struct Negative {
334 template <typename T>
335 T operator()(T x) {
336 return -x;
337 }
338};
339
340struct Round {
341 template <typename T>
342 T operator()(T x) {
343 return std::rint(x);
344 }
345
347 return {std::rint(x.real()), std::rint(x.imag())};
348 }
349};
350
351struct Sigmoid {
352 template <typename T>
353 T operator()(T x) {
354 auto one = static_cast<decltype(x)>(1.0);
355 return one / (one + fast_exp(-x));
356 }
357};
358
359struct Sign {
360 template <typename T>
361 T operator()(T x) {
362 return (x > T(0)) - (x < T(0));
363 }
364 uint8_t operator()(uint8_t x) {
365 return x != 0;
366 }
367 uint16_t operator()(uint16_t x) {
368 return x != 0;
369 }
370 uint32_t operator()(uint32_t x) {
371 return x != 0;
372 }
373 uint64_t operator()(uint64_t x) {
374 return x != 0;
375 }
376
378 return x == complex64_t(0) ? x : x / std::abs(x);
379 }
380};
381
382struct Sin {
383 template <typename T>
384 T operator()(T x) {
385 return std::sin(x);
386 }
387};
388
389struct Sinh {
390 template <typename T>
391 T operator()(T x) {
392 return std::sinh(x);
393 }
394};
395
396struct Square {
397 template <typename T>
398 T operator()(T x) {
399 return x * x;
400 }
401};
402
403struct Sqrt {
404 template <typename T>
405 T operator()(T x) {
406 return std::sqrt(x);
407 }
408};
409
410struct Rsqrt {
411 template <typename T>
412 T operator()(T x) {
413 return static_cast<decltype(x)>(1.0) / std::sqrt(x);
414 }
415};
416
417struct Tan {
418 template <typename T>
419 T operator()(T x) {
420 return std::tan(x);
421 }
422};
423
424struct Tanh {
425 template <typename T>
426 T operator()(T x) {
427 return std::tanh(x);
428 }
429};
430
431struct Add {
432 template <typename T>
433 T operator()(T x, T y) {
434 return x + y;
435 }
436};
437
438struct Divide {
439 template <typename T>
440 T operator()(T x, T y) {
441 return x / y;
442 }
443};
444
445struct Remainder {
446 template <typename T>
447 std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
448 T numerator,
449 T denominator) {
450 return numerator % denominator;
451 }
452
453 template <typename T>
454 std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
455 T numerator,
456 T denominator) {
457 auto r = numerator % denominator;
458 if (r != 0 && (r < 0 != denominator < 0))
459 r += denominator;
460 return r;
461 }
462
463 template <typename T>
464 std::enable_if_t<!std::is_integral_v<T>, T> operator()(
465 T numerator,
466 T denominator) {
467 auto r = std::fmod(numerator, denominator);
468 if (r != 0 && (r < 0 != denominator < 0)) {
469 r += denominator;
470 }
471 return r;
472 }
473
475 return numerator % denominator;
476 }
477};
478
479struct Equal {
480 template <typename T>
481 bool operator()(T x, T y) {
482 return x == y;
483 }
484};
485
486struct NaNEqual {
487 template <typename T>
488 bool operator()(T x, T y) {
489 return x == y || (std::isnan(x) && std::isnan(y));
490 }
491};
492
493struct Greater {
494 template <typename T>
495 bool operator()(T x, T y) {
496 return x > y;
497 }
498};
499
501 template <typename T>
502 bool operator()(T x, T y) {
503 return x >= y;
504 }
505};
506
507struct Less {
508 template <typename T>
509 bool operator()(T x, T y) {
510 return x < y;
511 }
512};
513
514struct LessEqual {
515 template <typename T>
516 bool operator()(T x, T y) {
517 return x <= y;
518 }
519};
520
521struct Maximum {
522 template <typename T>
523 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
524 return (x > y) ? x : y;
525 }
526
527 template <typename T>
528 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
529 if (std::isnan(x)) {
530 return x;
531 }
532 return (x > y) ? x : y;
533 }
534};
535
536struct Minimum {
537 template <typename T>
538 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
539 return x < y ? x : y;
540 }
541
542 template <typename T>
543 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
544 if (std::isnan(x)) {
545 return x;
546 }
547 return x < y ? x : y;
548 }
549};
550
551struct LogAddExp {
552 template <typename T>
553 T operator()(T x, T y) {
554 constexpr float inf = std::numeric_limits<float>::infinity();
555 auto maxval = Maximum()(x, y);
556 auto minval = Minimum()(x, y);
557 return (minval == -inf || maxval == inf)
558 ? maxval
559 : static_cast<decltype(x)>(
560 maxval + std::log1p(fast_exp(minval - maxval)));
561 }
562};
563
564struct Multiply {
565 template <typename T>
566 T operator()(T x, T y) {
567 return x * y;
568 }
569};
570
571struct NotEqual {
572 template <typename T>
573 bool operator()(T x, T y) {
574 return x != y;
575 }
576};
577
578struct Power {
579 template <typename T>
580 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
581 return std::pow(base, exp);
582 }
583
584 template <typename T>
585 std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
586 T res = 1;
587 while (exp) {
588 if (exp & 1) {
589 res *= base;
590 }
591 exp >>= 1;
592 base *= base;
593 }
594 return res;
595 }
596};
597
598struct Subtract {
599 template <typename T>
600 T operator()(T x, T y) {
601 return x - y;
602 }
603};
604
606 template <typename T>
607 T operator()(T x, T y) {
608 return x && y;
609 }
610};
611
612struct LogicalOr {
613 template <typename T>
614 T operator()(T x, T y) {
615 return x || y;
616 }
617};
618
619struct Select {
620 template <typename T>
621 T operator()(bool condition, T x, T y) {
622 return condition ? x : y;
623 }
624};
625
627 template <typename T>
628 T operator()(T x, T y) {
629 return x & y;
630 }
631};
632
633struct BitwiseOr {
634 template <typename T>
635 T operator()(T x, T y) {
636 return x | y;
637 }
638};
639
641 template <typename T>
642 T operator()(T x, T y) {
643 return x ^ y;
644 }
645};
646
647struct LeftShift {
648 template <typename T>
649 T operator()(T x, T y) {
650 return x << y;
651 }
652};
653
655 template <typename T>
656 T operator()(T x, T y) {
657 return x >> y;
658 }
659};
660
661} // namespace mlx::core::detail
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
Definition ops.h:8
float fast_exp(float x)
Definition ops.h:19
float fast_erf(float a)
Definition ops.h:47
float fast_erfinv(float a)
Definition ops.h:78
Definition complex.h:34
Definition ops.h:107
T operator()(T x)
Definition ops.h:109
uint8_t operator()(uint8_t x)
Definition ops.h:112
uint64_t operator()(uint64_t x)
Definition ops.h:121
uint16_t operator()(uint16_t x)
Definition ops.h:115
bool operator()(bool x)
Definition ops.h:124
uint32_t operator()(uint32_t x)
Definition ops.h:118
Definition ops.h:431
T operator()(T x, T y)
Definition ops.h:433
Definition ops.h:129
T operator()(T x)
Definition ops.h:131
Definition ops.h:136
T operator()(T x)
Definition ops.h:138
Definition ops.h:143
T operator()(T x)
Definition ops.h:145
Definition ops.h:150
T operator()(T x)
Definition ops.h:152
Definition ops.h:164
T operator()(T y, T x)
Definition ops.h:166
Definition ops.h:157
T operator()(T x)
Definition ops.h:159
Definition ops.h:171
T operator()(T x)
Definition ops.h:173
Definition ops.h:626
T operator()(T x, T y)
Definition ops.h:628
Definition ops.h:633
T operator()(T x, T y)
Definition ops.h:635
Definition ops.h:640
T operator()(T x, T y)
Definition ops.h:642
Definition ops.h:178
uint8_t operator()(uint8_t x)
Definition ops.h:195
T operator()(T x)
Definition ops.h:180
uint32_t operator()(uint32_t x)
Definition ops.h:201
int8_t operator()(int8_t x)
Definition ops.h:183
int16_t operator()(int16_t x)
Definition ops.h:186
bool operator()(bool x)
Definition ops.h:207
uint16_t operator()(uint16_t x)
Definition ops.h:198
uint64_t operator()(uint64_t x)
Definition ops.h:204
int32_t operator()(int32_t x)
Definition ops.h:189
int64_t operator()(int64_t x)
Definition ops.h:192
Definition ops.h:212
complex64_t operator()(complex64_t x)
Definition ops.h:213
Definition ops.h:218
T operator()(T x)
Definition ops.h:220
Definition ops.h:225
T operator()(T x)
Definition ops.h:227
Definition ops.h:438
T operator()(T x, T y)
Definition ops.h:440
Definition ops.h:479
bool operator()(T x, T y)
Definition ops.h:481
Definition ops.h:232
T operator()(T x)
Definition ops.h:234
Definition ops.h:239
T operator()(T x)
Definition ops.h:241
Definition ops.h:246
T operator()(T x)
Definition ops.h:248
complex64_t operator()(complex64_t x)
Definition ops.h:252
Definition ops.h:257
T operator()(T x)
Definition ops.h:259
Definition ops.h:264
T operator()(T x)
Definition ops.h:266
uint32_t operator()(uint32_t x)
Definition ops.h:287
uint16_t operator()(uint16_t x)
Definition ops.h:284
uint8_t operator()(uint8_t x)
Definition ops.h:281
int32_t operator()(int32_t x)
Definition ops.h:275
int64_t operator()(int64_t x)
Definition ops.h:278
bool operator()(bool x)
Definition ops.h:293
int8_t operator()(int8_t x)
Definition ops.h:269
uint64_t operator()(uint64_t x)
Definition ops.h:290
int16_t operator()(int16_t x)
Definition ops.h:272
bool operator()(T x, T y)
Definition ops.h:502
Definition ops.h:493
bool operator()(T x, T y)
Definition ops.h:495
Definition ops.h:647
T operator()(T x, T y)
Definition ops.h:649
Definition ops.h:514
bool operator()(T x, T y)
Definition ops.h:516
Definition ops.h:507
bool operator()(T x, T y)
Definition ops.h:509
Definition ops.h:312
T operator()(T x)
Definition ops.h:314
Definition ops.h:319
T operator()(T x)
Definition ops.h:321
Definition ops.h:305
T operator()(T x)
Definition ops.h:307
Definition ops.h:551
T operator()(T x, T y)
Definition ops.h:553
Definition ops.h:298
T operator()(T x)
Definition ops.h:300
Definition ops.h:605
T operator()(T x, T y)
Definition ops.h:607
Definition ops.h:326
T operator()(T x)
Definition ops.h:328
Definition ops.h:612
T operator()(T x, T y)
Definition ops.h:614
Definition ops.h:521
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:523
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:528
Definition ops.h:536
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:543
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:538
Definition ops.h:564
T operator()(T x, T y)
Definition ops.h:566
Definition ops.h:486
bool operator()(T x, T y)
Definition ops.h:488
Definition ops.h:333
T operator()(T x)
Definition ops.h:335
Definition ops.h:571
bool operator()(T x, T y)
Definition ops.h:573
Definition ops.h:578
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:580
std::enable_if_t< std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:585
Definition ops.h:445
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:464
std::enable_if_t< std::is_integral_v< T > &!std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:447
std::enable_if_t< std::is_integral_v< T > &std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:454
complex64_t operator()(complex64_t numerator, complex64_t denominator)
Definition ops.h:474
Definition ops.h:654
T operator()(T x, T y)
Definition ops.h:656
Definition ops.h:340
T operator()(T x)
Definition ops.h:342
complex64_t operator()(complex64_t x)
Definition ops.h:346
Definition ops.h:410
T operator()(T x)
Definition ops.h:412
Definition ops.h:619
T operator()(bool condition, T x, T y)
Definition ops.h:621
Definition ops.h:351
T operator()(T x)
Definition ops.h:353
Definition ops.h:359
uint64_t operator()(uint64_t x)
Definition ops.h:373
T operator()(T x)
Definition ops.h:361
uint8_t operator()(uint8_t x)
Definition ops.h:364
uint16_t operator()(uint16_t x)
Definition ops.h:367
complex64_t operator()(complex64_t x)
Definition ops.h:377
uint32_t operator()(uint32_t x)
Definition ops.h:370
Definition ops.h:382
T operator()(T x)
Definition ops.h:384
Definition ops.h:389
T operator()(T x)
Definition ops.h:391
Definition ops.h:403
T operator()(T x)
Definition ops.h:405
Definition ops.h:396
T operator()(T x)
Definition ops.h:398
Definition ops.h:598
T operator()(T x, T y)
Definition ops.h:600
Definition ops.h:417
T operator()(T x)
Definition ops.h:419
Definition ops.h:424
T operator()(T x)
Definition ops.h:426
uint32_t u
Definition bf16.h:17
float f
Definition ops.h:16
int i
Definition ops.h:15