Enable vjp for quantized scale and bias (#2129)

* Enable vjp for quantized scale and bias

* higher tol
This commit is contained in:
Awni Hannun
2025-04-29 13:03:09 -07:00
committed by GitHub
parent b36dd472bb
commit 7bb063bcb3
2 changed files with 53 additions and 2 deletions

View File

@@ -3056,6 +3056,7 @@ std::vector<array> QuantizedMatmul::vjp(
std::vector<array> vjps;
// We rely on the fact that w is always 2D so transpose is simple
std::optional<array> dsb = std::nullopt;
for (auto arg : argnums) {
// gradient wrt to x
if (arg == 0) {
@@ -3071,9 +3072,34 @@ std::vector<array> QuantizedMatmul::vjp(
}
// gradient wrt to w_q, scales or biases
else {
else if (arg == 1) {
throw std::runtime_error(
"[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet.");
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
} else {
if (!dsb) {
auto fc = flatten(cotangents[0], 0, -2, stream());
auto fx = flatten(primals[0], 0, -2, stream());
auto dw = transpose_
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
dsb = unflatten(dw, -1, {-1, group_size_}, stream());
}
if (arg == 3) {
// biases
vjps.push_back(sum(*dsb, -1, false, stream()));
} else {
// scales
auto s = stream();
auto wq = dequantize(
primals[1],
ones_like(primals[2], stream()),
zeros_like(primals[3], stream()),
group_size_,
bits_,
stream());
wq = unflatten(wq, -1, {-1, group_size_}, stream());
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
}
}
}
return vjps;