diff --git a/mlx/backend/metal/kernels/expm1f.h b/mlx/backend/metal/kernels/expm1f.h index b649dd99a..68224e179 100644 --- a/mlx/backend/metal/kernels/expm1f.h +++ b/mlx/backend/metal/kernels/expm1f.h @@ -83,6 +83,7 @@ float expm1f(float a) { r = expm1f_scaled_unchecked(a, 1.0f); /* handle severe overflow and underflow */ if (abs(a - 1.0f) > 88.0f) { + r = pow(2, a); r = fma(r, r, -1.0f); } return r; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4b0abebdb..e8b1cfc38 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -845,7 +845,7 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) def test_expm1(self): - a = mx.array([0, 0.5, -0.5, 5]) + a = mx.array([-88, -87, 0, 0.5, -0.5, 5, 87, 88, 89, 90]) result = mx.expm1(a) expected = np.expm1(a, dtype=np.float32)