fix quantized vjp for mxfp4 (#2555)

This commit is contained in:
Awni Hannun
2025-08-29 10:06:15 -07:00
committed by GitHub
parent 9c68b50853
commit 8ce49cd39e
2 changed files with 43 additions and 8 deletions

View File

@@ -3246,7 +3246,8 @@ std::vector<array> QuantizedMatmul::vjp(
cotangents[0], cotangents[0],
primals[1], primals[1],
primals[2], primals[2],
primals[3], mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])
: std::nullopt,
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
@@ -3260,7 +3261,7 @@ std::vector<array> QuantizedMatmul::vjp(
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); "[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
} else { } else {
if (mode_ == QuantizationMode::Mxfp4) { if (mode_ == QuantizationMode::Mxfp4) {
throw std::runtime_error( throw std::invalid_argument(
"[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization."); "[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization.");
} }
if (!dsb) { if (!dsb) {
@@ -3305,7 +3306,8 @@ std::vector<array> QuantizedMatmul::jvp(
tangents[0], tangents[0],
primals[1], primals[1],
primals[2], primals[2],
primals[3], mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])
: std::nullopt,
transpose_, transpose_,
group_size_, group_size_,
bits_, bits_,
@@ -3346,9 +3348,11 @@ std::vector<array> GatherQMM::vjp(
auto& x = primals[0]; auto& x = primals[0];
auto& w = primals[1]; auto& w = primals[1];
auto& scales = primals[2]; auto& scales = primals[2];
auto& biases = primals[3]; auto& lhs_indices = primals[primals.size() - 2];
auto& lhs_indices = primals[4]; auto& rhs_indices = primals[primals.size() - 1];
auto& rhs_indices = primals[5]; auto biases = (mode_ == QuantizationMode::Affine)
? std::optional<array>(primals[3])
: std::nullopt;
int M = cotan.shape(-2); int M = cotan.shape(-2);
int N = cotan.shape(-1); int N = cotan.shape(-1);
@@ -3401,7 +3405,7 @@ std::vector<array> GatherQMM::vjp(
"[GatherQMM::vjp] no gradient wrt the quantized weights."); "[GatherQMM::vjp] no gradient wrt the quantized weights.");
} else { } else {
if (mode_ == QuantizationMode::Mxfp4) { if (mode_ == QuantizationMode::Mxfp4) {
throw std::runtime_error( throw std::invalid_argument(
"[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization."); "[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization.");
} }
@@ -3432,7 +3436,7 @@ std::vector<array> GatherQMM::vjp(
dequantize( dequantize(
w, w,
ones_like(scales, stream()), ones_like(scales, stream()),
zeros_like(biases, stream()), zeros_like(*biases, stream()),
group_size_, group_size_,
bits_, bits_,
quantization_mode_to_string(mode_), quantization_mode_to_string(mode_),

View File

@@ -842,6 +842,37 @@ class TestQuantized(mlx_tests.MLXTestCase):
num_ds = (out_up - out_down) / (2 * eps) num_ds = (out_up - out_down) / (2 * eps)
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2) self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
def test_mxfp4_vjp_scales_throws(self):
mx.random.seed(0)
x = mx.random.normal(shape=(2, 512))
w = mx.random.normal(shape=(512, 512))
wq, s = mx.quantize(w, bits=4, group_size=32, mode="mxfp4")
def mm(s, x, wq):
return mx.quantized_matmul(
x, wq, s, bits=4, group_size=32, mode="mxfp4"
).sum()
# Should raise
with self.assertRaises(ValueError):
ds = mx.grad(mm)(s, x, wq)
rhs_indices = mx.array(0)
with self.assertRaises(ValueError):
def gmm(s, x, wq):
return mx.gather_qmm(
x,
wq,
s,
rhs_indices=rhs_indices,
bits=4,
group_size=32,
mode="mxfp4",
).sum()
ds = mx.grad(gmm)(s, x, wq)
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()