MLX
Loading...
Searching...
No Matches
mlx
backend
metal
kernels
erf.h
Go to the documentation of this file.
1
// Copyright © 2023 Apple Inc.
2
3
#pragma once
4
5
#include <metal_math>
6
7
/*
8
* Approximation to the error function.
9
* Based on code from:
10
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
11
*/
12
float
erf
(
float
a) {
13
float
r, s, t,
u
;
14
t =
metal::abs
(a);
15
s = a * a;
16
if
(t > 0.927734375f) {
17
// maximum error 0.99527 ulp
18
r =
metal::fma
(
19
-1.72853470e-5f, t, 3.83197126e-4f);
// -0x1.220000p-16,0x1.91cfb2p-12
20
u
=
metal::fma
(
21
-3.88396438e-3f, t, 2.42546219e-2f);
// -0x1.fd1438p-9, 0x1.8d6342p-6
22
r =
metal::fma
(r, s,
u
);
23
r =
metal::fma
(r, t, -1.06777877e-1f);
// -0x1.b55cb8p-4
24
r =
metal::fma
(r, t, -6.34846687e-1f);
// -0x1.450aa0p-1
25
r =
metal::fma
(r, t, -1.28717512e-1f);
// -0x1.079d0cp-3
26
r =
metal::fma
(r, t, -t);
27
// TODO, replace with expm1 when implemented
28
r = 1.0f -
metal::exp
(r);
29
r = metal::copysign(r, a);
30
}
else
{
31
// maximum error 0.98929 ulp
32
r = -5.96761703e-4f;
// -0x1.38e000p-11
33
r =
metal::fma
(r, s, 4.99119423e-3f);
// 0x1.471a58p-8
34
r =
metal::fma
(r, s, -2.67681349e-2f);
// -0x1.b691b2p-6
35
r =
metal::fma
(r, s, 1.12819925e-1f);
// 0x1.ce1c44p-4
36
r =
metal::fma
(r, s, -3.76125336e-1f);
// -0x1.812700p-2
37
r =
metal::fma
(r, s, 1.28379166e-1f);
// 0x1.06eba8p-3
38
r =
metal::fma
(r, a, a);
39
}
40
return
r;
41
}
42
43
float
erfinv
(
float
a) {
44
auto
t =
metal::fma
(a, 0.0f - a, 1.0f);
45
t =
metal::log
(t);
46
float
p;
47
if
(
metal::abs
(t) > 6.125f) {
// maximum ulp error = 2.35793
48
p = 3.03697567e-10f;
// 0x1.4deb44p-32
49
p =
metal::fma
(p, t, 2.93243101e-8f);
// 0x1.f7c9aep-26
50
p =
metal::fma
(p, t, 1.22150334e-6f);
// 0x1.47e512p-20
51
p =
metal::fma
(p, t, 2.84108955e-5f);
// 0x1.dca7dep-16
52
p =
metal::fma
(p, t, 3.93552968e-4f);
// 0x1.9cab92p-12
53
p =
metal::fma
(p, t, 3.02698812e-3f);
// 0x1.8cc0dep-9
54
p =
metal::fma
(p, t, 4.83185798e-3f);
// 0x1.3ca920p-8
55
p =
metal::fma
(p, t, -2.64646143e-1f);
// -0x1.0eff66p-2
56
p =
metal::fma
(p, t, 8.40016484e-1f);
// 0x1.ae16a4p-1
57
}
else
{
// maximum ulp error = 2.35002
58
p = 5.43877832e-9f;
// 0x1.75c000p-28
59
p =
metal::fma
(p, t, 1.43285448e-7f);
// 0x1.33b402p-23
60
p =
metal::fma
(p, t, 1.22774793e-6f);
// 0x1.499232p-20
61
p =
metal::fma
(p, t, 1.12963626e-7f);
// 0x1.e52cd2p-24
62
p =
metal::fma
(p, t, -5.61530760e-5f);
// -0x1.d70bd0p-15
63
p =
metal::fma
(p, t, -1.47697632e-4f);
// -0x1.35be90p-13
64
p =
metal::fma
(p, t, 2.31468678e-3f);
// 0x1.2f6400p-9
65
p =
metal::fma
(p, t, 1.15392581e-2f);
// 0x1.7a1e50p-7
66
p =
metal::fma
(p, t, -2.32015476e-1f);
// -0x1.db2aeep-3
67
p =
metal::fma
(p, t, 8.86226892e-1f);
// 0x1.c5bf88p-1
68
}
69
return
a * p;
70
}
erfinv
float erfinv(float a)
Definition
erf.h:43
erf
float erf(float a)
Definition
erf.h:12
metal::log
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition
bf16_math.h:234
metal::fma
METAL_FUNC bfloat16_t fma(bfloat16_t x, bfloat16_t y, bfloat16_t z)
Definition
bf16_math.h:234
metal::abs
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition
bf16_math.h:234
metal::exp
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition
bf16_math.h:234
u
uint32_t u
Definition
bf16.h:17
Generated by
1.10.0