Enable vjp for quantized scale and bias (#2129)

* Enable vjp for quantized scale and bias

* higher tol
This commit is contained in:
Awni Hannun 2025-04-29 13:03:09 -07:00 committed by GitHub
parent b36dd472bb
commit 7bb063bcb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 2 deletions

View File

@ -3056,6 +3056,7 @@ std::vector<array> QuantizedMatmul::vjp(
std::vector<array> vjps; std::vector<array> vjps;
// We rely on the fact that w is always 2D so transpose is simple // We rely on the fact that w is always 2D so transpose is simple
std::optional<array> dsb = std::nullopt;
for (auto arg : argnums) { for (auto arg : argnums) {
// gradient wrt to x // gradient wrt to x
if (arg == 0) { if (arg == 0) {
@ -3071,9 +3072,34 @@ std::vector<array> QuantizedMatmul::vjp(
} }
// gradient wrt to w_q, scales or biases // gradient wrt to w_q, scales or biases
else { else if (arg == 1) {
throw std::runtime_error( throw std::runtime_error(
"[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet."); "[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
} else {
if (!dsb) {
auto fc = flatten(cotangents[0], 0, -2, stream());
auto fx = flatten(primals[0], 0, -2, stream());
auto dw = transpose_
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
dsb = unflatten(dw, -1, {-1, group_size_}, stream());
}
if (arg == 3) {
// biases
vjps.push_back(sum(*dsb, -1, false, stream()));
} else {
// scales
auto s = stream();
auto wq = dequantize(
primals[1],
ones_like(primals[2], stream()),
zeros_like(primals[3], stream()),
group_size_,
bits_,
stream());
wq = unflatten(wq, -1, {-1, group_size_}, stream());
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
}
} }
} }
return vjps; return vjps;

View File

@ -549,6 +549,31 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
def test_vjp_scales_biases(self):
mx.random.seed(0)
x = mx.random.normal(shape=(2, 2, 512))
w = mx.random.normal(shape=(512, 512))
wq, s, b = mx.quantize(w, bits=4, group_size=64)
def mm(sb, x, wq):
return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum()
params = (s, b)
dparams = mx.grad(mm)((s, b), x, wq)
eps = 8e-3
# numerical grad check with a few indices
indices = [(0, 0), (11, 4), (22, 7)]
for idx in indices:
for p in [0, 1]:
params[p][idx] += eps
out_up = mm(params, x, wq)
params[p][idx] -= 2 * eps
out_down = mm(params, x, wq)
params[p][idx] += eps
num_ds = (out_up - out_down) / (2 * eps)
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()