Support transposed head/seq for kv (#1950)

* support transposed head/seq for kv

* fix flaky test

* nit
This commit is contained in:
Awni Hannun
2025-03-10 10:53:45 -07:00
committed by GitHub
parent cffceda6ee
commit 3c3e558c60
4 changed files with 84 additions and 45 deletions

View File

@@ -171,7 +171,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
rtol = 1e-2
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
q = mx.random.normal(shape=(1, 32, 1, Dk))
k = mx.random.normal(shape=(1, 32, 32, Dk))
v = mx.random.normal(shape=(1, 32, 128, Dk))
@@ -201,6 +200,38 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
def test_fast_sdpa_vector_kv_transposed_head_seq(self):
D = 64
Nq = 4
Nkv = 1
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
lengths = [43, 4096]
for L in lengths:
k = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))
v = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))
k = k.swapaxes(1, 2)
v = v.swapaxes(1, 2)
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fast_sdpa_vector(self):
D = 64
L = 43
@@ -292,7 +323,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
return
L = 4096
scale = 1.0
mx.random.seed(0)