From 5597fa089c291f2db9a52e95dcc6decffbff6ec3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 25 Jul 2025 11:50:24 -0700 Subject: [PATCH] Fix qvm splitk (#2415) --- mlx/backend/metal/quantized.cpp | 13 +++++++++++-- python/tests/test_quantized.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 53f1c96f3..39c208c03 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -265,9 +265,15 @@ void qvm_split_k( MTL::Size group_dims = MTL::Size(bk, 2, 1); MTL::Size grid_dims = MTL::Size(M, N / bn, B); - int x_batch_ndims = x.ndim() - 2; auto x_shape = x.shape(); auto x_strides = x.strides(); + if (x_shape.size() == 1) { + x_shape.insert(x_shape.begin(), 1); + x_strides.insert(x_strides.begin(), 0); + } + + int x_ndim = x_shape.size(); + int x_batch_ndims = x_ndim - 2; int w_batch_ndims = w.ndim() - 2; auto w_shape = w.shape(); auto w_strides = w.strides(); @@ -278,7 +284,7 @@ void qvm_split_k( x_shape.insert(x_shape.end() - 2, split_k); x_shape.back() /= split_k; x_strides.insert(x_strides.end() - 2, split_D); - x_strides[x.ndim() - 1] = split_D; + x_strides[x_ndim - 1] = split_D; x_batch_ndims += 1; w_shape.insert(w_shape.end() - 2, split_k); @@ -291,6 +297,9 @@ void qvm_split_k( int final_block_size = K - (split_k - 1) * split_D; auto temp_shape = out.shape(); + if (temp_shape.size() == 1) { + temp_shape.insert(temp_shape.begin(), 1); + } temp_shape.insert(temp_shape.end() - 2, split_k); array intermediate(temp_shape, x.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 2c62c6307..de43ec26d 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -220,6 +220,19 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 2e-3) + # Test with 1D vector + group_size = 32 + bits = 8 + N = 2048 + x = 1e-1 * mx.random.normal(shape=(N,), key=k1) + w = 1e-1 * mx.random.normal(shape=(N, N), key=k2) + w_q, scales, biases = mx.quantize(w, group_size, bits) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) + y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 2e-3) + def test_throw(self): x = mx.random.normal(shape=(10, 512)) w = mx.random.normal(shape=(32, 512))