mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences Updated fast attention: GEMM-ified with Steel primitives Uses flash attention 1 for scale correction * more compiler silencing * Address rebase issues * Templatize kernel instantiation, revise cpu bindings * Safer writes to output * Permit batch size > 1 * Numerical fixes for sdpa self attention * Re-enable test, remove unused variable * add benchmarking script * Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
This commit is contained in:
		| @@ -32,9 +32,80 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale): | ||||
|     return mlx_primitives_sdpa(q, k, v, scale) | ||||
|  | ||||
|  | ||||
| class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
| class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase): | ||||
|     def test_fast_sdpa(self): | ||||
|  | ||||
|         # Not yet supported: | ||||
|         # * K pre-transposed in kernel, V pre-transposed in kernel | ||||
|         np.random.seed(0) | ||||
|         R = 20 | ||||
|         L = R | ||||
|         Dk = 64 | ||||
|         H = 3 | ||||
|         scale = float(1.0 / np.sqrt(Dk)) | ||||
|         q_npy = np.random.normal(0.0, 1.0, (1, H, R, Dk)).astype(np.float32) | ||||
|         k_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32) | ||||
|         v_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32) | ||||
|  | ||||
|         q_mlx = mx.array(q_npy) | ||||
|         k_mlx = mx.array(k_npy) | ||||
|         v_mlx = mx.array(v_npy) | ||||
|  | ||||
|         reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale) | ||||
|  | ||||
|         o_mlx = mx.fast.scaled_dot_product_attention( | ||||
|             q_mlx, k_mlx, v_mlx, scale=scale, mask=None | ||||
|         ) | ||||
|  | ||||
|         self.assertListEqual(list(reference.shape), list(o_mlx.shape)) | ||||
|         self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4)) | ||||
|  | ||||
|         dtypes = [np.float32] | ||||
|  | ||||
|         Dk = 64 | ||||
|  | ||||
|         if self.is_apple_silicon: | ||||
|             dtypes.append(np.half) | ||||
|  | ||||
|         for SEQUENCE_LENGTH in [63, 129, 400]: | ||||
|             for DTYPE in dtypes: | ||||
|                 B = 2 | ||||
|                 H = 24 | ||||
|                 n_kv_heads = H | ||||
|                 q_npy = np.random.normal(0.0, 1.0, (B, H, SEQUENCE_LENGTH, Dk)).astype( | ||||
|                     DTYPE | ||||
|                 ) | ||||
|                 k_npy = np.random.normal( | ||||
|                     0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) | ||||
|                 ).astype(DTYPE) | ||||
|                 v_npy = np.random.normal( | ||||
|                     0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) | ||||
|                 ).astype(DTYPE) | ||||
|  | ||||
|                 q_mlx = mx.array(q_npy) | ||||
|                 k_mlx = mx.array(k_npy) | ||||
|                 v_mlx = mx.array(v_npy) | ||||
|  | ||||
|                 reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) | ||||
|                 o_mlx = mx.fast.scaled_dot_product_attention( | ||||
|                     q_mlx, k_mlx, v_mlx, scale=scale | ||||
|                 ) | ||||
|  | ||||
|                 self.assertListEqual(list(reference.shape), list(o_mlx.shape)) | ||||
|                 rtol = 1e-3 | ||||
|                 atol = 1e-2 | ||||
|  | ||||
|                 if SEQUENCE_LENGTH > 500: | ||||
|                     rtol = 1e-2 | ||||
|  | ||||
|                 if DTYPE == np.half: | ||||
|                     rtol = 1e-2 | ||||
|  | ||||
|                 self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) | ||||
|  | ||||
|  | ||||
| class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|     def test_fast_sdpa(self): | ||||
|         # Not yet supported: | ||||
|         # * K pre-transposed in kernel, V pre-transposed in kernel | ||||
|         np.random.seed(0) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Brian Keene
					Brian Keene