diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index b672c34e6..2e27ea06f 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -229,7 +229,7 @@ template < // Init to -Inf STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { - max_score[i] = Limits::min; + max_score[i] = Limits::finite_min; } int kb_lim = params->NK; @@ -273,7 +273,7 @@ template < if (!align_K && kb == (params->NK_aligned)) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; - constexpr auto neg_inf = -metal::numeric_limits::infinity(); + constexpr auto neg_inf = Limits::finite_min; STEEL_PRAGMA_UNROLL for (short i = 0; i < stile_t::kTileRows; i++) { @@ -294,7 +294,7 @@ template < if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; - constexpr auto neg_inf = -metal::numeric_limits::infinity(); + constexpr auto neg_inf = Limits::finite_min; STEEL_PRAGMA_UNROLL for (short i = 0; i < stile_t::kTileRows; i++) { @@ -317,7 +317,7 @@ template < if (has_mask) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; - constexpr auto neg_inf = -metal::numeric_limits::infinity(); + constexpr auto neg_inf = Limits::finite_min; constexpr bool is_bool = is_same_v; using melem_t = typename metal::conditional_t; diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index f3b6656f8..b767fbc8f 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -572,6 +572,34 @@ class TestSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_nan_bug(self): + N = 128 + q_shape = (1, 1, N, 128) + kv_shape = (1, 1, N, 128) + q = mx.random.uniform(shape=q_shape) + k = mx.random.uniform(shape=kv_shape) + v = mx.random.uniform(shape=kv_shape) + + # Make boolean window causal mask + linds = rinds = mx.arange(N) + linds = linds[:, None] + rinds = rinds[None] + mask = linds >= rinds + mask = mask & (linds <= rinds + 111) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) + expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0) + self.assertFalse(mx.isnan(out).any().item()) + self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4) + + # And an additive one + mask = mx.log(mask) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) + expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0) + self.assertFalse(mx.isnan(out).any().item()) + self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4) + if __name__ == "__main__": unittest.main(failfast=True)