mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -347,6 +347,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|             ) | ||||
|  | ||||
|         masks = [ | ||||
|             None, | ||||
|             mx.array(True), | ||||
|             mx.array([True] * (L - 10) + [False] * 10), | ||||
|             mx.random.uniform(shape=(Nq, 1, L)) > 0.2, | ||||
| @@ -392,7 +393,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|     def test_fast_sdpa_few_query(self): | ||||
|         D = 64 | ||||
|         L = 43 | ||||
|         Lq = 4 | ||||
|         Lq = 8 | ||||
|         Nq = 8 | ||||
|         Nkv = 1 | ||||
|         scale = 1.0 | ||||
| @@ -403,6 +404,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|         v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|  | ||||
|         masks = [ | ||||
|             None, | ||||
|             mx.array(True), | ||||
|             mx.array([True] * (L - 10) + [False] * 10), | ||||
|             mx.random.uniform(shape=(Nq, 1, L)) > 0.2, | ||||
| @@ -428,6 +430,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|         v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|  | ||||
|         masks = [ | ||||
|             None, | ||||
|             mx.array(True), | ||||
|             mx.array([True] * (L - 10) + [False] * 10), | ||||
|             mx.random.uniform(shape=(Nq, 1, L)) > 0.2, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun