Add quantize/dequantize slow path for mxfp8 and nvfp4

This commit is contained in:
Awni Hannun
2025-10-20 16:53:03 -07:00
parent 460691a0e8
commit c00ccf7404
5 changed files with 197 additions and 78 deletions

View File

@@ -4018,7 +4018,8 @@ array conv_general(
}
void validate_mode(std::string_view tag, const std::string& mode) {
if (mode != "affine" && mode != "mxfp4") {
if (mode != "affine" && mode != "mxfp4" && mode != "mxfp8" &&
mode != "nvfp4") {
std::ostringstream msg;
msg << "[" << tag << "] Invalid quantization mode '" << mode << "'.";
throw std::invalid_argument(msg.str());
@@ -4249,51 +4250,67 @@ std::vector<array> quantize(
if (mode == "affine") {
return affine_quantize(w, group_size, bits, s);
} else {
if (group_size != 32) {
int expected_gs = (mode[0] == 'm') ? 32 : 16;
int expected_bits = (mode.back() == '8') ? 8 : 4;
if (group_size != expected_gs) {
std::ostringstream msg;
msg << "[quantize] mxfp4 quantization requires group size 32 "
<< "but got " << group_size << ".";
msg << "[quantize] " << mode << " quantization requires group size "
<< expected_gs << " but got " << group_size << ".";
throw std::invalid_argument(msg.str());
}
if (bits != 4) {
if (bits != expected_bits) {
std::ostringstream msg;
msg << "[quantize] mxfp4 quantization requires bits to be 4 "
<< "but got " << bits << ".";
msg << "[quantize] " << mode << " quantization requires bits to be "
<< expected_bits << " but got " << bits << ".";
throw std::invalid_argument(msg.str());
}
auto lut = array({
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f,
});
lut = astype(lut, w.dtype(), s);
float maxval = (bits == 4) ? 6.0f : 448.0f;
auto new_shape = w.shape();
new_shape.back() = -1;
auto wq = reshape(w, {-1, group_size}, s);
auto scales =
divide(max(abs(wq, s), -1, true, s), array(6.0f, w.dtype()), s);
scales = astype(log2(scales, s), int32, s);
wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);
scales = astype(add(scales, array(127, int32), s), uint8, s);
wq = argmin(abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s);
auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s);
wq = reshape(wq, {-1, group_size / 8, 8}, s);
wq = sum(multiply(wq, shifts, s), -1, false, s);
divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s);
if (group_size == 16) {
// convert to e4m3
scales = to_fp8(scales, s);
wq = divide(wq, from_fp8(scales, w.dtype(), s), s);
} else {
// convert to e8m0
auto z = array(0, scales.dtype());
scales =
where(equal(scales, z, s), z, astype(log2(scales, s), int32, s), s);
wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);
scales = astype(add(scales, array(127, int32), s), uint8, s);
}
if (bits == 4) {
auto lut = array({
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f,
});
lut = astype(lut, w.dtype(), s);
wq = argmin(
abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s);
auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s);
wq = reshape(wq, {-1, 4, 8}, s);
wq = sum(multiply(wq, shifts, s), -1, false, s);
} else {
wq = view(to_fp8(wq, s), uint32, s);
}
wq = reshape(wq, new_shape, s);
scales = reshape(scales, new_shape, s);
return {std::move(wq), std::move(scales)};
@@ -4404,6 +4421,7 @@ array dequantize(
int group_size /* = 64 */,
int bits /* = 4 */,
const std::string& mode /* = "affine" */,
std::optional<Dtype> dtype /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
validate_mode_with_type("dequantize", scales, biases, mode);
if (bits <= 0) {
@@ -4422,24 +4440,30 @@ array dequantize(
}
if (mode == "affine") {
return affine_dequantize(w, scales, *biases, group_size, bits, s);
auto out = affine_dequantize(w, scales, *biases, group_size, bits, s);
if (dtype) {
out = astype(out, *dtype, s);
}
return out;
} else {
if (group_size != 32) {
int expected_gs = (mode[0] == 'm') ? 32 : 16;
int expected_bits = (mode.back() == '8') ? 8 : 4;
if (group_size != expected_gs) {
std::ostringstream msg;
msg << "[dequantize] mxfp4 quantization requires group size 32 "
<< "but got " << group_size << ".";
msg << "[quantize] " << mode << " quantization requires group size "
<< expected_gs << " but got " << group_size << ".";
throw std::invalid_argument(msg.str());
}
if (bits != 4) {
if (bits != expected_bits) {
std::ostringstream msg;
msg << "[dequantize] mxfp4 quantization requires bits to be 4 "
<< "but got " << bits << ".";
msg << "[quantize] " << mode << " quantization requires bits to be "
<< expected_bits << " but got " << bits << ".";
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2 || scales.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
msg << "[quantize] The matrix to be dequantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
@@ -4470,39 +4494,48 @@ array dequantize(
throw std::invalid_argument(msg.str());
}
auto dtype = bfloat16;
auto lut = array(
{
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f,
},
dtype);
auto what = view(reshape(w, {-1, group_size / 8}, s), int8, s);
auto idx_lo = bitwise_and(what, array(0x0F, int8), s);
auto idx_hi = right_shift(what, array(4, int8), s);
auto lo = gather(lut, idx_lo, 0, {1}, s);
auto hi = gather(lut, idx_hi, 0, {1}, s);
what = flatten(concatenate({lo, hi}, -1, s), -2, -1, s);
auto exponent = subtract(astype(scales, dtype, s), array(127, dtype), s);
exponent = reshape(exponent, {-1, 1}, s);
return reshape(
multiply(power(array(2.0f, dtype), exponent, s), what, s), wshape, s);
auto out_type = dtype.has_value() ? *dtype : bfloat16;
auto out = w;
if (bits == 4) {
auto lut = array(
{
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f,
},
out_type);
out = view(reshape(out, {-1, 4}, s), int8, s);
auto idx_lo = bitwise_and(out, array(0x0F, int8), s);
auto idx_hi = right_shift(out, array(4, int8), s);
auto lo = gather(lut, idx_lo, 0, {1}, s);
auto hi = gather(lut, idx_hi, 0, {1}, s);
out = concatenate({lo, hi}, -1, s);
} else {
out = from_fp8(view(out, uint8, s), out_type, s);
}
out = reshape(out, {-1, group_size}, s);
auto flat_scales = reshape(scales, {-1, 1}, s);
if (group_size == 16) {
flat_scales = from_fp8(flat_scales, out_type, s);
} else {
flat_scales =
subtract(astype(flat_scales, out_type, s), array(127, out_type), s);
flat_scales = power(array(2.0f, out_type), flat_scales, s);
}
return reshape(multiply(out, flat_scales, s), wshape, s);
}
}

View File

@@ -1400,6 +1400,7 @@ array dequantize(
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
std::optional<Dtype> dtype = std::nullopt,
StreamOrDevice s = {});
/** Convert an E4M3 float8 to the given floating point dtype. */

View File

@@ -3404,6 +3404,7 @@ std::vector<array> QuantizedMatmul::vjp(
group_size_,
bits_,
quantization_mode_to_string(mode_),
std::nullopt,
stream());
wq = unflatten(wq, -1, {-1, group_size_}, stream());
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
@@ -3558,6 +3559,7 @@ std::vector<array> GatherQMM::vjp(
group_size_,
bits_,
quantization_mode_to_string(mode_),
std::nullopt,
stream()),
-1,
{-1, group_size_},

View File

@@ -4307,10 +4307,11 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dequantize the matrix ``w`` using quantization parameters.
@@ -4323,6 +4324,10 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
dtype (Dtype, optional): The data type of the dequantized output. If
``None`` the return type is inferred from the scales and biases
when possible and otherwise defaults to ``bfloat16``.
Default: ``None``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns:

View File

@@ -77,6 +77,84 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.all(w_hat == 0))
def test_mxfp8_quantize_dequantize(self):
w = 2 * mx.random.uniform(shape=(512, 32)) - 1
w = w.astype(mx.bfloat16)
# Invalid bits / group size
with self.assertRaises(ValueError):
mx.quantize(w, bits=3, group_size=32, mode="mxfp8")
with self.assertRaises(ValueError):
mx.quantize(w, group_size=32, bits=7, mode="mxfp8")
w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, bits=8, group_size=16, mode="mxfp8")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-2))
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
w_q, scales = mx.quantize(a, group_size=32, bits=8, mode="mxfp8")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
self.assertTrue(mx.all(w_hat == 0))
def test_nvfp4_quantize_dequantize(self):
lut = mx.array(
[
+0.0,
+0.5,
+1.0,
+1.5,
+2.0,
+3.0,
+4.0,
+6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
)
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
w = w.reshape(-1, 16)
w[:, 0] = 6
w = (w + 3e-6).astype(mx.bfloat16)
# Invalid bits / group size
with self.assertRaises(ValueError):
mx.quantize(w, bits=3, group_size=16, mode="nvfp4")
with self.assertRaises(ValueError):
mx.quantize(w, group_size=64, bits=4, mode="nvfp4")
w_q, scales = mx.quantize(w, group_size=16, bits=4, mode="nvfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, bits=4, group_size=32, mode="nvfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="nvfp4")
w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
w_q, scales = mx.quantize(a, group_size=16, bits=4, mode="nvfp4")
w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4")
self.assertTrue(mx.all(w_hat == 0))
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)