promote mask when needed (#1998)

This commit is contained in:
Awni Hannun
2025-03-23 19:58:28 -07:00
committed by GitHub
parent f018e248cd
commit a84cc0123f
2 changed files with 18 additions and 0 deletions

View File

@@ -543,6 +543,22 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_prommote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64
Nq = 4
Nkv = 1
scale = 1.0
L = 256
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
if __name__ == "__main__":
unittest.main(failfast=True)