mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix quantized vjp for mxfp4 (#2555)
This commit is contained in:
@@ -3246,7 +3246,8 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
cotangents[0],
|
||||
primals[1],
|
||||
primals[2],
|
||||
primals[3],
|
||||
mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])
|
||||
: std::nullopt,
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
@@ -3260,7 +3261,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (mode_ == QuantizationMode::Mxfp4) {
|
||||
throw std::runtime_error(
|
||||
throw std::invalid_argument(
|
||||
"[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization.");
|
||||
}
|
||||
if (!dsb) {
|
||||
@@ -3305,7 +3306,8 @@ std::vector<array> QuantizedMatmul::jvp(
|
||||
tangents[0],
|
||||
primals[1],
|
||||
primals[2],
|
||||
primals[3],
|
||||
mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])
|
||||
: std::nullopt,
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
@@ -3346,9 +3348,11 @@ std::vector<array> GatherQMM::vjp(
|
||||
auto& x = primals[0];
|
||||
auto& w = primals[1];
|
||||
auto& scales = primals[2];
|
||||
auto& biases = primals[3];
|
||||
auto& lhs_indices = primals[4];
|
||||
auto& rhs_indices = primals[5];
|
||||
auto& lhs_indices = primals[primals.size() - 2];
|
||||
auto& rhs_indices = primals[primals.size() - 1];
|
||||
auto biases = (mode_ == QuantizationMode::Affine)
|
||||
? std::optional<array>(primals[3])
|
||||
: std::nullopt;
|
||||
|
||||
int M = cotan.shape(-2);
|
||||
int N = cotan.shape(-1);
|
||||
@@ -3401,7 +3405,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
"[GatherQMM::vjp] no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (mode_ == QuantizationMode::Mxfp4) {
|
||||
throw std::runtime_error(
|
||||
throw std::invalid_argument(
|
||||
"[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization.");
|
||||
}
|
||||
|
||||
@@ -3432,7 +3436,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
dequantize(
|
||||
w,
|
||||
ones_like(scales, stream()),
|
||||
zeros_like(biases, stream()),
|
||||
zeros_like(*biases, stream()),
|
||||
group_size_,
|
||||
bits_,
|
||||
quantization_mode_to_string(mode_),
|
||||
|
||||
Reference in New Issue
Block a user