Add float mask to sdpa vector (#2068)

This commit is contained in:
Angelos Katharopoulos
2025-04-11 17:29:40 -07:00
committed by GitHub
parent 68d1b3256b
commit c4189a38e4
5 changed files with 94 additions and 50 deletions

View File

@@ -352,6 +352,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
mx.random.uniform(shape=(Nq, 1, L)),
mx.random.uniform(shape=(L, 1, Nq)).T,
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
"causal",
]
for m in masks:
@@ -377,6 +381,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
mx.random.uniform(shape=(Nq, 1, L)),
mx.random.uniform(shape=(L, 1, Nq)).T,
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
"causal",
]
for m in masks: