mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims * fix test * batched cpu * add batched template param * refactor metal quantized.cpp
This commit is contained in:
		| @@ -117,19 +117,24 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         tests = product( | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 4, 8],  # bits | ||||
|             [512, 1024],  # M | ||||
|             [512, 1024],  # N | ||||
|             [512, 1024, 67],  # M | ||||
|             [64, 128, 512, 1024],  # N | ||||
|             [0, 1, 3, 8],  # B | ||||
|         ) | ||||
|         for group_size, bits, M, N in tests: | ||||
|             with self.subTest(shape=(M, N), group_size=group_size, bits=bits): | ||||
|                 x = mx.random.normal(shape=(1, N), key=k1) | ||||
|                 w = mx.random.normal(shape=(M, N), key=k2) | ||||
|         for group_size, bits, M, N, B in tests: | ||||
|             if group_size > N: | ||||
|                 continue | ||||
|             with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): | ||||
|                 x_shape = (3, 1, N) if B == 0 else (B, 1, N) | ||||
|                 w_shape = (M, N) if B == 0 else (B, M, N) | ||||
|                 x = mx.random.normal(shape=x_shape, key=k1) | ||||
|                 w = mx.random.normal(shape=w_shape, key=k2) | ||||
|                 w_q, scales, biases = mx.quantize(w, group_size, bits) | ||||
|                 w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) | ||||
|                 y_q = mx.quantized_matmul( | ||||
|                     x, w_q, scales, biases, True, group_size, bits | ||||
|                 ) | ||||
|                 y_hat = x @ w_hat.T | ||||
|                 y_hat = x @ mx.swapaxes(w_hat, -1, -2) | ||||
|                 self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                 self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
| @@ -140,12 +145,15 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 4, 8],  # bits | ||||
|             [512, 1024],  # M | ||||
|             [512, 1024],  # N | ||||
|             [512, 1024, 67],  # N | ||||
|             [0, 1, 3, 8],  # B | ||||
|         ) | ||||
|         for group_size, bits, M, N in tests: | ||||
|             with self.subTest(shape=(M, N), group_size=group_size, bits=bits): | ||||
|                 x = mx.random.normal(shape=(1, N), key=k1) | ||||
|                 w = mx.random.normal(shape=(N, M), key=k2) | ||||
|         for group_size, bits, M, N, B in tests: | ||||
|             with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): | ||||
|                 x_shape = (1, N) if B == 0 else (B, 1, N) | ||||
|                 w_shape = (N, M) if B == 0 else (B, N, M) | ||||
|                 x = mx.random.normal(shape=x_shape, key=k1) | ||||
|                 w = mx.random.normal(shape=w_shape, key=k2) | ||||
|                 w_q, scales, biases = mx.quantize(w, group_size, bits) | ||||
|                 w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) | ||||
|                 y_q = mx.quantized_matmul( | ||||
| @@ -172,37 +180,39 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         mx.eval(y) | ||||
|  | ||||
|     def test_small_matrix(self): | ||||
|         w = mx.random.normal(shape=(8, 256)) | ||||
|         w_q, scales, biases = mx.quantize(w) | ||||
|         w_hat = mx.dequantize(w_q, scales, biases) | ||||
|         for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]: | ||||
|             with self.subTest(w_shape=w_shape): | ||||
|                 w = mx.random.normal(shape=(w_shape)) | ||||
|                 w_q, scales, biases = mx.quantize(w) | ||||
|                 w_hat = mx.dequantize(w_q, scales, biases) | ||||
|  | ||||
|         # Test qmv | ||||
|         x = mx.random.normal(shape=(1, 256)) | ||||
|         y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) | ||||
|         y_hat = x @ w_hat.T | ||||
|         self.assertEqual(y_q.shape, y_hat.shape) | ||||
|         self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|                 # Test qmv | ||||
|                 x = mx.random.normal(shape=(3, 1, 256)) | ||||
|                 y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) | ||||
|                 y_hat = x @ mx.swapaxes(w_hat, -1, -2) | ||||
|                 self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                 self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
|         # Test qmm_t | ||||
|         x = mx.random.normal(shape=(10, 256)) | ||||
|         y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) | ||||
|         y_hat = x @ w_hat.T | ||||
|         self.assertEqual(y_q.shape, y_hat.shape) | ||||
|         self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|                 # Test qmm_t | ||||
|                 x = mx.random.normal(shape=(3, 10, 256)) | ||||
|                 y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) | ||||
|                 y_hat = x @ mx.swapaxes(w_hat, -1, -2) | ||||
|                 self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                 self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
|         # Test qmv | ||||
|         x = mx.random.normal(shape=(1, 8)) | ||||
|         y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) | ||||
|         y_hat = x @ w_hat | ||||
|         self.assertEqual(y_q.shape, y_hat.shape) | ||||
|         self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|                 # Test qvm | ||||
|                 x = mx.random.normal(shape=(3, 1, 8)) | ||||
|                 y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) | ||||
|                 y_hat = x @ w_hat | ||||
|                 self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                 self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
|         # Test qmm | ||||
|         x = mx.random.normal(shape=(10, 8)) | ||||
|         y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) | ||||
|         y_hat = x @ w_hat | ||||
|         self.assertEqual(y_q.shape, y_hat.shape) | ||||
|         self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|                 # Test qmm | ||||
|                 x = mx.random.normal(shape=(3, 10, 8)) | ||||
|                 y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) | ||||
|                 y_hat = x @ w_hat | ||||
|                 self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                 self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
|     def test_non_multiples(self): | ||||
|         w = mx.random.normal(shape=(33, 256)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron