diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 977e5c62a..1406fd46f 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3246,7 +3246,8 @@ std::vector QuantizedMatmul::vjp( cotangents[0], primals[1], primals[2], - primals[3], + mode_ == QuantizationMode::Affine ? std::optional(primals[3]) + : std::nullopt, !transpose_, group_size_, bits_, @@ -3260,7 +3261,7 @@ std::vector QuantizedMatmul::vjp( "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); } else { if (mode_ == QuantizationMode::Mxfp4) { - throw std::runtime_error( + throw std::invalid_argument( "[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization."); } if (!dsb) { @@ -3305,7 +3306,8 @@ std::vector QuantizedMatmul::jvp( tangents[0], primals[1], primals[2], - primals[3], + mode_ == QuantizationMode::Affine ? std::optional(primals[3]) + : std::nullopt, transpose_, group_size_, bits_, @@ -3346,9 +3348,11 @@ std::vector GatherQMM::vjp( auto& x = primals[0]; auto& w = primals[1]; auto& scales = primals[2]; - auto& biases = primals[3]; - auto& lhs_indices = primals[4]; - auto& rhs_indices = primals[5]; + auto& lhs_indices = primals[primals.size() - 2]; + auto& rhs_indices = primals[primals.size() - 1]; + auto biases = (mode_ == QuantizationMode::Affine) + ? std::optional(primals[3]) + : std::nullopt; int M = cotan.shape(-2); int N = cotan.shape(-1); @@ -3401,7 +3405,7 @@ std::vector GatherQMM::vjp( "[GatherQMM::vjp] no gradient wrt the quantized weights."); } else { if (mode_ == QuantizationMode::Mxfp4) { - throw std::runtime_error( + throw std::invalid_argument( "[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization."); } @@ -3432,7 +3436,7 @@ std::vector GatherQMM::vjp( dequantize( w, ones_like(scales, stream()), - zeros_like(biases, stream()), + zeros_like(*biases, stream()), group_size_, bits_, quantization_mode_to_string(mode_), diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index f22c0cae3..3a195ef54 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -842,6 +842,37 @@ class TestQuantized(mlx_tests.MLXTestCase): num_ds = (out_up - out_down) / (2 * eps) self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2) + def test_mxfp4_vjp_scales_throws(self): + mx.random.seed(0) + x = mx.random.normal(shape=(2, 512)) + w = mx.random.normal(shape=(512, 512)) + wq, s = mx.quantize(w, bits=4, group_size=32, mode="mxfp4") + + def mm(s, x, wq): + return mx.quantized_matmul( + x, wq, s, bits=4, group_size=32, mode="mxfp4" + ).sum() + + # Should raise + with self.assertRaises(ValueError): + ds = mx.grad(mm)(s, x, wq) + + rhs_indices = mx.array(0) + with self.assertRaises(ValueError): + + def gmm(s, x, wq): + return mx.gather_qmm( + x, + wq, + s, + rhs_indices=rhs_indices, + bits=4, + group_size=32, + mode="mxfp4", + ).sum() + + ds = mx.grad(gmm)(s, x, wq) + if __name__ == "__main__": mlx_tests.MLXTestRunner()