mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:10:15 +08:00
Add boolean mask support in vector SDPA (#1757)
This commit is contained in:
@@ -10,7 +10,10 @@ import numpy as np
|
||||
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
|
||||
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
p += mask
|
||||
if mask.dtype == mx.bool_:
|
||||
p = mx.where(mask, p, mx.finfo(mx.float32).min)
|
||||
else:
|
||||
p += mask
|
||||
scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
|
||||
return scores @ v
|
||||
|
||||
@@ -198,6 +201,67 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||
|
||||
def test_fast_sdpa_vector(self):
|
||||
D = 64
|
||||
L = 43
|
||||
Nq = 4
|
||||
Nkv = 1
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
|
||||
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=mx.full((Nq, 2, L), False),
|
||||
)
|
||||
|
||||
masks = [
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=m,
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
L = 4096
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
|
||||
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
masks = [
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=m,
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Reference in New Issue
Block a user