From 7bb063bcb3000cf9f57078c114fd385577074c57 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 29 Apr 2025 13:03:09 -0700 Subject: [PATCH] Enable vjp for quantized scale and bias (#2129) * Enable vjp for quantized scale and bias * higher tol --- mlx/primitives.cpp | 30 ++++++++++++++++++++++++++++-- python/tests/test_quantized.py | 25 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 3d36f0881..7288a4885 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3056,6 +3056,7 @@ std::vector QuantizedMatmul::vjp( std::vector vjps; // We rely on the fact that w is always 2D so transpose is simple + std::optional dsb = std::nullopt; for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { @@ -3071,9 +3072,34 @@ std::vector QuantizedMatmul::vjp( } // gradient wrt to w_q, scales or biases - else { + else if (arg == 1) { throw std::runtime_error( - "[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet."); + "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); + } else { + if (!dsb) { + auto fc = flatten(cotangents[0], 0, -2, stream()); + auto fx = flatten(primals[0], 0, -2, stream()); + auto dw = transpose_ + ? matmul(swapaxes(fc, -1, -2, stream()), fx, stream()) + : matmul(swapaxes(fx, -1, -2, stream()), fc, stream()); + dsb = unflatten(dw, -1, {-1, group_size_}, stream()); + } + if (arg == 3) { + // biases + vjps.push_back(sum(*dsb, -1, false, stream())); + } else { + // scales + auto s = stream(); + auto wq = dequantize( + primals[1], + ones_like(primals[2], stream()), + zeros_like(primals[3], stream()), + group_size_, + bits_, + stream()); + wq = unflatten(wq, -1, {-1, group_size_}, stream()); + vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); + } } } return vjps; diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index eeefcd94f..60ab421c6 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -549,6 +549,31 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + def test_vjp_scales_biases(self): + mx.random.seed(0) + x = mx.random.normal(shape=(2, 2, 512)) + w = mx.random.normal(shape=(512, 512)) + wq, s, b = mx.quantize(w, bits=4, group_size=64) + + def mm(sb, x, wq): + return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum() + + params = (s, b) + dparams = mx.grad(mm)((s, b), x, wq) + + eps = 8e-3 + # numerical grad check with a few indices + indices = [(0, 0), (11, 4), (22, 7)] + for idx in indices: + for p in [0, 1]: + params[p][idx] += eps + out_up = mm(params, x, wq) + params[p][idx] -= 2 * eps + out_down = mm(params, x, wq) + params[p][idx] += eps + num_ds = (out_up - out_down) / (2 * eps) + self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2) + if __name__ == "__main__": unittest.main()