diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp index 3a2dcd532..8326ba1c3 100644 --- a/mlx/backend/accelerate/softmax.cpp +++ b/mlx/backend/accelerate/softmax.cpp @@ -33,8 +33,8 @@ namespace { * Note: The implementation below is a general fast exp. There could be faster * implementations for numbers strictly < 0. */ -inline simd_float16 simd_fast_exp(simd_float16 x) { - x *= 1.442695; // multiply with log_2(e) +inline simd_float16 simd_fast_exp(simd_float16 x_init) { + auto x = x_init * 1.442695; // multiply with log_2(e) simd_float16 ipart, fpart; simd_int16 epart; x = simd_clamp(x, -80, 80); @@ -53,7 +53,9 @@ inline simd_float16 simd_fast_exp(simd_float16 x) { // bitshifting epart = (simd_int(ipart) + 127) << 23; - return (*(simd_float16*)&epart) * x; + // Avoid supressing NaNs + simd_int16 eq = (x_init == x_init); + return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index a455dabd2..b36b73bd8 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -32,12 +32,12 @@ template } } else { for (int i = 0; i < N_READS; i++) { - ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i]) - : Limits::finite_min; + ld[i] = + ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; } } if (simd_group_id == 0) { - local_max[simd_lane_id] = Limits::finite_min; + local_max[simd_lane_id] = Limits::min; local_normalizer[simd_lane_id] = 0; } threadgroup_barrier(mem_flags::mem_threadgroup); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index bc5f4e8b4..311858670 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1567,6 +1567,11 @@ class TestOps(mlx_tests.MLXTestCase): out = mx.softmax(a, axis=-1, precise=True) self.assertTrue(mx.allclose(out_expect, out)) + # All Infs give NaNs + for n in [127, 128, 129]: + x = mx.full((n,), vals=-float("inf")) + self.assertTrue(mx.all(mx.isnan(mx.softmax(x)))) + def test_concatenate(self): a_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32)