MLX
Loading...
Searching...
No Matches
ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4#include <stdint.h>
5#include <cmath>
6#include <complex>
7
9
10namespace {
11constexpr float inf = std::numeric_limits<float>::infinity();
12} // namespace
13
14typedef union {
15 int i;
16 float f;
18
19inline float fast_exp(float x) {
20 if (x == -std::numeric_limits<float>::infinity()) {
21 return 0.0f;
22 } else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
23 return x;
24 }
25 x *= 1.442695; // multiply with log_2(e)
26 float ipart, fpart;
27 IntOrFloat epart;
28 x = std::max(-80.f, std::min(x, 80.f));
29 ipart = std::floor(x + 0.5);
30 fpart = x - ipart;
31
32 x = 1.535336188319500e-4f;
33 x = x * fpart + 1.339887440266574e-3f;
34 x = x * fpart + 9.618437357674640e-3f;
35 x = x * fpart + 5.550332471162809e-2f;
36 x = x * fpart + 2.402264791363012e-1f;
37 x = x * fpart + 6.931472028550421e-1f;
38 x = x * fpart + 1.000000000000000f;
39
40 // generate 2**ipart in the floating point representation using integer
41 // bitshifting
42 epart.i = (int(ipart) + 127) << 23;
43
44 return epart.f * x;
45}
46
47inline float fast_erf(float a) {
48 float r, s, t, u;
49 t = std::abs(a);
50 s = a * a;
51 if (t > 0.927734375f) {
52 // maximum error 0.99527 ulp
53 r = std::fma(
54 -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
55 u = std::fma(
56 -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
57 r = std::fma(r, s, u);
58 r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
59 r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
60 r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
61 r = std::fma(r, t, -t);
62 // TODO, replace with expm1 when implemented
63 r = 1.0f - std::exp(r);
64 r = std::copysign(r, a);
65 } else {
66 // maximum error 0.98929 ulp
67 r = -5.96761703e-4f; // -0x1.38e000p-11
68 r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
69 r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
70 r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
71 r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
72 r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
73 r = std::fma(r, a, a);
74 }
75 return r;
76}
77
78inline float fast_erfinv(float a) {
79 auto t = std::fma(a, 0.0f - a, 1.0f);
80 t = std::log(t);
81 float p;
82 if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
83 p = 3.03697567e-10f; // 0x1.4deb44p-32
84 p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
85 p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
86 p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
87 p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
88 p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
89 p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
90 p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
91 p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
92 } else { // maximum ulp error = 2.35002
93 p = 5.43877832e-9f; // 0x1.75c000p-28
94 p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
95 p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
96 p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
97 p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
98 p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
99 p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
100 p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
101 p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
102 p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
103 }
104 return a * p;
105}
106
107struct Abs {
108 template <typename T>
109 T operator()(T x) {
110 return std::abs(x);
111 }
112 uint8_t operator()(uint8_t x) {
113 return x;
114 }
115 uint16_t operator()(uint16_t x) {
116 return x;
117 }
118 uint32_t operator()(uint32_t x) {
119 return x;
120 }
121 uint64_t operator()(uint64_t x) {
122 return x;
123 }
124 bool operator()(bool x) {
125 return x;
126 }
127};
128
129struct ArcCos {
130 template <typename T>
131 T operator()(T x) {
132 return std::acos(x);
133 }
134};
135
136struct ArcCosh {
137 template <typename T>
138 T operator()(T x) {
139 return std::acosh(x);
140 }
141};
142
143struct ArcSin {
144 template <typename T>
145 T operator()(T x) {
146 return std::asin(x);
147 }
148};
149
150struct ArcSinh {
151 template <typename T>
152 T operator()(T x) {
153 return std::asinh(x);
154 }
155};
156
157struct ArcTan {
158 template <typename T>
159 T operator()(T x) {
160 return std::atan(x);
161 }
162};
163
164struct ArcTan2 {
165 template <typename T>
166 T operator()(T y, T x) {
167 return std::atan2(y, x);
168 }
169};
170
171struct ArcTanh {
172 template <typename T>
173 T operator()(T x) {
174 return std::atanh(x);
175 }
176};
177
178struct Ceil {
179 template <typename T>
180 T operator()(T x) {
181 return std::ceil(x);
182 }
183 int8_t operator()(int8_t x) {
184 return x;
185 }
186 int16_t operator()(int16_t x) {
187 return x;
188 }
189 int32_t operator()(int32_t x) {
190 return x;
191 }
192 int64_t operator()(int64_t x) {
193 return x;
194 }
195 uint8_t operator()(uint8_t x) {
196 return x;
197 }
198 uint16_t operator()(uint16_t x) {
199 return x;
200 }
201 uint32_t operator()(uint32_t x) {
202 return x;
203 }
204 uint64_t operator()(uint64_t x) {
205 return x;
206 }
207 bool operator()(bool x) {
208 return x;
209 }
210};
211
212struct Conjugate {
214 return std::conj(x);
215 }
216};
217
218struct Cos {
219 template <typename T>
220 T operator()(T x) {
221 return std::cos(x);
222 }
223};
224
225struct Cosh {
226 template <typename T>
227 T operator()(T x) {
228 return std::cosh(x);
229 }
230};
231
232struct Erf {
233 template <typename T>
234 T operator()(T x) {
235 return static_cast<T>(fast_erf(static_cast<float>(x)));
236 }
237};
238
239struct ErfInv {
240 template <typename T>
241 T operator()(T x) {
242 return static_cast<T>(fast_erfinv(static_cast<float>(x)));
243 }
244};
245
246struct Exp {
247 template <typename T>
248 T operator()(T x) {
249 return fast_exp(x);
250 }
251
253 return std::exp(x);
254 }
255};
256
257struct Expm1 {
258 template <typename T>
259 T operator()(T x) {
260 return expm1(x);
261 }
262};
263
264struct Floor {
265 template <typename T>
266 T operator()(T x) {
267 return std::floor(x);
268 }
269 int8_t operator()(int8_t x) {
270 return x;
271 }
272 int16_t operator()(int16_t x) {
273 return x;
274 }
275 int32_t operator()(int32_t x) {
276 return x;
277 }
278 int64_t operator()(int64_t x) {
279 return x;
280 }
281 uint8_t operator()(uint8_t x) {
282 return x;
283 }
284 uint16_t operator()(uint16_t x) {
285 return x;
286 }
287 uint32_t operator()(uint32_t x) {
288 return x;
289 }
290 uint64_t operator()(uint64_t x) {
291 return x;
292 }
293 bool operator()(bool x) {
294 return x;
295 }
296};
297
298struct Imag {
299 template <typename T>
300 T operator()(T x) {
301 return std::imag(x);
302 }
303};
304
305struct Log {
306 template <typename T>
307 T operator()(T x) {
308 return std::log(x);
309 }
310};
311
312struct Log2 {
313 template <typename T>
314 T operator()(T x) {
315 return std::log2(x);
316 }
317};
318
319struct Log10 {
320 template <typename T>
321 T operator()(T x) {
322 return std::log10(x);
323 }
324};
325
326struct Log1p {
327 template <typename T>
328 T operator()(T x) {
329 return log1p(x);
330 }
331};
332
334 template <typename T>
335 T operator()(T x) {
336 return !x;
337 }
338};
339
340struct Negative {
341 template <typename T>
342 T operator()(T x) {
343 return -x;
344 }
345};
346
347struct Real {
348 template <typename T>
349 T operator()(T x) {
350 return std::real(x);
351 }
352};
353
354struct Round {
355 template <typename T>
356 T operator()(T x) {
357 return std::rint(x);
358 }
359
361 return {std::rint(x.real()), std::rint(x.imag())};
362 }
363};
364
365struct Sigmoid {
366 template <typename T>
367 T operator()(T x) {
368 auto one = static_cast<decltype(x)>(1.0);
369 return one / (one + fast_exp(-x));
370 }
371};
372
373struct Sign {
374 template <typename T>
375 T operator()(T x) {
376 return (x > T(0)) - (x < T(0));
377 }
378 uint8_t operator()(uint8_t x) {
379 return x != 0;
380 }
381 uint16_t operator()(uint16_t x) {
382 return x != 0;
383 }
384 uint32_t operator()(uint32_t x) {
385 return x != 0;
386 }
387 uint64_t operator()(uint64_t x) {
388 return x != 0;
389 }
390
392 return x == complex64_t(0) ? x : x / std::abs(x);
393 }
394};
395
396struct Sin {
397 template <typename T>
398 T operator()(T x) {
399 return std::sin(x);
400 }
401};
402
403struct Sinh {
404 template <typename T>
405 T operator()(T x) {
406 return std::sinh(x);
407 }
408};
409
410struct Square {
411 template <typename T>
412 T operator()(T x) {
413 return x * x;
414 }
415};
416
417struct Sqrt {
418 template <typename T>
419 T operator()(T x) {
420 return std::sqrt(x);
421 }
422};
423
424struct Rsqrt {
425 template <typename T>
426 T operator()(T x) {
427 return static_cast<decltype(x)>(1.0) / std::sqrt(x);
428 }
429};
430
431struct Tan {
432 template <typename T>
433 T operator()(T x) {
434 return std::tan(x);
435 }
436};
437
438struct Tanh {
439 template <typename T>
440 T operator()(T x) {
441 return std::tanh(x);
442 }
443};
444
445struct Add {
446 template <typename T>
447 T operator()(T x, T y) {
448 return x + y;
449 }
450};
451
452struct Divide {
453 template <typename T>
454 T operator()(T x, T y) {
455 return x / y;
456 }
457};
458
459struct Remainder {
460 template <typename T>
461 std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
462 T numerator,
463 T denominator) {
464 return numerator % denominator;
465 }
466
467 template <typename T>
468 std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
469 T numerator,
470 T denominator) {
471 auto r = numerator % denominator;
472 if (r != 0 && (r < 0 != denominator < 0))
473 r += denominator;
474 return r;
475 }
476
477 template <typename T>
478 std::enable_if_t<!std::is_integral_v<T>, T> operator()(
479 T numerator,
480 T denominator) {
481 auto r = std::fmod(numerator, denominator);
482 if (r != 0 && (r < 0 != denominator < 0)) {
483 r += denominator;
484 }
485 return r;
486 }
487
489 return numerator % denominator;
490 }
491};
492
493struct Equal {
494 template <typename T>
495 bool operator()(T x, T y) {
496 return x == y;
497 }
498};
499
500struct NaNEqual {
501 template <typename T>
502 bool operator()(T x, T y) {
503 return x == y || (std::isnan(x) && std::isnan(y));
504 }
505};
506
507struct Greater {
508 template <typename T>
509 bool operator()(T x, T y) {
510 return x > y;
511 }
512};
513
515 template <typename T>
516 bool operator()(T x, T y) {
517 return x >= y;
518 }
519};
520
521struct Less {
522 template <typename T>
523 bool operator()(T x, T y) {
524 return x < y;
525 }
526};
527
528struct LessEqual {
529 template <typename T>
530 bool operator()(T x, T y) {
531 return x <= y;
532 }
533};
534
535struct Maximum {
536 template <typename T>
537 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
538 return (x > y) ? x : y;
539 }
540
541 template <typename T>
542 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
543 if (std::isnan(x)) {
544 return x;
545 }
546 return (x > y) ? x : y;
547 }
548};
549
550struct Minimum {
551 template <typename T>
552 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
553 return x < y ? x : y;
554 }
555
556 template <typename T>
557 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
558 if (std::isnan(x)) {
559 return x;
560 }
561 return x < y ? x : y;
562 }
563};
564
565struct LogAddExp {
566 template <typename T>
567 T operator()(T x, T y) {
568 constexpr float inf = std::numeric_limits<float>::infinity();
569 auto maxval = Maximum()(x, y);
570 auto minval = Minimum()(x, y);
571 return (minval == -inf || maxval == inf)
572 ? maxval
573 : static_cast<decltype(x)>(
574 maxval + std::log1p(fast_exp(minval - maxval)));
575 }
576};
577
578struct Multiply {
579 template <typename T>
580 T operator()(T x, T y) {
581 return x * y;
582 }
583};
584
585struct NotEqual {
586 template <typename T>
587 bool operator()(T x, T y) {
588 return x != y;
589 }
590};
591
592struct Power {
593 template <typename T>
594 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
595 return std::pow(base, exp);
596 }
597
598 template <typename T>
599 std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
600 T res = 1;
601 while (exp) {
602 if (exp & 1) {
603 res *= base;
604 }
605 exp >>= 1;
606 base *= base;
607 }
608 return res;
609 }
610};
611
612struct Subtract {
613 template <typename T>
614 T operator()(T x, T y) {
615 return x - y;
616 }
617};
618
620 template <typename T>
621 T operator()(T x, T y) {
622 return x && y;
623 }
624};
625
626struct LogicalOr {
627 template <typename T>
628 T operator()(T x, T y) {
629 return x || y;
630 }
631};
632
633struct Select {
634 template <typename T>
635 T operator()(bool condition, T x, T y) {
636 return condition ? x : y;
637 }
638};
639
641 template <typename T>
642 T operator()(T x, T y) {
643 return x & y;
644 }
645};
646
647struct BitwiseOr {
648 template <typename T>
649 T operator()(T x, T y) {
650 return x | y;
651 }
652};
653
655 template <typename T>
656 T operator()(T x, T y) {
657 return x ^ y;
658 }
659};
660
661struct LeftShift {
662 template <typename T>
663 T operator()(T x, T y) {
664 return x << y;
665 }
666};
667
669 template <typename T>
670 T operator()(T x, T y) {
671 return x >> y;
672 }
673};
674
675} // namespace mlx::core::detail
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
Definition ops.h:8
float fast_exp(float x)
Definition ops.h:19
float fast_erf(float a)
Definition ops.h:47
float fast_erfinv(float a)
Definition ops.h:78
Definition complex.h:34
Definition ops.h:107
T operator()(T x)
Definition ops.h:109
uint8_t operator()(uint8_t x)
Definition ops.h:112
uint64_t operator()(uint64_t x)
Definition ops.h:121
uint16_t operator()(uint16_t x)
Definition ops.h:115
bool operator()(bool x)
Definition ops.h:124
uint32_t operator()(uint32_t x)
Definition ops.h:118
Definition ops.h:445
T operator()(T x, T y)
Definition ops.h:447
Definition ops.h:129
T operator()(T x)
Definition ops.h:131
Definition ops.h:136
T operator()(T x)
Definition ops.h:138
Definition ops.h:143
T operator()(T x)
Definition ops.h:145
Definition ops.h:150
T operator()(T x)
Definition ops.h:152
Definition ops.h:164
T operator()(T y, T x)
Definition ops.h:166
Definition ops.h:157
T operator()(T x)
Definition ops.h:159
Definition ops.h:171
T operator()(T x)
Definition ops.h:173
Definition ops.h:640
T operator()(T x, T y)
Definition ops.h:642
Definition ops.h:647
T operator()(T x, T y)
Definition ops.h:649
Definition ops.h:654
T operator()(T x, T y)
Definition ops.h:656
Definition ops.h:178
uint8_t operator()(uint8_t x)
Definition ops.h:195
T operator()(T x)
Definition ops.h:180
uint32_t operator()(uint32_t x)
Definition ops.h:201
int8_t operator()(int8_t x)
Definition ops.h:183
int16_t operator()(int16_t x)
Definition ops.h:186
bool operator()(bool x)
Definition ops.h:207
uint16_t operator()(uint16_t x)
Definition ops.h:198
uint64_t operator()(uint64_t x)
Definition ops.h:204
int32_t operator()(int32_t x)
Definition ops.h:189
int64_t operator()(int64_t x)
Definition ops.h:192
Definition ops.h:212
complex64_t operator()(complex64_t x)
Definition ops.h:213
Definition ops.h:218
T operator()(T x)
Definition ops.h:220
Definition ops.h:225
T operator()(T x)
Definition ops.h:227
Definition ops.h:452
T operator()(T x, T y)
Definition ops.h:454
Definition ops.h:493
bool operator()(T x, T y)
Definition ops.h:495
Definition ops.h:232
T operator()(T x)
Definition ops.h:234
Definition ops.h:239
T operator()(T x)
Definition ops.h:241
Definition ops.h:246
T operator()(T x)
Definition ops.h:248
complex64_t operator()(complex64_t x)
Definition ops.h:252
Definition ops.h:257
T operator()(T x)
Definition ops.h:259
Definition ops.h:264
T operator()(T x)
Definition ops.h:266
uint32_t operator()(uint32_t x)
Definition ops.h:287
uint16_t operator()(uint16_t x)
Definition ops.h:284
uint8_t operator()(uint8_t x)
Definition ops.h:281
int32_t operator()(int32_t x)
Definition ops.h:275
int64_t operator()(int64_t x)
Definition ops.h:278
bool operator()(bool x)
Definition ops.h:293
int8_t operator()(int8_t x)
Definition ops.h:269
uint64_t operator()(uint64_t x)
Definition ops.h:290
int16_t operator()(int16_t x)
Definition ops.h:272
bool operator()(T x, T y)
Definition ops.h:516
Definition ops.h:507
bool operator()(T x, T y)
Definition ops.h:509
Definition ops.h:298
T operator()(T x)
Definition ops.h:300
Definition ops.h:661
T operator()(T x, T y)
Definition ops.h:663
Definition ops.h:528
bool operator()(T x, T y)
Definition ops.h:530
Definition ops.h:521
bool operator()(T x, T y)
Definition ops.h:523
Definition ops.h:319
T operator()(T x)
Definition ops.h:321
Definition ops.h:326
T operator()(T x)
Definition ops.h:328
Definition ops.h:312
T operator()(T x)
Definition ops.h:314
Definition ops.h:565
T operator()(T x, T y)
Definition ops.h:567
Definition ops.h:305
T operator()(T x)
Definition ops.h:307
Definition ops.h:619
T operator()(T x, T y)
Definition ops.h:621
Definition ops.h:333
T operator()(T x)
Definition ops.h:335
Definition ops.h:626
T operator()(T x, T y)
Definition ops.h:628
Definition ops.h:535
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:537
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:542
Definition ops.h:550
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:557
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:552
Definition ops.h:578
T operator()(T x, T y)
Definition ops.h:580
Definition ops.h:500
bool operator()(T x, T y)
Definition ops.h:502
Definition ops.h:340
T operator()(T x)
Definition ops.h:342
Definition ops.h:585
bool operator()(T x, T y)
Definition ops.h:587
Definition ops.h:592
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:594
std::enable_if_t< std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:599
Definition ops.h:347
T operator()(T x)
Definition ops.h:349
Definition ops.h:459
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:478
std::enable_if_t< std::is_integral_v< T > &!std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:461
std::enable_if_t< std::is_integral_v< T > &std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:468
complex64_t operator()(complex64_t numerator, complex64_t denominator)
Definition ops.h:488
Definition ops.h:668
T operator()(T x, T y)
Definition ops.h:670
Definition ops.h:354
T operator()(T x)
Definition ops.h:356
complex64_t operator()(complex64_t x)
Definition ops.h:360
Definition ops.h:424
T operator()(T x)
Definition ops.h:426
Definition ops.h:633
T operator()(bool condition, T x, T y)
Definition ops.h:635
Definition ops.h:365
T operator()(T x)
Definition ops.h:367
Definition ops.h:373
uint64_t operator()(uint64_t x)
Definition ops.h:387
T operator()(T x)
Definition ops.h:375
uint8_t operator()(uint8_t x)
Definition ops.h:378
uint16_t operator()(uint16_t x)
Definition ops.h:381
complex64_t operator()(complex64_t x)
Definition ops.h:391
uint32_t operator()(uint32_t x)
Definition ops.h:384
Definition ops.h:396
T operator()(T x)
Definition ops.h:398
Definition ops.h:403
T operator()(T x)
Definition ops.h:405
Definition ops.h:417
T operator()(T x)
Definition ops.h:419
Definition ops.h:410
T operator()(T x)
Definition ops.h:412
Definition ops.h:612
T operator()(T x, T y)
Definition ops.h:614
Definition ops.h:431
T operator()(T x)
Definition ops.h:433
Definition ops.h:438
T operator()(T x)
Definition ops.h:440
uint32_t u
Definition bf16.h:17
float f
Definition ops.h:16
int i
Definition ops.h:15