fix batched vector sdpa (#2152)

This commit is contained in:
Awni Hannun
2025-05-05 13:13:03 -07:00
committed by GitHub
parent 825124af8f
commit af705590ac
3 changed files with 105 additions and 50 deletions

View File

@@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_vector_batched(self):
D = 64
q = mx.random.normal(shape=(2, 1, 3, D))
k = mx.random.normal(shape=(2, 1, 3, D))
v = mx.random.normal(shape=(2, 1, 3, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
q = mx.random.normal(shape=(2, 4, 3, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
q = mx.random.normal(shape=(2, 4, 3, D))
k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2)
v = mx.random.normal(shape=(2, 2, 3, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
ref = mlx_ref_attn(q, k, v)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
q = mx.random.normal(shape=(2, 4, 3, D))
k = mx.random.normal(shape=(2, 1, 3, D))
v = mx.random.normal(shape=(2, 1, 3, D))
mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
ref = mlx_ref_attn(q, k, v, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
class TestSDPA(mlx_tests.MLXTestCase):
@property