From 40c108766b146453fea2d1662382658334d069c6 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 12 Feb 2024 18:54:21 -0800 Subject: [PATCH] Quantized matmul fix (#677) * Fix qmv for small or unaligned matrices * Fix qmm --- mlx/backend/metal/kernels/quantized.metal | 24 ++++++--- mlx/backend/metal/quantized.cpp | 2 +- python/tests/test_quantized.py | 64 +++++++++++++++++++++++ 3 files changed, 81 insertions(+), 9 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 0de84093d..c2bfba9f9 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -39,11 +39,12 @@ template ::acc_t U; threadgroup U scales_block[BM * groups_per_block]; @@ -66,12 +67,19 @@ template = out_vec_size) { + return; + } + // Loop over in_vec in blocks of colgroup for (int i=0; i; using loader_x_t = mlx::steel::BlockLoader; - threadgroup T scales_block[BN * groups_per_block]; threadgroup T biases_block[BN * groups_per_block]; threadgroup T Xs[BM * BK]; @@ -313,7 +320,7 @@ template = K) { + if (num_k < BK) { for (int wo=0; wo& inputs, array& out) { int bo = std::min(32, O); int bd = 32; MTL::Size group_dims = MTL::Size(bd, bo, 1); - MTL::Size grid_dims = MTL::Size(1, O / bo, B); + MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); set_array_buffer(compute_encoder, w, 0); set_array_buffer(compute_encoder, scales, 1); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index b068aa6ee..fad2ba51c 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -165,6 +165,70 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_non_multiples(self): + w = mx.random.normal(shape=(33, 256)) + w_q, scales, biases = mx.quantize(w) + w_hat = mx.dequantize(w_q, scales, biases) + + # Test qmv + x = mx.random.normal(shape=(1, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm_t + x = mx.random.normal(shape=(10, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qvm + x = mx.random.normal(shape=(1, 33)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm + x = mx.random.normal(shape=(10, 33)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Smaller than 8 + w = mx.random.normal(shape=(3, 256)) + w_q, scales, biases = mx.quantize(w) + w_hat = mx.dequantize(w_q, scales, biases) + + # Test qmv + x = mx.random.normal(shape=(1, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm_t + x = mx.random.normal(shape=(10, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qvm + x = mx.random.normal(shape=(1, 3)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm + x = mx.random.normal(shape=(10, 3)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + if __name__ == "__main__": unittest.main()