diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index ecbb65cb4..51cd0cfb1 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -72,6 +72,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits): quant_matmul = { + "quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2), + "quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4), + "quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8), "quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2), "quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4), "quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8), @@ -84,6 +87,15 @@ quant_matmul = { "quant_matmul_128_8": partial( _quant_matmul, transpose=False, group_size=128, bits=8 ), + "quant_matmul_t_32_2": partial( + _quant_matmul, transpose=True, group_size=32, bits=2 + ), + "quant_matmul_t_32_4": partial( + _quant_matmul, transpose=True, group_size=32, bits=4 + ), + "quant_matmul_t_32_8": partial( + _quant_matmul, transpose=True, group_size=32, bits=8 + ), "quant_matmul_t_64_2": partial( _quant_matmul, transpose=True, group_size=64, bits=2 ), diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index 0ac2c2b61..8482468a0 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -119,6 +119,12 @@ void _qmm_dispatch_typed( switch (bits) { case 2: { switch (group_size) { + case 32: + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } case 64: if (transposed_w) { return _qmm_t(result, x, w, scales, biases, M, N, K); @@ -135,6 +141,12 @@ void _qmm_dispatch_typed( } case 4: { switch (group_size) { + case 32: + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } case 64: if (transposed_w) { return _qmm_t(result, x, w, scales, biases, M, N, K); @@ -151,6 +163,12 @@ void _qmm_dispatch_typed( } case 8: { switch (group_size) { + case 32: + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } case 64: if (transposed_w) { return _qmm_t(result, x, w, scales, biases, M, N, K); diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 294dbab5c..5bf3142d4 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -142,10 +142,11 @@ template quantize( int group_size /* = 64 */, int bits /* = 4 */, StreamOrDevice s /* = {} */) { - if (group_size != 64 && group_size != 128) { + if (group_size != 32 && group_size != 64 && group_size != 128) { std::ostringstream msg; msg << "[quantize] The requested group size " << group_size << " is not supported. The supported group sizes are 64 and 128."; diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index a23f26454..2682cbadc 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -140,7 +140,6 @@ class TestLosses(mlx_tests.MLXTestCase): probs, targets, with_logits=False, reduction="none" ) expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144]) - print(losses_none, expected_none) self.assertTrue(mx.allclose(losses_none, expected_none)) # Test with reduction 'mean' diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 23200a9fa..b068aa6ee 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -10,18 +10,19 @@ import mlx_tests class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) - for b in [2, 4, 8]: - w_q, scales, biases = mx.quantize(w, 64, b) - w_hat = mx.dequantize(w_q, scales, biases, 64, b) - errors = (w - w_hat).abs().reshape(*scales.shape, -1) - eps = 1e-6 - self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all()) + for gs in [32, 64, 128]: + for b in [2, 4, 8]: + w_q, scales, biases = mx.quantize(w, gs, b) + w_hat = mx.dequantize(w_q, scales, biases, gs, b) + errors = (w - w_hat).abs().reshape(*scales.shape, -1) + eps = 1e-6 + self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all()) def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) tests = product( - [128, 64], # group_size + [128, 64, 32], # group_size [2, 4, 8], # bits [8, 32, 33, 64], # M [512, 1024], # N @@ -75,7 +76,7 @@ class TestQuantized(mlx_tests.MLXTestCase): key = mx.random.key(0) k1, k2 = mx.random.split(key) tests = product( - [128, 64], # group_size + [128, 64, 32], # group_size [2, 4, 8], # bits [512, 1024], # M [512, 1024], # N @@ -97,7 +98,7 @@ class TestQuantized(mlx_tests.MLXTestCase): key = mx.random.key(0) k1, k2 = mx.random.split(key) tests = product( - [128, 64], # group_size + [128, 64, 32], # group_size [2, 4, 8], # bits [512, 1024], # M [512, 1024], # N