MLX
 
Loading...
Searching...
No Matches
expm1f.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <metal_math>
6
7// Original license copied below:
8// Copyright (c) 2015-2023 Norbert Juffa
9// All rights reserved.
10//
11// Redistribution and use in source and binary forms, with or without
12// modification, are permitted provided that the following conditions
13// are met:
14//
15// 1. Redistributions of source code must retain the above copyright
16// notice, this list of conditions and the following disclaimer.
17//
18// 2. Redistributions in binary form must reproduce the above copyright
19// notice, this list of conditions and the following disclaimer in the
20// documentation and/or other materials provided with the distribution.
21//
22// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34/* Compute exponential base e minus 1. Maximum ulp error = 0.997458
35
36 i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
37 Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
38 With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
39 when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
40
41 NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
42*/
43float expm1f_scaled_unchecked(float a, float b) {
44 float f, j, r, s, t, u, v, x, y;
45 int i;
46
47 // exp(a) = 2**i * exp(f); i = rintf (a / log(2))
48 j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
49 j = j - 12582912.0f; // 0x1.8p23
50 i = (int)j;
51 f = fma(j, -6.93145752e-1f, a);
52
53 // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
54 s = f * f;
55 if (a == 0.0f)
56 s = a; // ensure -0 is passed through
57 // err = 0.997458 ulp1 = 11081805
58 r = 1.97350979e-4f; // 0x1.9de000p-13
59 r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
60 r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
61 r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
62 r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
63 r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
64 u = (j == 1) ? (f + 0.5f) : f;
65 v = fma(r, s, u);
66 s = 0.5f * b;
67 t = ldexp(s, i);
68 y = t - s;
69 x = (t - y) - s; // double-float canonicalization of difference
70 r = fma(v, t, x) + y;
71 r = r + r;
72 if (j == 0)
73 r = v;
74 if (j == 1)
75 r = v + v;
76 return r;
77}
78
79/* Compute exponential base e minus 1. max ulp err = 0.99746 */
80float expm1f(float a) {
81 float r;
82
83 r = expm1f_scaled_unchecked(a, 1.0f);
84 /* handle severe overflow and underflow */
85 if (abs(a - 1.0f) > 88.0f) {
86 r = pow(2, a);
87 r = fma(r, r, -1.0f);
88 }
89 return r;
90}
float expm1f(float a)
Definition expm1f.h:80
float expm1f_scaled_unchecked(float a, float b)
Definition expm1f.h:43