mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
Add float mask to sdpa vector (#2068)
This commit is contained in:

committed by
GitHub

parent
68d1b3256b
commit
c4189a38e4
@@ -352,6 +352,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
mx.random.uniform(shape=(Nq, 1, L)),
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T,
|
||||
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
|
||||
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
|
||||
"causal",
|
||||
]
|
||||
for m in masks:
|
||||
@@ -377,6 +381,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
mx.random.uniform(shape=(Nq, 1, L)),
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T,
|
||||
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
|
||||
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
|
||||
"causal",
|
||||
]
|
||||
for m in masks:
|
||||
|
Reference in New Issue
Block a user