full row mask in sdpa consistently gives nan (#2406)

This commit is contained in:
Awni Hannun 2025-07-23 16:37:03 -07:00 committed by GitHub
parent 0f5ce173da
commit e1840853ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 1 deletions

View File

@ -708,7 +708,10 @@ array scaled_dot_product_attention(
}
if (mask.dtype() == bool_) {
scores = where(
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
mask,
scores,
array(-std::numeric_limits<float>::infinity(), scores.dtype()),
s);
} else {
scores = add(scores, mask, s);
}

View File

@ -398,6 +398,18 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fully_masked(self):
Lkv = 8
mask = mx.array(False)
for D in [4, 128]:
for Lq in [1, 8]:
q = mx.random.normal(shape=(1, 4, Lq, D))
k = mx.random.normal(shape=(1, 4, Lkv, D))
v = mx.random.normal(shape=(1, 4, Lkv, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
self.assertTrue(mx.all(mx.isnan(out)))
def test_fast_sdpa_few_query(self):
D = 64
L = 43