mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	SDPA support for small batch (over sequence) queries (#1922)
* batch query sdpa * batch sdpa for query
This commit is contained in:
		| @@ -262,6 +262,61 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|     def test_fast_sdpa_few_query(self): | ||||
|         D = 64 | ||||
|         L = 43 | ||||
|         Lq = 4 | ||||
|         Nq = 8 | ||||
|         Nkv = 1 | ||||
|         scale = 1.0 | ||||
|         mx.random.seed(0) | ||||
|         q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D)) | ||||
|         q = q.swapaxes(1, 2) | ||||
|         k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|         v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|  | ||||
|         masks = [ | ||||
|             mx.array(True), | ||||
|             mx.array([True] * (L - 10) + [False] * 10), | ||||
|             mx.random.uniform(shape=(Nq, 1, L)) > 0.2, | ||||
|             mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, | ||||
|         ] | ||||
|         for m in masks: | ||||
|             ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) | ||||
|             out = mx.fast.scaled_dot_product_attention( | ||||
|                 q, | ||||
|                 k, | ||||
|                 v, | ||||
|                 scale=scale, | ||||
|                 mask=m, | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|         return | ||||
|         L = 4096 | ||||
|         scale = 1.0 | ||||
|         mx.random.seed(0) | ||||
|         q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D)) | ||||
|         k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|         v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) | ||||
|  | ||||
|         masks = [ | ||||
|             mx.array(True), | ||||
|             mx.array([True] * (L - 10) + [False] * 10), | ||||
|             mx.random.uniform(shape=(Nq, 1, L)) > 0.2, | ||||
|             mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, | ||||
|         ] | ||||
|         for m in masks: | ||||
|             ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) | ||||
|             out = mx.fast.scaled_dot_product_attention( | ||||
|                 q, | ||||
|                 k, | ||||
|                 v, | ||||
|                 scale=scale, | ||||
|                 mask=m, | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|     @unittest.skip("Different head and value dims is not enabled") | ||||
|     def test_fast_sdpa_vector_value_dims(self): | ||||
|         D = 192 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun