mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
SDPA support for small batch (over sequence) queries (#1922)
* batch query sdpa * batch sdpa for query
This commit is contained in:
@@ -262,6 +262,61 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_fast_sdpa_few_query(self):
|
||||
D = 64
|
||||
L = 43
|
||||
Lq = 4
|
||||
Nq = 8
|
||||
Nkv = 1
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D))
|
||||
q = q.swapaxes(1, 2)
|
||||
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
masks = [
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=m,
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
return
|
||||
L = 4096
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D))
|
||||
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
masks = [
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=m,
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@unittest.skip("Different head and value dims is not enabled")
|
||||
def test_fast_sdpa_vector_value_dims(self):
|
||||
D = 192
|
||||
|
Reference in New Issue
Block a user