mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add float mask to sdpa vector (#2068)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							68d1b3256b
						
					
				
				
					commit
					c4189a38e4
				
			| @@ -352,6 +352,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|             mx.array([True] * (L - 10) + [False] * 10), | ||||
|             mx.random.uniform(shape=(Nq, 1, L)) > 0.2, | ||||
|             mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, | ||||
|             mx.random.uniform(shape=(Nq, 1, L)), | ||||
|             mx.random.uniform(shape=(L, 1, Nq)).T, | ||||
|             mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2), | ||||
|             mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2), | ||||
|             "causal", | ||||
|         ] | ||||
|         for m in masks: | ||||
| @@ -377,6 +381,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|             mx.array([True] * (L - 10) + [False] * 10), | ||||
|             mx.random.uniform(shape=(Nq, 1, L)) > 0.2, | ||||
|             mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, | ||||
|             mx.random.uniform(shape=(Nq, 1, L)), | ||||
|             mx.random.uniform(shape=(L, 1, Nq)).T, | ||||
|             mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2), | ||||
|             mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2), | ||||
|             "causal", | ||||
|         ] | ||||
|         for m in masks: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user