mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-09 23:24:45 +08:00
@@ -347,6 +347,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
masks = [
|
||||
None,
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
@@ -392,7 +393,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
def test_fast_sdpa_few_query(self):
|
||||
D = 64
|
||||
L = 43
|
||||
Lq = 4
|
||||
Lq = 8
|
||||
Nq = 8
|
||||
Nkv = 1
|
||||
scale = 1.0
|
||||
@@ -403,6 +404,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
masks = [
|
||||
None,
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
@@ -428,6 +430,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
masks = [
|
||||
None,
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
|
Reference in New Issue
Block a user