From c15fe3e61bd6e171aa45ab9ef4e920f09bd8faff Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 16 Jan 2024 00:46:21 -0800 Subject: [PATCH] Allow arbitrary first dimension in quantization kernels. (#458) * Allow arbitrary first dim on qmm_t and qmv * Allow arbitrary first dim on qmm and qvm * Specialized aligned vs unaligned case * Add more checks for valid quantizations --- benchmarks/python/comparative/bench_mlx.py | 44 +++++-- mlx/backend/metal/kernels/quantized.metal | 135 +++++++++++++++------ mlx/backend/metal/quantized.cpp | 10 +- mlx/ops.cpp | 40 ++++-- python/tests/test_quantized.py | 35 +++++- tests/ops_tests.cpp | 8 +- 6 files changed, 206 insertions(+), 66 deletions(-) diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 15137a8a7..8b96840f7 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -60,20 +60,48 @@ def matmul(x, y): mx.eval(ys) -def _quant_matmul(x, w, s, b, group_size, bits): +def _quant_matmul(x, w, s, b, transpose, group_size, bits): ys = [] for i in range(10): - ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits)) + ys.append( + mx.quantized_matmul( + x, w, s, b, transpose=transpose, group_size=group_size, bits=bits + ) + ) mx.eval(ys) quant_matmul = { - "quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2), - "quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4), - "quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8), - "quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2), - "quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4), - "quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8), + "quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2), + "quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4), + "quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8), + "quant_matmul_128_2": partial( + _quant_matmul, transpose=False, group_size=128, bits=2 + ), + "quant_matmul_128_4": partial( + _quant_matmul, transpose=False, group_size=128, bits=4 + ), + "quant_matmul_128_8": partial( + _quant_matmul, transpose=False, group_size=128, bits=8 + ), + "quant_matmul_t_64_2": partial( + _quant_matmul, transpose=True, group_size=64, bits=2 + ), + "quant_matmul_t_64_4": partial( + _quant_matmul, transpose=True, group_size=64, bits=4 + ), + "quant_matmul_t_64_8": partial( + _quant_matmul, transpose=True, group_size=64, bits=8 + ), + "quant_matmul_t_128_2": partial( + _quant_matmul, transpose=True, group_size=128, bits=2 + ), + "quant_matmul_t_128_4": partial( + _quant_matmul, transpose=True, group_size=128, bits=4 + ), + "quant_matmul_t_128_8": partial( + _quant_matmul, transpose=True, group_size=128, bits=8 + ), } diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 9627dc3c0..9cb54e0f8 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -154,10 +154,13 @@ template +template [[kernel]] void qmm_t( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], @@ -257,6 +260,7 @@ template (wi & bitmask) + bias; - wi >>= bits; + #pragma clang loop unroll(full) + for (int t=0; t(wi & bitmask) + bias; + wi >>= bits; + } + } else { + #pragma clang loop unroll(full) + for (int t=0; t(wi & bitmask) + bias; + wi >>= bits; + } } } } @@ -324,8 +355,8 @@ template = K) { + for (int wo=0; wo(wi & bitmask) + bias; - wi >>= bits; + #pragma clang loop unroll(full) + for (int t=0; t(wi & bitmask) + bias; + wi >>= bits; + } + } else { + #pragma clang loop unroll(full) + for (int t=0; t(wi & bitmask) + bias; + wi >>= bits; + } } } } @@ -511,9 +569,9 @@ instantiate_qvm_types( 64, 2) instantiate_qvm_types( 64, 4) instantiate_qvm_types( 64, 8) -#define instantiate_qmm_t(name, itype, group_size, bits) \ - template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \ - [[kernel]] void qmm_t( \ +#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \ + template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \ + [[kernel]] void qmm_t( \ const device itype* x [[buffer(0)]], \ const device uint32_t* w [[buffer(1)]], \ const device itype* scales [[buffer(2)]], \ @@ -528,9 +586,12 @@ instantiate_qvm_types( 64, 8) uint simd_lid [[thread_index_in_simdgroup]]); #define instantiate_qmm_t_types(group_size, bits) \ - instantiate_qmm_t(float32, float, group_size, bits) \ - instantiate_qmm_t(float16, half, group_size, bits) \ - instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits) + instantiate_qmm_t(float32, float, group_size, bits, false) \ + instantiate_qmm_t(float16, half, group_size, bits, false) \ + instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \ + instantiate_qmm_t(float32, float, group_size, bits, true) \ + instantiate_qmm_t(float16, half, group_size, bits, true) \ + instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) instantiate_qmm_t_types(128, 2) instantiate_qmm_t_types(128, 4) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 3997037e5..34c3e6d4f 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -52,7 +52,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - int bo = 32; + int bo = std::min(32, O); int bd = 32; MTL::Size group_dims = MTL::Size(bd, bo, 1); MTL::Size grid_dims = MTL::Size(1, O / bo, B); @@ -72,7 +72,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { else { std::ostringstream kname; kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" - << bits_; + << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0); // Encode and dispatch kernel auto compute_encoder = d.get_command_encoder(s.index); @@ -85,7 +85,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int bn = 32; int bk = 64; MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1); + MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1); set_array_buffer(compute_encoder, x, 0); set_array_buffer(compute_encoder, w, 1); @@ -110,10 +110,10 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - int bo = 32; + int bo = std::min(32, O); int bd = 32; MTL::Size group_dims = MTL::Size(bd, bo, 1); - MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B); + MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); set_array_buffer(compute_encoder, x, 0); set_array_buffer(compute_encoder, w, 1); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c185568ff..6840afb62 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2801,16 +2801,25 @@ std::tuple quantize( int group_size /* = 64 */, int bits /* = 4 */, StreamOrDevice s /* = {} */) { + if (group_size != 64 && group_size != 128) { + std::ostringstream msg; + msg << "[quantize] The requested group size " << group_size + << " is not supported. The supported group sizes are 64 and 128."; + throw std::invalid_argument(msg.str()); + } + + if (bits != 2 && bits != 4 && bits != 8) { + std::ostringstream msg; + msg << "[quantize] The requested number of bits " << bits + << " is not supported. The supported bits are 2, 4 and 8."; + throw std::invalid_argument(msg.str()); + } + if (w.ndim() != 2) { throw std::invalid_argument("[quantize] Only matrices supported for now"); } - if ((w.shape(0) % 32) != 0) { - throw std::invalid_argument( - "[quantize] All dimensions should be divisible by 32 for now"); - } - - if ((w.shape(-1) % group_size) != 0) { + if ((w.shape(1) % group_size) != 0) { std::ostringstream msg; msg << "[quantize] The last dimension of the matrix needs to be divisible by " << "the quantization group size " << group_size @@ -2825,6 +2834,20 @@ std::tuple quantize( array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); shifts = reshape(shifts, {1, 1, -1}, s); + // Check that the w matrix will fill up a whole SIMD. + // This is an implementation detail which should be removed in the future but + // at least we bail out early which will result in a nice readable error. + // + // Hopefully nobody is quantizing matrices that small anyway. + if (w.shape(1) < 32 * el_per_int) { + std::ostringstream msg; + msg << "[quantize] The feature dimension (2nd dimension of the matrix) is " + << "too small for quantization. We support >=512 for 2 bits, " + << ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has " + << "shape " << w.shape() << "."; + throw std::invalid_argument(msg.str()); + } + // Compute scales and biases array packed_w = reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s); @@ -2855,11 +2878,6 @@ array dequantize( throw std::invalid_argument("[dequantize] Only matrices supported for now"); } - if ((w.shape(0) % 32) != 0) { - throw std::invalid_argument( - "[dequantize] All dimensions should be divisible by 32 for now"); - } - if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) { throw std::invalid_argument( "[dequantize] Shape of scales and biases does not match the matrix"); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 5f038057d..23200a9fa 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -9,7 +9,7 @@ import mlx_tests class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): - w = mx.random.normal(shape=(128, 128)) + w = mx.random.normal(shape=(128, 512)) for b in [2, 4, 8]: w_q, scales, biases = mx.quantize(w, 64, b) w_hat = mx.dequantize(w_q, scales, biases, 64, b) @@ -131,6 +131,39 @@ class TestQuantized(mlx_tests.MLXTestCase): y = mx.quantized_matmul(x, w_q, scales, biases, True) mx.eval(y) + def test_small_matrix(self): + w = mx.random.normal(shape=(8, 256)) + w_q, scales, biases = mx.quantize(w) + w_hat = mx.dequantize(w_q, scales, biases) + + # Test qmv + x = mx.random.normal(shape=(1, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm_t + x = mx.random.normal(shape=(10, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmv + x = mx.random.normal(shape=(1, 8)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm + x = mx.random.normal(shape=(10, 8)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index b2be73517..1c00f4541 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2308,15 +2308,15 @@ TEST_CASE("test linspace") { TEST_CASE("test quantize dequantize") { auto x1 = ones({128, 1}); - auto x2 = expand_dims(arange(0, 128, float32), 0); + auto x2 = expand_dims(arange(0, 512, float32), 0); auto x = x1 * x2; for (int i = 2; i <= 8; i *= 2) { int el_per_int = 32 / i; auto [x_q, scales, biases] = quantize(x, 128, i); - CHECK_EQ(x_q.shape(), std::vector{128, 128 / el_per_int}); - CHECK_EQ(scales.shape(), std::vector{128, 1}); - CHECK_EQ(biases.shape(), std::vector{128, 1}); + CHECK_EQ(x_q.shape(), std::vector{128, 512 / el_per_int}); + CHECK_EQ(scales.shape(), std::vector{128, 4}); + CHECK_EQ(biases.shape(), std::vector{128, 4}); auto x_hat = dequantize(x_q, scales, biases, 128, i); auto max_diff = max(abs(x - x_hat)).item();