mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Enable vjp for quantized scale and bias (#2129)
* Enable vjp for quantized scale and bias * higher tol
This commit is contained in:
@@ -3056,6 +3056,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
std::vector<array> vjps;
|
||||
|
||||
// We rely on the fact that w is always 2D so transpose is simple
|
||||
std::optional<array> dsb = std::nullopt;
|
||||
for (auto arg : argnums) {
|
||||
// gradient wrt to x
|
||||
if (arg == 0) {
|
||||
@@ -3071,9 +3072,34 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
}
|
||||
|
||||
// gradient wrt to w_q, scales or biases
|
||||
else {
|
||||
else if (arg == 1) {
|
||||
throw std::runtime_error(
|
||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet.");
|
||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (!dsb) {
|
||||
auto fc = flatten(cotangents[0], 0, -2, stream());
|
||||
auto fx = flatten(primals[0], 0, -2, stream());
|
||||
auto dw = transpose_
|
||||
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
|
||||
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
|
||||
dsb = unflatten(dw, -1, {-1, group_size_}, stream());
|
||||
}
|
||||
if (arg == 3) {
|
||||
// biases
|
||||
vjps.push_back(sum(*dsb, -1, false, stream()));
|
||||
} else {
|
||||
// scales
|
||||
auto s = stream();
|
||||
auto wq = dequantize(
|
||||
primals[1],
|
||||
ones_like(primals[2], stream()),
|
||||
zeros_like(primals[3], stream()),
|
||||
group_size_,
|
||||
bits_,
|
||||
stream());
|
||||
wq = unflatten(wq, -1, {-1, group_size_}, stream());
|
||||
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
|
||||
}
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
|
||||
Reference in New Issue
Block a user