diff --git a/mlx/backend/cpu/simd/math.h b/mlx/backend/cpu/simd/math.h index 3730aac5e..f9fc8317a 100644 --- a/mlx/backend/cpu/simd/math.h +++ b/mlx/backend/cpu/simd/math.h @@ -186,7 +186,7 @@ Simd erfinv(Simd a_) { return a * rhs(t); } } else { - return a * select(t > thresh, lhs(t), rhs(t)); + return a * select(abs(t) > thresh, lhs(t), rhs(t)); } } diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index ff9c7fa30..2951188e9 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -4,6 +4,7 @@ set(BASE_HEADERS bf16_math.h complex.h defines.h + erf.h expm1f.h utils.h) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 43f1b3335..f5078afc0 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -898,6 +898,10 @@ class TestOps(mlx_tests.MLXTestCase): ).astype(np.float32) self.assertTrue(np.allclose(mx.erfinv(x), expected, equal_nan=True)) + result = mx.erfinv(mx.array([0.9999999403953552] * 8)) + expected = mx.array([3.8325066566467285] * 8) + self.assertTrue(mx.allclose(result, expected)) + def test_sin(self): a = mx.array( [0, math.pi / 4, math.pi / 2, math.pi, 3 * math.pi / 4, 2 * math.pi]