SDPA support for small batch (over sequence) queries (#1922)

* batch query sdpa

* batch sdpa for query
This commit is contained in:
Awni Hannun
2025-03-04 10:59:04 -08:00
committed by GitHub
parent 6bcd6bcf70
commit e613d0eaf0
5 changed files with 159 additions and 45 deletions

View File

@@ -262,6 +262,61 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fast_sdpa_few_query(self):
D = 64
L = 43
Lq = 4
Nq = 8
Nkv = 1
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D))
q = q.swapaxes(1, 2)
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
return
L = 4096
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
@unittest.skip("Different head and value dims is not enabled")
def test_fast_sdpa_vector_value_dims(self):
D = 192