mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
Swap -inf for finite_minimum value (#2029)
This commit is contained in:

committed by
GitHub

parent
90823d2938
commit
ec2854b13a
@@ -572,6 +572,34 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_nan_bug(self):
|
||||
N = 128
|
||||
q_shape = (1, 1, N, 128)
|
||||
kv_shape = (1, 1, N, 128)
|
||||
q = mx.random.uniform(shape=q_shape)
|
||||
k = mx.random.uniform(shape=kv_shape)
|
||||
v = mx.random.uniform(shape=kv_shape)
|
||||
|
||||
# Make boolean window causal mask
|
||||
linds = rinds = mx.arange(N)
|
||||
linds = linds[:, None]
|
||||
rinds = rinds[None]
|
||||
mask = linds >= rinds
|
||||
mask = mask & (linds <= rinds + 111)
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
|
||||
expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)
|
||||
self.assertFalse(mx.isnan(out).any().item())
|
||||
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
|
||||
|
||||
# And an additive one
|
||||
mask = mx.log(mask)
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
|
||||
expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)
|
||||
self.assertFalse(mx.isnan(out).any().item())
|
||||
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Reference in New Issue
Block a user