Add boolean mask support in vector SDPA (#1757)

This commit is contained in:
Awni Hannun
2025-01-07 20:24:53 -08:00
committed by GitHub
parent 516ded618b
commit d1766f2c70
5 changed files with 226 additions and 74 deletions

View File

@@ -10,7 +10,10 @@ import numpy as np
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
p = (q * scale) @ k.transpose(0, 1, 3, 2)
if mask is not None:
p += mask
if mask.dtype == mx.bool_:
p = mx.where(mask, p, mx.finfo(mx.float32).min)
else:
p += mask
scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
return scores @ v
@@ -198,6 +201,67 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
def test_fast_sdpa_vector(self):
D = 64
L = 43
Nq = 4
Nkv = 1
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=mx.full((Nq, 2, L), False),
)
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
L = 4096
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
if __name__ == "__main__":
unittest.main(failfast=True)