mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add boolean mask support in vector SDPA (#1757)
This commit is contained in:
		| @@ -10,7 +10,10 @@ import numpy as np | ||||
| def mlx_primitives_sdpa(q, k, v, scale, mask=None): | ||||
|     p = (q * scale) @ k.transpose(0, 1, 3, 2) | ||||
|     if mask is not None: | ||||
|         p += mask | ||||
|         if mask.dtype == mx.bool_: | ||||
|             p = mx.where(mask, p, mx.finfo(mx.float32).min) | ||||
|         else: | ||||
|             p += mask | ||||
|     scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype) | ||||
|     return scores @ v | ||||
|  | ||||
| @@ -198,6 +201,67 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(y, y_hat, atol=atol)) | ||||
|  | ||||
|     def test_fast_sdpa_vector(self): | ||||
|         D = 64 | ||||
|         L = 43 | ||||
|         Nq = 4 | ||||
|         Nkv = 1 | ||||
|         scale = 1.0 | ||||
|         mx.random.seed(0) | ||||
|         q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) | ||||
|         k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|         v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.fast.scaled_dot_product_attention( | ||||
|                 q, | ||||
|                 k, | ||||
|                 v, | ||||
|                 scale=scale, | ||||
|                 mask=mx.full((Nq, 2, L), False), | ||||
|             ) | ||||
|  | ||||
|         masks = [ | ||||
|             mx.array(True), | ||||
|             mx.array([True] * (L - 10) + [False] * 10), | ||||
|             mx.random.uniform(shape=(Nq, 1, L)) > 0.2, | ||||
|             mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, | ||||
|         ] | ||||
|         for m in masks: | ||||
|             ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) | ||||
|             out = mx.fast.scaled_dot_product_attention( | ||||
|                 q, | ||||
|                 k, | ||||
|                 v, | ||||
|                 scale=scale, | ||||
|                 mask=m, | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|         L = 4096 | ||||
|         scale = 1.0 | ||||
|         mx.random.seed(0) | ||||
|         q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) | ||||
|         k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|         v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|  | ||||
|         masks = [ | ||||
|             mx.array(True), | ||||
|             mx.array([True] * (L - 10) + [False] * 10), | ||||
|             mx.random.uniform(shape=(Nq, 1, L)) > 0.2, | ||||
|             mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, | ||||
|         ] | ||||
|         for m in masks: | ||||
|             ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) | ||||
|             out = mx.fast.scaled_dot_product_attention( | ||||
|                 q, | ||||
|                 k, | ||||
|                 v, | ||||
|                 scale=scale, | ||||
|                 mask=m, | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main(failfast=True) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun