From 890fdd1ef09c59028158de8c8d874b7bd8468708 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Mon, 2 Dec 2024 16:19:29 -0800 Subject: [PATCH] start --- mlx/backend/metal/kernels/quantized.h | 12 +++-- mlx/backend/metal/quantized.cpp | 7 +-- mlx/fast.cpp | 3 +- python/tests/test_quantized.py | 63 +++++++++++++++++++-------- 4 files changed, 59 insertions(+), 26 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index ad53f08232..b2d5e92143 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2006,7 +2006,8 @@ template uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr T eps = T(1e-7); - constexpr int simd_size = 32; + constexpr bool use_quads = group_size <= 16; + constexpr int simd_size = use_quads ? 4 : 32; constexpr T n_bins = (1 << bits) - 1; constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int values_per_reduce = group_size / simd_size; @@ -2038,8 +2039,13 @@ template w_max = max(w_max, val); } - w_min = simd_min(w_min); - w_max = simd_max(w_max); + if (use_quads) { + w_min = quad_min(w_min); + w_max = quad_max(w_max); + } else { + w_min = simd_min(w_min); + w_max = simd_max(w_max); + } T scale = max((w_max - w_min) / n_bins, eps); bool side = abs(w_min) > abs(w_max); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 4454476c99..08f6d5ac6d 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -308,13 +308,13 @@ void qmm_op( group_dims = MTL::Size(simdgroup_size, 1, 1); grid_dims = MTL::Size((O + bo - 1) / bo, B, N); quad = true; - } else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { + } else if (B > 0 && O % 8 == 0 && D % 512 == 0 && D >= 512) { name += "qmv_fast"; int bo = 8; int bd = 32; group_dims = MTL::Size(bd, 2, 1); grid_dims = MTL::Size(O / bo, B, N); - } else if (B < 6) { + } else if (B > 0) { name += "qmv"; int bo = 8; int bd = 32; @@ -445,7 +445,8 @@ void fast::AffineQuantize::eval_gpu( // Treat uint32 as uint8 in kernel constexpr int uint8_per_uint32 = 4; - constexpr int simd_size = 32; + // Use quads for small group sizes + int simd_size = group_size_ <= 16 ? 4 : 32; int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_; int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size; size_t nthreads = diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 731912d699..77b3fed99e 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -729,7 +729,8 @@ std::tuple affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { auto s = to_stream(s_); - if (group_size != 32 && group_size != 64 && group_size != 128) { + if (group_size != 16 && group_size != 32 && group_size != 64 && + group_size != 128) { std::ostringstream msg; msg << "[quantize] The requested group size " << group_size << " is not supported. The supported group sizes are 64 and 128."; diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 7d4ba99493..bdba254d8b 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -10,7 +10,7 @@ import mlx_tests class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) - for gs in [32, 64, 128]: + for gs in [16, 32, 64, 128]: for b in [2, 3, 6, 4, 8]: with self.subTest(gs=gs, b=b): w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) @@ -115,7 +115,7 @@ class TestQuantized(mlx_tests.MLXTestCase): key = mx.random.key(0) k1, k2 = mx.random.split(key) tests = product( - [128, 64, 32], # group_size + [128, 64, 32, 16], # group_size [2, 3, 4, 6, 8], # bits [512, 1024, 67], # M [64, 128, 512, 1024], # N @@ -205,39 +205,64 @@ class TestQuantized(mlx_tests.MLXTestCase): mx.eval(y) def test_small_matrix(self): - for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]: + # We are going to need some way of doing this when we're loading a block scale / bias + # For 6 bit scales/biases we'll have to load them from uint16s I guess? + bits = 8 + group_size = 16 + # for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]: + for w_shape in [(32, 4096)]: with self.subTest(w_shape=w_shape): w = mx.random.normal(shape=(w_shape)) - w_q, scales, biases = mx.quantize(w) - w_hat = mx.dequantize(w_q, scales, biases) + w_q, scales, biases = mx.quantize(w, bits=bits, group_size=group_size) + w_hat = mx.dequantize( + w_q, scales, biases, bits=bits, group_size=group_size + ) # Test qmv - x = mx.random.normal(shape=(3, 1, 256)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + x = mx.random.normal(shape=(3, 1, 4096)) + y_q = mx.quantized_matmul( + x, + w_q, + scales, + biases, + transpose=True, + bits=bits, + group_size=group_size, + ) y_hat = x @ mx.swapaxes(w_hat, -1, -2) self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) # Test qmm_t - x = mx.random.normal(shape=(3, 10, 256)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + x = mx.random.normal(shape=(3, 10, 4096)) + y_q = mx.quantized_matmul( + x, + w_q, + scales, + biases, + transpose=True, + bits=bits, + group_size=group_size, + ) + print("y_q", y_q) y_hat = x @ mx.swapaxes(w_hat, -1, -2) + print("y_hat", y_hat) self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) # Test qvm - x = mx.random.normal(shape=(3, 1, 8)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) - y_hat = x @ w_hat - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # x = mx.random.normal(shape=(3, 1, 8)) + # y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False, bits=bits, group_size=group_size) + # y_hat = x @ w_hat + # self.assertEqual(y_q.shape, y_hat.shape) + # self.assertLess((y_q - y_hat).abs().max(), 1e-3) # Test qmm - x = mx.random.normal(shape=(3, 10, 8)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) - y_hat = x @ w_hat - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # x = mx.random.normal(shape=(3, 10, 8)) + # y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False, bits=bits, group_size=group_size) + # y_hat = x @ w_hat + # self.assertEqual(y_q.shape, y_hat.shape) + # self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_non_multiples(self): w = mx.random.normal(shape=(33, 256))