mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-03 22:34:43 +08:00
Support transposed head/seq for kv (#1950)
* support transposed head/seq for kv * fix flaky test * nit
This commit is contained in:
@@ -171,7 +171,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
rtol = 1e-2
|
||||
|
||||
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
||||
|
||||
q = mx.random.normal(shape=(1, 32, 1, Dk))
|
||||
k = mx.random.normal(shape=(1, 32, 32, Dk))
|
||||
v = mx.random.normal(shape=(1, 32, 128, Dk))
|
||||
@@ -201,6 +200,38 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||
|
||||
def test_fast_sdpa_vector_kv_transposed_head_seq(self):
|
||||
D = 64
|
||||
Nq = 4
|
||||
Nkv = 1
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
|
||||
|
||||
lengths = [43, 4096]
|
||||
for L in lengths:
|
||||
k = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))
|
||||
k = k.swapaxes(1, 2)
|
||||
v = v.swapaxes(1, 2)
|
||||
masks = [
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
]
|
||||
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=m,
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_fast_sdpa_vector(self):
|
||||
D = 64
|
||||
L = 43
|
||||
@@ -292,7 +323,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
return
|
||||
L = 4096
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
|
Reference in New Issue
Block a user