mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	add test
This commit is contained in:
		| @@ -1,5 +1,6 @@ | ||||
| import math | ||||
| import unittest | ||||
| from itertools import product | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
| @@ -113,61 +114,63 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|         R = 1 | ||||
|         Dk = 128 | ||||
|         scale = float(1.0 / np.sqrt(128.0)) | ||||
|         q_npy = np.random.normal(0.0, 1.0, (1, 32, R, Dk)).astype(np.float32) | ||||
|         k_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) | ||||
|         v_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) | ||||
|         q = mx.random.normal(shape=(1, 32, R, Dk)) | ||||
|         k = mx.random.normal(shape=(1, 32, L, Dk)) | ||||
|         v = mx.random.normal(shape=(1, 32, L, Dk)) | ||||
|  | ||||
|         q_mlx = mx.array(q_npy) | ||||
|         k_mlx = mx.array(k_npy) | ||||
|         v_mlx = mx.array(v_npy) | ||||
|         reference = mlx_primitives_sdpa(q, k, v, scale) | ||||
|  | ||||
|         reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale) | ||||
|         o = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) | ||||
|  | ||||
|         o_mlx = mx.fast.scaled_dot_product_attention( | ||||
|             q_mlx, k_mlx, v_mlx, scale=scale, mask=None | ||||
|         ) | ||||
|  | ||||
|         self.assertListEqual(list(reference.shape), list(o_mlx.shape)) | ||||
|         self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4)) | ||||
|         self.assertListEqual(list(reference.shape), list(o.shape)) | ||||
|         self.assertTrue(mx.allclose(o, reference, atol=1e-4)) | ||||
|  | ||||
|         B = 1 | ||||
|         H = 32 | ||||
|         dtypes = [np.float32] | ||||
|         if self.is_apple_silicon: | ||||
|             dtypes.append(np.half) | ||||
|  | ||||
|         for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: | ||||
|             for DO_GQA in [0, 1]: | ||||
|                 for DTYPE in dtypes: | ||||
|                     n_kv_heads = 8 if DO_GQA else 32 | ||||
|                     q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE) | ||||
|                     k_npy = np.random.normal( | ||||
|                         0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) | ||||
|                     ).astype(DTYPE) | ||||
|                     v_npy = np.random.normal( | ||||
|                         0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) | ||||
|                     ).astype(DTYPE) | ||||
|         tests = product( | ||||
|             [1, 7, 9, 32, 63, 67, 129, 2000],  # sequence length | ||||
|             [False, True],  # gqa | ||||
|             [mx.float32, mx.float16], | ||||
|             [4, 8],  # bits | ||||
|         ) | ||||
|         for sequence_length, do_gqa, dtype, bits in tests: | ||||
|             with self.subTest( | ||||
|                 sequence_length=sequence_length, gqa=do_gqa, dtype=dtype, bits=bits | ||||
|             ): | ||||
|                 n_kv_heads = 8 if do_gqa else 32 | ||||
|                 q = mx.random.normal(shape=(B, H, R, Dk), dtype=dtype) | ||||
|                 k = mx.random.normal( | ||||
|                     shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype | ||||
|                 ) | ||||
|                 v = mx.random.normal( | ||||
|                     shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype | ||||
|                 ) | ||||
|  | ||||
|                     q_mlx = mx.array(q_npy) | ||||
|                     k_mlx = mx.array(k_npy) | ||||
|                     v_mlx = mx.array(v_npy) | ||||
|                 k_q = mx.quantize(k, bits=bits) | ||||
|                 v_q = mx.quantize(v, bits=bits) | ||||
|                 k_d = mx.dequantize(*k_q, bits=bits) | ||||
|                 v_d = mx.dequantize(*v_q, bits=bits) | ||||
|  | ||||
|                     reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) | ||||
|                     o_mlx = mx.fast.scaled_dot_product_attention( | ||||
|                         q_mlx, k_mlx, v_mlx, scale=scale | ||||
|                     ) | ||||
|                 reference = mlx_primitives_sdpa_with_gqa(q, k_d, v_d, scale) | ||||
|                 o = mx.fast.scaled_dot_product_attention(q, k_d, v_d, scale=scale) | ||||
|                 o_q = mx.fast.quantized_scaled_dot_product_attention( | ||||
|                     q, *k_q, *v_q, scale=scale, bits=bits | ||||
|                 ) | ||||
|  | ||||
|                     self.assertListEqual(list(reference.shape), list(o_mlx.shape)) | ||||
|                     rtol = 1e-5 | ||||
|                     atol = 1e-1 | ||||
|                 self.assertListEqual(list(reference.shape), list(o.shape)) | ||||
|                 rtol = 1e-5 | ||||
|                 atol = 1e-1 | ||||
|  | ||||
|                     if SEQUENCE_LENGTH > 500: | ||||
|                         rtol = 1e-2 | ||||
|                 if sequence_length > 500: | ||||
|                     rtol = 1e-2 | ||||
|  | ||||
|                     if DTYPE == np.half: | ||||
|                         rtol = 1e-2 | ||||
|                 if dtype == mx.float16: | ||||
|                     rtol = 1e-2 | ||||
|  | ||||
|                     self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) | ||||
|                 # np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol) | ||||
|                 self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol)) | ||||
|                 self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol)) | ||||
|  | ||||
|         q = mx.random.normal(shape=(1, 32, 1, Dk)) | ||||
|         k = mx.random.normal(shape=(1, 32, 32, Dk)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron