From 630350ad3e4639b2f08462d3f45d0f459daff5a3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 10 Oct 2025 10:05:23 -0700 Subject: [PATCH] Precise sigmoid (#2659) * bump patch * Sigmoid matches PyTorch and is more precise on tails --- mlx/backend/cpu/unary_ops.h | 3 ++- mlx/backend/cuda/device/unary_ops.cuh | 4 ++-- mlx/backend/metal/kernels/unary_ops.h | 4 ++-- python/tests/test_ops.py | 6 ++++++ 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index 633230658..255e28a19 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -77,7 +77,8 @@ struct Real { struct Sigmoid { template Simd operator()(Simd x) { - return 1.0f / (1.0f + simd::exp(-x)); + auto y = 1.0f / (1.0f + simd::exp(simd::abs(x))); + return simd::select(x < Simd{0}, y, Simd{1} - y); } SINGLE() }; diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index aebed1e4d..fcd083f2f 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -257,8 +257,8 @@ struct Round { struct Sigmoid { template __device__ T operator()(T x) { - T y = 1 / (1 + exp(-abs(x))); - return (x < 0) ? 1 - y : y; + T y = 1 / (1 + exp(abs(x))); + return (x < 0) ? y : 1 - y; } }; diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index b34bc44ba..44d43cee8 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -309,8 +309,8 @@ struct Round { struct Sigmoid { template T operator()(T x) { - auto y = 1 / (1 + metal::exp(-metal::abs(x))); - return (x < 0) ? 1 - y : y; + auto y = 1 / (1 + metal::exp(metal::abs(x))); + return (x < 0) ? y : 1 - y; } }; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index e60952aa7..af262ed51 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1041,6 +1041,12 @@ class TestOps(mlx_tests.MLXTestCase): expected = 1 / (1 + np.exp(-a, dtype=np.float32)) self.assertTrue(np.allclose(result, expected)) + # Low precision + a = mx.array(-8.0).astype(mx.float16) + self.assertNotEqual(mx.sigmoid(a).item(), 0.0) + a = mx.array(8.0).astype(mx.float16) + self.assertNotEqual(mx.sigmoid(a).item(), 1.0) + def test_allclose(self): a = mx.array(1.0) b = mx.array(1.0)