Fix causal mask sdpa vec (#2053)

* fix sdpa vector causal mask

* test
This commit is contained in:
Awni Hannun
2025-04-08 09:11:23 -07:00
committed by GitHub
parent 08a1bf3f10
commit 00794c42bc
2 changed files with 6 additions and 1 deletions

View File

@@ -347,6 +347,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
masks = [
None,
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
@@ -392,7 +393,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
def test_fast_sdpa_few_query(self):
D = 64
L = 43
Lq = 4
Lq = 8
Nq = 8
Nkv = 1
scale = 1.0
@@ -403,6 +404,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
None,
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
@@ -428,6 +430,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
None,
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,