From 039da779d141cb92a673f0081519e13eb4b50d82 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 4 Apr 2024 11:52:12 -0700 Subject: [PATCH] No quant reshape (#957) * precise option on cpu * remove print * remove reshape in quant matmul * no quant reshape --- mlx/ops.cpp | 22 ++++------------------ python/tests/test_quantized.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 1c9e930fc..a06bb4637 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3057,7 +3057,6 @@ array quantized_matmul( int bits /* = 4 */, StreamOrDevice s /* = {} */) { array x = in_x; - if (w.dtype() != uint32) { std::ostringstream msg; msg << "[quantized_matmul] The weight matrix should be uint32 " @@ -3074,12 +3073,6 @@ array quantized_matmul( // Keep x's batch dimensions to reshape it back after the matmul auto original_shape = x.shape(); int x_inner_dims = original_shape.back(); - original_shape.pop_back(); - - // Reshape x into a matrix if it isn't already one - if (x.ndim() != 2) { - x = reshape(x, {-1, x_inner_dims}, s); - } if (scales.ndim() != 2 || scales.shape() != biases.shape()) { std::ostringstream msg; @@ -3122,9 +3115,10 @@ array quantized_matmul( << " and biases.dtype() == " << biases.dtype(); throw std::invalid_argument(msg.str()); } - - auto out = array( - {x.shape(0), w_outer_dims}, + std::vector inputs; + original_shape.back() = w_outer_dims; + return array( + std::move(original_shape), dtype, std::make_shared( to_stream(s), group_size, bits, transpose), @@ -3132,14 +3126,6 @@ array quantized_matmul( w, astype(scales, dtype, s), astype(biases, dtype, s)}); - - // If needed reshape x to the original batch shape - if (original_shape.size() != 1) { - original_shape.push_back(w_outer_dims); - out = reshape(out, std::move(original_shape), s); - } - - return out; } std::tuple quantize( diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index f623a0dca..57e369aa3 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -47,6 +47,36 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qmm_vjp(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + + bits = 8 + group_size = 64 + M = 64 + N = 1024 + K = 512 + + x = mx.random.normal(shape=(2, M, K), key=k1) + c = mx.ones(shape=(2, M, N)) + + transposes = [True, False] + for transposed in transposes: + w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) + w_q, scales, biases = mx.quantize(w, group_size, bits) + + def fn(x): + return mx.quantized_matmul( + x, w_q, scales, biases, transposed, group_size, bits + ) + + _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,)) + + expected_out = mx.quantized_matmul( + c, w_q, scales, biases, not transposed, group_size, bits + ) + self.assertTrue(mx.allclose(vjp_out[0], expected_out)) + def test_qmm_shapes(self): key = mx.random.key(0) k1, k2 = mx.random.split(key)