MLX
Loading...
Searching...
No Matches
ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4#include <stdint.h>
5#include <cmath>
6#include <complex>
7
9
10namespace {
11constexpr float inf = std::numeric_limits<float>::infinity();
12} // namespace
13
14typedef union {
15 int i;
16 float f;
18
19inline float fast_exp(float x) {
20 if (x == -std::numeric_limits<float>::infinity()) {
21 return 0.0f;
22 } else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
23 return x;
24 }
25 x *= 1.442695; // multiply with log_2(e)
26 float ipart, fpart;
27 IntOrFloat epart;
28 x = std::max(-80.f, std::min(x, 80.f));
29 ipart = std::floor(x + 0.5);
30 fpart = x - ipart;
31
32 x = 1.535336188319500e-4f;
33 x = x * fpart + 1.339887440266574e-3f;
34 x = x * fpart + 9.618437357674640e-3f;
35 x = x * fpart + 5.550332471162809e-2f;
36 x = x * fpart + 2.402264791363012e-1f;
37 x = x * fpart + 6.931472028550421e-1f;
38 x = x * fpart + 1.000000000000000f;
39
40 // generate 2**ipart in the floating point representation using integer
41 // bitshifting
42 epart.i = (int(ipart) + 127) << 23;
43
44 return epart.f * x;
45}
46
47inline float fast_erf(float a) {
48 float r, s, t, u;
49 t = std::abs(a);
50 s = a * a;
51 if (t > 0.927734375f) {
52 // maximum error 0.99527 ulp
53 r = std::fma(
54 -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
55 u = std::fma(
56 -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
57 r = std::fma(r, s, u);
58 r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
59 r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
60 r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
61 r = std::fma(r, t, -t);
62 // TODO, replace with expm1 when implemented
63 r = 1.0f - std::exp(r);
64 r = std::copysign(r, a);
65 } else {
66 // maximum error 0.98929 ulp
67 r = -5.96761703e-4f; // -0x1.38e000p-11
68 r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
69 r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
70 r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
71 r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
72 r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
73 r = std::fma(r, a, a);
74 }
75 return r;
76}
77
78inline float fast_erfinv(float a) {
79 auto t = std::fma(a, 0.0f - a, 1.0f);
80 t = std::log(t);
81 float p;
82 if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
83 p = 3.03697567e-10f; // 0x1.4deb44p-32
84 p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
85 p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
86 p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
87 p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
88 p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
89 p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
90 p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
91 p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
92 } else { // maximum ulp error = 2.35002
93 p = 5.43877832e-9f; // 0x1.75c000p-28
94 p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
95 p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
96 p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
97 p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
98 p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
99 p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
100 p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
101 p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
102 p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
103 }
104 return a * p;
105}
106
107struct Abs {
108 template <typename T>
109 T operator()(T x) {
110 return std::abs(x);
111 };
112 uint8_t operator()(uint8_t x) {
113 return x;
114 };
115 uint16_t operator()(uint16_t x) {
116 return x;
117 };
118 uint32_t operator()(uint32_t x) {
119 return x;
120 };
121 uint64_t operator()(uint64_t x) {
122 return x;
123 };
124 bool operator()(bool x) {
125 return x;
126 };
127};
128
129struct ArcCos {
130 template <typename T>
131 T operator()(T x) {
132 return std::acos(x);
133 };
134};
135
136struct ArcCosh {
137 template <typename T>
138 T operator()(T x) {
139 return std::acosh(x);
140 };
141};
142
143struct ArcSin {
144 template <typename T>
145 T operator()(T x) {
146 return std::asin(x);
147 };
148};
149
150struct ArcSinh {
151 template <typename T>
152 T operator()(T x) {
153 return std::asinh(x);
154 };
155};
156
157struct ArcTan {
158 template <typename T>
159 T operator()(T x) {
160 return std::atan(x);
161 };
162};
163
164struct ArcTan2 {
165 template <typename T>
166 T operator()(T y, T x) {
167 return std::atan2(y, x);
168 };
169};
170
171struct ArcTanh {
172 template <typename T>
173 T operator()(T x) {
174 return std::atanh(x);
175 };
176};
177
178struct Ceil {
179 template <typename T>
180 T operator()(T x) {
181 return std::ceil(x);
182 };
183 int8_t operator()(int8_t x) {
184 return x;
185 };
186 int16_t operator()(int16_t x) {
187 return x;
188 };
189 int32_t operator()(int32_t x) {
190 return x;
191 };
192 int64_t operator()(int64_t x) {
193 return x;
194 };
195 uint8_t operator()(uint8_t x) {
196 return x;
197 };
198 uint16_t operator()(uint16_t x) {
199 return x;
200 };
201 uint32_t operator()(uint32_t x) {
202 return x;
203 };
204 uint64_t operator()(uint64_t x) {
205 return x;
206 };
207 bool operator()(bool x) {
208 return x;
209 };
210};
211
212struct Conjugate {
214 return std::conj(x);
215 }
216};
217
218struct Cos {
219 template <typename T>
220 T operator()(T x) {
221 return std::cos(x);
222 };
223};
224
225struct Cosh {
226 template <typename T>
227 T operator()(T x) {
228 return std::cosh(x);
229 };
230};
231
232struct Erf {
233 template <typename T>
234 T operator()(T x) {
235 return static_cast<T>(fast_erf(static_cast<float>(x)));
236 };
237};
238
239struct ErfInv {
240 template <typename T>
241 T operator()(T x) {
242 return static_cast<T>(fast_erfinv(static_cast<float>(x)));
243 };
244};
245
246struct Exp {
247 template <typename T>
248 T operator()(T x) {
249 return fast_exp(x);
250 };
251
253 return std::exp(x);
254 }
255};
256
257struct Expm1 {
258 template <typename T>
259 T operator()(T x) {
260 return expm1(x);
261 };
262};
263
264struct Floor {
265 template <typename T>
266 T operator()(T x) {
267 return std::floor(x);
268 };
269 int8_t operator()(int8_t x) {
270 return x;
271 };
272 int16_t operator()(int16_t x) {
273 return x;
274 };
275 int32_t operator()(int32_t x) {
276 return x;
277 };
278 int64_t operator()(int64_t x) {
279 return x;
280 };
281 uint8_t operator()(uint8_t x) {
282 return x;
283 };
284 uint16_t operator()(uint16_t x) {
285 return x;
286 };
287 uint32_t operator()(uint32_t x) {
288 return x;
289 };
290 uint64_t operator()(uint64_t x) {
291 return x;
292 };
293 bool operator()(bool x) {
294 return x;
295 };
296};
297
298struct Log {
299 template <typename T>
300 T operator()(T x) {
301 return std::log(x);
302 };
303};
304
305struct Log2 {
306 template <typename T>
307 T operator()(T x) {
308 return std::log2(x);
309 };
310};
311
312struct Log10 {
313 template <typename T>
314 T operator()(T x) {
315 return std::log10(x);
316 };
317};
318
319struct Log1p {
320 template <typename T>
321 T operator()(T x) {
322 return log1p(x);
323 };
324};
325
327 template <typename T>
328 T operator()(T x) {
329 return !x;
330 };
331};
332
333struct Negative {
334 template <typename T>
335 T operator()(T x) {
336 return -x;
337 };
338};
339
340struct Round {
341 template <typename T>
342 T operator()(T x) {
343 return std::rint(x);
344 }
345
347 return {std::rint(x.real()), std::rint(x.imag())};
348 }
349};
350
351struct Sigmoid {
352 template <typename T>
353 T operator()(T x) {
354 auto one = static_cast<decltype(x)>(1.0);
355 return one / (one + fast_exp(-x));
356 }
357};
358
359struct Sign {
360 template <typename T>
361 T operator()(T x) {
362 return (x > T(0)) - (x < T(0));
363 }
364 uint8_t operator()(uint8_t x) {
365 return x != 0;
366 }
367 uint16_t operator()(uint16_t x) {
368 return x != 0;
369 }
370 uint32_t operator()(uint32_t x) {
371 return x != 0;
372 }
373 uint64_t operator()(uint64_t x) {
374 return x != 0;
375 }
376};
377
378struct Sin {
379 template <typename T>
380 T operator()(T x) {
381 return std::sin(x);
382 };
383};
384
385struct Sinh {
386 template <typename T>
387 T operator()(T x) {
388 return std::sinh(x);
389 };
390};
391
392struct Square {
393 template <typename T>
394 T operator()(T x) {
395 return x * x;
396 };
397};
398
399struct Sqrt {
400 template <typename T>
401 T operator()(T x) {
402 return std::sqrt(x);
403 };
404};
405
406struct Rsqrt {
407 template <typename T>
408 T operator()(T x) {
409 return static_cast<decltype(x)>(1.0) / std::sqrt(x);
410 };
411};
412
413struct Tan {
414 template <typename T>
415 T operator()(T x) {
416 return std::tan(x);
417 };
418};
419
420struct Tanh {
421 template <typename T>
422 T operator()(T x) {
423 return std::tanh(x);
424 };
425};
426
427struct Add {
428 template <typename T>
429 T operator()(T x, T y) {
430 return x + y;
431 }
432};
433
434struct Divide {
435 template <typename T>
436 T operator()(T x, T y) {
437 return x / y;
438 }
439};
440
441struct Remainder {
442 template <typename T>
443 std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
444 T numerator,
445 T denominator) {
446 return numerator % denominator;
447 }
448
449 template <typename T>
450 std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
451 T numerator,
452 T denominator) {
453 auto r = numerator % denominator;
454 if (r != 0 && (r < 0 != denominator < 0))
455 r += denominator;
456 return r;
457 }
458
459 template <typename T>
460 std::enable_if_t<!std::is_integral_v<T>, T> operator()(
461 T numerator,
462 T denominator) {
463 auto r = std::fmod(numerator, denominator);
464 if (r != 0 && (r < 0 != denominator < 0)) {
465 r += denominator;
466 }
467 return r;
468 }
469
471 return numerator % denominator;
472 }
473};
474
475struct Equal {
476 template <typename T>
477 bool operator()(T x, T y) {
478 return x == y;
479 }
480};
481
482struct NaNEqual {
483 template <typename T>
484 bool operator()(T x, T y) {
485 return x == y || (std::isnan(x) && std::isnan(y));
486 }
487};
488
489struct Greater {
490 template <typename T>
491 bool operator()(T x, T y) {
492 return x > y;
493 }
494};
495
497 template <typename T>
498 bool operator()(T x, T y) {
499 return x >= y;
500 }
501};
502
503struct Less {
504 template <typename T>
505 bool operator()(T x, T y) {
506 return x < y;
507 }
508};
509
510struct LessEqual {
511 template <typename T>
512 bool operator()(T x, T y) {
513 return x <= y;
514 }
515};
516
517struct Maximum {
518 template <typename T>
519 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
520 return (x > y) ? x : y;
521 }
522
523 template <typename T>
524 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
525 if (std::isnan(x)) {
526 return x;
527 }
528 return (x > y) ? x : y;
529 }
530};
531
532struct Minimum {
533 template <typename T>
534 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
535 return x < y ? x : y;
536 }
537
538 template <typename T>
539 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
540 if (std::isnan(x)) {
541 return x;
542 }
543 return x < y ? x : y;
544 }
545};
546
547struct LogAddExp {
548 template <typename T>
549 T operator()(T x, T y) {
550 constexpr float inf = std::numeric_limits<float>::infinity();
551 auto maxval = Maximum()(x, y);
552 auto minval = Minimum()(x, y);
553 return (minval == -inf || maxval == inf)
554 ? maxval
555 : static_cast<decltype(x)>(
556 maxval + std::log1p(fast_exp(minval - maxval)));
557 };
558};
559
560struct Multiply {
561 template <typename T>
562 T operator()(T x, T y) {
563 return x * y;
564 }
565};
566
567struct NotEqual {
568 template <typename T>
569 bool operator()(T x, T y) {
570 return x != y;
571 }
572};
573
574struct Power {
575 template <typename T>
576 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
577 return std::pow(base, exp);
578 }
579
580 template <typename T>
581 std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
582 T res = 1;
583 while (exp) {
584 if (exp & 1) {
585 res *= base;
586 }
587 exp >>= 1;
588 base *= base;
589 }
590 return res;
591 }
592};
593
594struct Subtract {
595 template <typename T>
596 T operator()(T x, T y) {
597 return x - y;
598 }
599};
600
602 template <typename T>
603 T operator()(T x, T y) {
604 return x && y;
605 };
606};
607
608struct LogicalOr {
609 template <typename T>
610 T operator()(T x, T y) {
611 return x || y;
612 };
613};
614
615struct Select {
616 template <typename T>
617 T operator()(bool condition, T x, T y) {
618 return condition ? x : y;
619 }
620};
621
623 template <typename T>
624 T operator()(T x, T y) {
625 return x & y;
626 };
627};
628
629struct BitwiseOr {
630 template <typename T>
631 T operator()(T x, T y) {
632 return x | y;
633 };
634};
635
637 template <typename T>
638 T operator()(T x, T y) {
639 return x ^ y;
640 };
641};
642
643struct LeftShift {
644 template <typename T>
645 T operator()(T x, T y) {
646 return x << y;
647 };
648};
649
651 template <typename T>
652 T operator()(T x, T y) {
653 return x >> y;
654 };
655};
656
657} // namespace mlx::core::detail
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
Definition ops.h:8
float fast_exp(float x)
Definition ops.h:19
float fast_erf(float a)
Definition ops.h:47
float fast_erfinv(float a)
Definition ops.h:78
Definition complex.h:34
Definition ops.h:107
T operator()(T x)
Definition ops.h:109
uint8_t operator()(uint8_t x)
Definition ops.h:112
uint64_t operator()(uint64_t x)
Definition ops.h:121
uint16_t operator()(uint16_t x)
Definition ops.h:115
bool operator()(bool x)
Definition ops.h:124
uint32_t operator()(uint32_t x)
Definition ops.h:118
Definition ops.h:427
T operator()(T x, T y)
Definition ops.h:429
Definition ops.h:129
T operator()(T x)
Definition ops.h:131
Definition ops.h:136
T operator()(T x)
Definition ops.h:138
Definition ops.h:143
T operator()(T x)
Definition ops.h:145
Definition ops.h:150
T operator()(T x)
Definition ops.h:152
Definition ops.h:164
T operator()(T y, T x)
Definition ops.h:166
Definition ops.h:157
T operator()(T x)
Definition ops.h:159
Definition ops.h:171
T operator()(T x)
Definition ops.h:173
Definition ops.h:622
T operator()(T x, T y)
Definition ops.h:624
Definition ops.h:629
T operator()(T x, T y)
Definition ops.h:631
Definition ops.h:636
T operator()(T x, T y)
Definition ops.h:638
Definition ops.h:178
uint8_t operator()(uint8_t x)
Definition ops.h:195
T operator()(T x)
Definition ops.h:180
uint32_t operator()(uint32_t x)
Definition ops.h:201
int8_t operator()(int8_t x)
Definition ops.h:183
int16_t operator()(int16_t x)
Definition ops.h:186
bool operator()(bool x)
Definition ops.h:207
uint16_t operator()(uint16_t x)
Definition ops.h:198
uint64_t operator()(uint64_t x)
Definition ops.h:204
int32_t operator()(int32_t x)
Definition ops.h:189
int64_t operator()(int64_t x)
Definition ops.h:192
Definition ops.h:212
complex64_t operator()(complex64_t x)
Definition ops.h:213
Definition ops.h:218
T operator()(T x)
Definition ops.h:220
Definition ops.h:225
T operator()(T x)
Definition ops.h:227
Definition ops.h:434
T operator()(T x, T y)
Definition ops.h:436
Definition ops.h:475
bool operator()(T x, T y)
Definition ops.h:477
Definition ops.h:232
T operator()(T x)
Definition ops.h:234
Definition ops.h:239
T operator()(T x)
Definition ops.h:241
Definition ops.h:246
T operator()(T x)
Definition ops.h:248
complex64_t operator()(complex64_t x)
Definition ops.h:252
Definition ops.h:257
T operator()(T x)
Definition ops.h:259
Definition ops.h:264
T operator()(T x)
Definition ops.h:266
uint32_t operator()(uint32_t x)
Definition ops.h:287
uint16_t operator()(uint16_t x)
Definition ops.h:284
uint8_t operator()(uint8_t x)
Definition ops.h:281
int32_t operator()(int32_t x)
Definition ops.h:275
int64_t operator()(int64_t x)
Definition ops.h:278
bool operator()(bool x)
Definition ops.h:293
int8_t operator()(int8_t x)
Definition ops.h:269
uint64_t operator()(uint64_t x)
Definition ops.h:290
int16_t operator()(int16_t x)
Definition ops.h:272
bool operator()(T x, T y)
Definition ops.h:498
Definition ops.h:489
bool operator()(T x, T y)
Definition ops.h:491
Definition ops.h:643
T operator()(T x, T y)
Definition ops.h:645
Definition ops.h:510
bool operator()(T x, T y)
Definition ops.h:512
Definition ops.h:503
bool operator()(T x, T y)
Definition ops.h:505
Definition ops.h:312
T operator()(T x)
Definition ops.h:314
Definition ops.h:319
T operator()(T x)
Definition ops.h:321
Definition ops.h:305
T operator()(T x)
Definition ops.h:307
Definition ops.h:547
T operator()(T x, T y)
Definition ops.h:549
Definition ops.h:298
T operator()(T x)
Definition ops.h:300
Definition ops.h:601
T operator()(T x, T y)
Definition ops.h:603
Definition ops.h:326
T operator()(T x)
Definition ops.h:328
Definition ops.h:608
T operator()(T x, T y)
Definition ops.h:610
Definition ops.h:517
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:519
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:524
Definition ops.h:532
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:539
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:534
Definition ops.h:560
T operator()(T x, T y)
Definition ops.h:562
Definition ops.h:482
bool operator()(T x, T y)
Definition ops.h:484
Definition ops.h:333
T operator()(T x)
Definition ops.h:335
Definition ops.h:567
bool operator()(T x, T y)
Definition ops.h:569
Definition ops.h:574
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:576
std::enable_if_t< std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:581
Definition ops.h:441
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:460
std::enable_if_t< std::is_integral_v< T > &!std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:443
std::enable_if_t< std::is_integral_v< T > &std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:450
complex64_t operator()(complex64_t numerator, complex64_t denominator)
Definition ops.h:470
Definition ops.h:650
T operator()(T x, T y)
Definition ops.h:652
Definition ops.h:340
T operator()(T x)
Definition ops.h:342
complex64_t operator()(complex64_t x)
Definition ops.h:346
Definition ops.h:406
T operator()(T x)
Definition ops.h:408
Definition ops.h:615
T operator()(bool condition, T x, T y)
Definition ops.h:617
Definition ops.h:351
T operator()(T x)
Definition ops.h:353
Definition ops.h:359
uint64_t operator()(uint64_t x)
Definition ops.h:373
T operator()(T x)
Definition ops.h:361
uint8_t operator()(uint8_t x)
Definition ops.h:364
uint16_t operator()(uint16_t x)
Definition ops.h:367
uint32_t operator()(uint32_t x)
Definition ops.h:370
Definition ops.h:378
T operator()(T x)
Definition ops.h:380
Definition ops.h:385
T operator()(T x)
Definition ops.h:387
Definition ops.h:399
T operator()(T x)
Definition ops.h:401
Definition ops.h:392
T operator()(T x)
Definition ops.h:394
Definition ops.h:594
T operator()(T x, T y)
Definition ops.h:596
Definition ops.h:413
T operator()(T x)
Definition ops.h:415
Definition ops.h:420
T operator()(T x)
Definition ops.h:422
uint32_t u
Definition bf16.h:17
float f
Definition ops.h:16
int i
Definition ops.h:15