mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Enable vjp for quantized scale and bias (#2129)
* Enable vjp for quantized scale and bias * higher tol
This commit is contained in:
parent
b36dd472bb
commit
7bb063bcb3
@ -3056,6 +3056,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
std::vector<array> vjps;
|
||||
|
||||
// We rely on the fact that w is always 2D so transpose is simple
|
||||
std::optional<array> dsb = std::nullopt;
|
||||
for (auto arg : argnums) {
|
||||
// gradient wrt to x
|
||||
if (arg == 0) {
|
||||
@ -3071,9 +3072,34 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
}
|
||||
|
||||
// gradient wrt to w_q, scales or biases
|
||||
else {
|
||||
else if (arg == 1) {
|
||||
throw std::runtime_error(
|
||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet.");
|
||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (!dsb) {
|
||||
auto fc = flatten(cotangents[0], 0, -2, stream());
|
||||
auto fx = flatten(primals[0], 0, -2, stream());
|
||||
auto dw = transpose_
|
||||
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
|
||||
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
|
||||
dsb = unflatten(dw, -1, {-1, group_size_}, stream());
|
||||
}
|
||||
if (arg == 3) {
|
||||
// biases
|
||||
vjps.push_back(sum(*dsb, -1, false, stream()));
|
||||
} else {
|
||||
// scales
|
||||
auto s = stream();
|
||||
auto wq = dequantize(
|
||||
primals[1],
|
||||
ones_like(primals[2], stream()),
|
||||
zeros_like(primals[3], stream()),
|
||||
group_size_,
|
||||
bits_,
|
||||
stream());
|
||||
wq = unflatten(wq, -1, {-1, group_size_}, stream());
|
||||
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
|
||||
}
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
|
@ -549,6 +549,31 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||
|
||||
def test_vjp_scales_biases(self):
|
||||
mx.random.seed(0)
|
||||
x = mx.random.normal(shape=(2, 2, 512))
|
||||
w = mx.random.normal(shape=(512, 512))
|
||||
wq, s, b = mx.quantize(w, bits=4, group_size=64)
|
||||
|
||||
def mm(sb, x, wq):
|
||||
return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum()
|
||||
|
||||
params = (s, b)
|
||||
dparams = mx.grad(mm)((s, b), x, wq)
|
||||
|
||||
eps = 8e-3
|
||||
# numerical grad check with a few indices
|
||||
indices = [(0, 0), (11, 4), (22, 7)]
|
||||
for idx in indices:
|
||||
for p in [0, 1]:
|
||||
params[p][idx] += eps
|
||||
out_up = mm(params, x, wq)
|
||||
params[p][idx] -= 2 * eps
|
||||
out_down = mm(params, x, wq)
|
||||
params[p][idx] += eps
|
||||
num_ds = (out_up - out_down) / (2 * eps)
|
||||
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user