Swap -inf for finite_minimum value (#2029)

This commit is contained in:
Angelos Katharopoulos
2025-03-30 21:55:04 -07:00
committed by GitHub
parent 90823d2938
commit ec2854b13a
2 changed files with 32 additions and 4 deletions

View File

@@ -572,6 +572,34 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_nan_bug(self):
N = 128
q_shape = (1, 1, N, 128)
kv_shape = (1, 1, N, 128)
q = mx.random.uniform(shape=q_shape)
k = mx.random.uniform(shape=kv_shape)
v = mx.random.uniform(shape=kv_shape)
# Make boolean window causal mask
linds = rinds = mx.arange(N)
linds = linds[:, None]
rinds = rinds[None]
mask = linds >= rinds
mask = mask & (linds <= rinds + 111)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)
self.assertFalse(mx.isnan(out).any().item())
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
# And an additive one
mask = mx.log(mask)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)
self.assertFalse(mx.isnan(out).any().item())
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
if __name__ == "__main__":
unittest.main(failfast=True)