mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
fix quantized vjp for mxfp4 (#2555)
This commit is contained in:
@@ -3246,7 +3246,8 @@ std::vector<array> QuantizedMatmul::vjp(
|
|||||||
cotangents[0],
|
cotangents[0],
|
||||||
primals[1],
|
primals[1],
|
||||||
primals[2],
|
primals[2],
|
||||||
primals[3],
|
mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])
|
||||||
|
: std::nullopt,
|
||||||
!transpose_,
|
!transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
@@ -3260,7 +3261,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
|||||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
||||||
} else {
|
} else {
|
||||||
if (mode_ == QuantizationMode::Mxfp4) {
|
if (mode_ == QuantizationMode::Mxfp4) {
|
||||||
throw std::runtime_error(
|
throw std::invalid_argument(
|
||||||
"[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization.");
|
"[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization.");
|
||||||
}
|
}
|
||||||
if (!dsb) {
|
if (!dsb) {
|
||||||
@@ -3305,7 +3306,8 @@ std::vector<array> QuantizedMatmul::jvp(
|
|||||||
tangents[0],
|
tangents[0],
|
||||||
primals[1],
|
primals[1],
|
||||||
primals[2],
|
primals[2],
|
||||||
primals[3],
|
mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])
|
||||||
|
: std::nullopt,
|
||||||
transpose_,
|
transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
@@ -3346,9 +3348,11 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
auto& x = primals[0];
|
auto& x = primals[0];
|
||||||
auto& w = primals[1];
|
auto& w = primals[1];
|
||||||
auto& scales = primals[2];
|
auto& scales = primals[2];
|
||||||
auto& biases = primals[3];
|
auto& lhs_indices = primals[primals.size() - 2];
|
||||||
auto& lhs_indices = primals[4];
|
auto& rhs_indices = primals[primals.size() - 1];
|
||||||
auto& rhs_indices = primals[5];
|
auto biases = (mode_ == QuantizationMode::Affine)
|
||||||
|
? std::optional<array>(primals[3])
|
||||||
|
: std::nullopt;
|
||||||
|
|
||||||
int M = cotan.shape(-2);
|
int M = cotan.shape(-2);
|
||||||
int N = cotan.shape(-1);
|
int N = cotan.shape(-1);
|
||||||
@@ -3401,7 +3405,7 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
"[GatherQMM::vjp] no gradient wrt the quantized weights.");
|
"[GatherQMM::vjp] no gradient wrt the quantized weights.");
|
||||||
} else {
|
} else {
|
||||||
if (mode_ == QuantizationMode::Mxfp4) {
|
if (mode_ == QuantizationMode::Mxfp4) {
|
||||||
throw std::runtime_error(
|
throw std::invalid_argument(
|
||||||
"[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization.");
|
"[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3432,7 +3436,7 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
dequantize(
|
dequantize(
|
||||||
w,
|
w,
|
||||||
ones_like(scales, stream()),
|
ones_like(scales, stream()),
|
||||||
zeros_like(biases, stream()),
|
zeros_like(*biases, stream()),
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
quantization_mode_to_string(mode_),
|
quantization_mode_to_string(mode_),
|
||||||
|
@@ -842,6 +842,37 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
num_ds = (out_up - out_down) / (2 * eps)
|
num_ds = (out_up - out_down) / (2 * eps)
|
||||||
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
|
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
|
||||||
|
|
||||||
|
def test_mxfp4_vjp_scales_throws(self):
|
||||||
|
mx.random.seed(0)
|
||||||
|
x = mx.random.normal(shape=(2, 512))
|
||||||
|
w = mx.random.normal(shape=(512, 512))
|
||||||
|
wq, s = mx.quantize(w, bits=4, group_size=32, mode="mxfp4")
|
||||||
|
|
||||||
|
def mm(s, x, wq):
|
||||||
|
return mx.quantized_matmul(
|
||||||
|
x, wq, s, bits=4, group_size=32, mode="mxfp4"
|
||||||
|
).sum()
|
||||||
|
|
||||||
|
# Should raise
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ds = mx.grad(mm)(s, x, wq)
|
||||||
|
|
||||||
|
rhs_indices = mx.array(0)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
|
||||||
|
def gmm(s, x, wq):
|
||||||
|
return mx.gather_qmm(
|
||||||
|
x,
|
||||||
|
wq,
|
||||||
|
s,
|
||||||
|
rhs_indices=rhs_indices,
|
||||||
|
bits=4,
|
||||||
|
group_size=32,
|
||||||
|
mode="mxfp4",
|
||||||
|
).sum()
|
||||||
|
|
||||||
|
ds = mx.grad(gmm)(s, x, wq)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
Reference in New Issue
Block a user