mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix batched vector sdpa (#2152)
This commit is contained in:
		| @@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|             out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) | ||||
|             self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|     def test_sdpa_vector_batched(self): | ||||
|         D = 64 | ||||
|         q = mx.random.normal(shape=(2, 1, 3, D)) | ||||
|         k = mx.random.normal(shape=(2, 1, 3, D)) | ||||
|         v = mx.random.normal(shape=(2, 1, 3, D)) | ||||
|  | ||||
|         out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) | ||||
|         ref = mlx_ref_attn(q, k, v) | ||||
|         self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|         q = mx.random.normal(shape=(2, 4, 3, D)) | ||||
|         out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) | ||||
|         ref = mlx_ref_attn(q, k, v) | ||||
|         self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|         q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2) | ||||
|         out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) | ||||
|         ref = mlx_ref_attn(q, k, v) | ||||
|         self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|         k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2) | ||||
|         out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) | ||||
|         ref = mlx_ref_attn(q, k, v) | ||||
|         self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|         q = mx.random.normal(shape=(2, 4, 3, D)) | ||||
|         k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2) | ||||
|         v = mx.random.normal(shape=(2, 2, 3, D)) | ||||
|         out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) | ||||
|         ref = mlx_ref_attn(q, k, v) | ||||
|         self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|         q = mx.random.normal(shape=(2, 4, 3, D)) | ||||
|         k = mx.random.normal(shape=(2, 1, 3, D)) | ||||
|         v = mx.random.normal(shape=(2, 1, 3, D)) | ||||
|         mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1) | ||||
|         out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) | ||||
|         ref = mlx_ref_attn(q, k, v, mask=mask) | ||||
|         self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|  | ||||
| class TestSDPA(mlx_tests.MLXTestCase): | ||||
|     @property | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun