From e1840853ce051b2621e82ca461f74f62bfdd9b9c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 23 Jul 2025 16:37:03 -0700 Subject: [PATCH] full row mask in sdpa consistently gives nan (#2406) --- mlx/fast.cpp | 5 ++++- python/tests/test_fast_sdpa.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 210c7f729..b8d622253 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -708,7 +708,10 @@ array scaled_dot_product_attention( } if (mask.dtype() == bool_) { scores = where( - mask, scores, array(finfo(scores.dtype()).min, scores.dtype())); + mask, + scores, + array(-std::numeric_limits::infinity(), scores.dtype()), + s); } else { scores = add(scores, mask, s); } diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index a929e91cf..8f9cb34cf 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -398,6 +398,18 @@ class TestFastSDPA(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_fully_masked(self): + Lkv = 8 + mask = mx.array(False) + for D in [4, 128]: + for Lq in [1, 8]: + q = mx.random.normal(shape=(1, 4, Lq, D)) + k = mx.random.normal(shape=(1, 4, Lkv, D)) + v = mx.random.normal(shape=(1, 4, Lkv, D)) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1) + self.assertTrue(mx.all(mx.isnan(out))) + def test_fast_sdpa_few_query(self): D = 64 L = 43