From c00ccf7404f468095a30958112e33e18e54a3bed Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 20 Oct 2025 16:53:03 -0700 Subject: [PATCH] Add quantize/dequantize slow path for mxfp8 and nvfp4 --- mlx/ops.cpp | 187 +++++++++++++++++++-------------- mlx/ops.h | 1 + mlx/primitives.cpp | 2 + python/src/ops.cpp | 7 +- python/tests/test_quantized.py | 78 ++++++++++++++ 5 files changed, 197 insertions(+), 78 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 879ef4fd5..d99bf2f46 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4018,7 +4018,8 @@ array conv_general( } void validate_mode(std::string_view tag, const std::string& mode) { - if (mode != "affine" && mode != "mxfp4") { + if (mode != "affine" && mode != "mxfp4" && mode != "mxfp8" && + mode != "nvfp4") { std::ostringstream msg; msg << "[" << tag << "] Invalid quantization mode '" << mode << "'."; throw std::invalid_argument(msg.str()); @@ -4249,51 +4250,67 @@ std::vector quantize( if (mode == "affine") { return affine_quantize(w, group_size, bits, s); } else { - if (group_size != 32) { + int expected_gs = (mode[0] == 'm') ? 32 : 16; + int expected_bits = (mode.back() == '8') ? 8 : 4; + if (group_size != expected_gs) { std::ostringstream msg; - msg << "[quantize] mxfp4 quantization requires group size 32 " - << "but got " << group_size << "."; + msg << "[quantize] " << mode << " quantization requires group size " + << expected_gs << " but got " << group_size << "."; throw std::invalid_argument(msg.str()); } - if (bits != 4) { + if (bits != expected_bits) { std::ostringstream msg; - msg << "[quantize] mxfp4 quantization requires bits to be 4 " - << "but got " << bits << "."; + msg << "[quantize] " << mode << " quantization requires bits to be " + << expected_bits << " but got " << bits << "."; throw std::invalid_argument(msg.str()); } - - auto lut = array({ - +0.0f, - +0.5f, - +1.0f, - +1.5f, - +2.0f, - +3.0f, - +4.0f, - +6.0f, - -0.0f, - -0.5f, - -1.0f, - -1.5f, - -2.0f, - -3.0f, - -4.0f, - -6.0f, - }); - lut = astype(lut, w.dtype(), s); - + float maxval = (bits == 4) ? 6.0f : 448.0f; auto new_shape = w.shape(); new_shape.back() = -1; auto wq = reshape(w, {-1, group_size}, s); auto scales = - divide(max(abs(wq, s), -1, true, s), array(6.0f, w.dtype()), s); - scales = astype(log2(scales, s), int32, s); - wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); - scales = astype(add(scales, array(127, int32), s), uint8, s); - wq = argmin(abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s); - auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); - wq = reshape(wq, {-1, group_size / 8, 8}, s); - wq = sum(multiply(wq, shifts, s), -1, false, s); + divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s); + if (group_size == 16) { + // convert to e4m3 + scales = to_fp8(scales, s); + wq = divide(wq, from_fp8(scales, w.dtype(), s), s); + } else { + // convert to e8m0 + auto z = array(0, scales.dtype()); + scales = + where(equal(scales, z, s), z, astype(log2(scales, s), int32, s), s); + + wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); + scales = astype(add(scales, array(127, int32), s), uint8, s); + } + if (bits == 4) { + auto lut = array({ + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f, + }); + lut = astype(lut, w.dtype(), s); + wq = argmin( + abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s); + auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); + wq = reshape(wq, {-1, 4, 8}, s); + wq = sum(multiply(wq, shifts, s), -1, false, s); + } else { + wq = view(to_fp8(wq, s), uint32, s); + } wq = reshape(wq, new_shape, s); scales = reshape(scales, new_shape, s); return {std::move(wq), std::move(scales)}; @@ -4404,6 +4421,7 @@ array dequantize( int group_size /* = 64 */, int bits /* = 4 */, const std::string& mode /* = "affine" */, + std::optional dtype /* = std::nullopt */, StreamOrDevice s /* = {} */) { validate_mode_with_type("dequantize", scales, biases, mode); if (bits <= 0) { @@ -4422,24 +4440,30 @@ array dequantize( } if (mode == "affine") { - return affine_dequantize(w, scales, *biases, group_size, bits, s); + auto out = affine_dequantize(w, scales, *biases, group_size, bits, s); + if (dtype) { + out = astype(out, *dtype, s); + } + return out; } else { - if (group_size != 32) { + int expected_gs = (mode[0] == 'm') ? 32 : 16; + int expected_bits = (mode.back() == '8') ? 8 : 4; + if (group_size != expected_gs) { std::ostringstream msg; - msg << "[dequantize] mxfp4 quantization requires group size 32 " - << "but got " << group_size << "."; + msg << "[quantize] " << mode << " quantization requires group size " + << expected_gs << " but got " << group_size << "."; throw std::invalid_argument(msg.str()); } - if (bits != 4) { + if (bits != expected_bits) { std::ostringstream msg; - msg << "[dequantize] mxfp4 quantization requires bits to be 4 " - << "but got " << bits << "."; + msg << "[quantize] " << mode << " quantization requires bits to be " + << expected_bits << " but got " << bits << "."; throw std::invalid_argument(msg.str()); } if (w.ndim() < 2 || scales.ndim() < 2) { std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + msg << "[quantize] The matrix to be dequantized must have at least 2 dimension " << "but it has only " << w.ndim() << "."; throw std::invalid_argument(msg.str()); } @@ -4470,39 +4494,48 @@ array dequantize( throw std::invalid_argument(msg.str()); } - auto dtype = bfloat16; - auto lut = array( - { - +0.0f, - +0.5f, - +1.0f, - +1.5f, - +2.0f, - +3.0f, - +4.0f, - +6.0f, - -0.0f, - -0.5f, - -1.0f, - -1.5f, - -2.0f, - -3.0f, - -4.0f, - -6.0f, - }, - dtype); - - auto what = view(reshape(w, {-1, group_size / 8}, s), int8, s); - - auto idx_lo = bitwise_and(what, array(0x0F, int8), s); - auto idx_hi = right_shift(what, array(4, int8), s); - auto lo = gather(lut, idx_lo, 0, {1}, s); - auto hi = gather(lut, idx_hi, 0, {1}, s); - what = flatten(concatenate({lo, hi}, -1, s), -2, -1, s); - auto exponent = subtract(astype(scales, dtype, s), array(127, dtype), s); - exponent = reshape(exponent, {-1, 1}, s); - return reshape( - multiply(power(array(2.0f, dtype), exponent, s), what, s), wshape, s); + auto out_type = dtype.has_value() ? *dtype : bfloat16; + auto out = w; + if (bits == 4) { + auto lut = array( + { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f, + }, + out_type); + out = view(reshape(out, {-1, 4}, s), int8, s); + auto idx_lo = bitwise_and(out, array(0x0F, int8), s); + auto idx_hi = right_shift(out, array(4, int8), s); + auto lo = gather(lut, idx_lo, 0, {1}, s); + auto hi = gather(lut, idx_hi, 0, {1}, s); + out = concatenate({lo, hi}, -1, s); + } else { + out = from_fp8(view(out, uint8, s), out_type, s); + } + out = reshape(out, {-1, group_size}, s); + auto flat_scales = reshape(scales, {-1, 1}, s); + if (group_size == 16) { + flat_scales = from_fp8(flat_scales, out_type, s); + } else { + flat_scales = + subtract(astype(flat_scales, out_type, s), array(127, out_type), s); + flat_scales = power(array(2.0f, out_type), flat_scales, s); + } + return reshape(multiply(out, flat_scales, s), wshape, s); } } diff --git a/mlx/ops.h b/mlx/ops.h index 312caac6d..b86df59fa 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1400,6 +1400,7 @@ array dequantize( int group_size = 64, int bits = 4, const std::string& mode = "affine", + std::optional dtype = std::nullopt, StreamOrDevice s = {}); /** Convert an E4M3 float8 to the given floating point dtype. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 0b335e765..afd9dd5b9 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3404,6 +3404,7 @@ std::vector QuantizedMatmul::vjp( group_size_, bits_, quantization_mode_to_string(mode_), + std::nullopt, stream()); wq = unflatten(wq, -1, {-1, group_size_}, stream()); vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); @@ -3558,6 +3559,7 @@ std::vector GatherQMM::vjp( group_size_, bits_, quantization_mode_to_string(mode_), + std::nullopt, stream()), -1, {-1, group_size_}, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2e364db76..16b9f50e0 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4307,10 +4307,11 @@ void init_ops(nb::module_& m) { "group_size"_a = 64, "bits"_a = 4, "mode"_a = "affine", + "dtype"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dequantize the matrix ``w`` using quantization parameters. @@ -4323,6 +4324,10 @@ void init_ops(nb::module_& m) { scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + dtype (Dtype, optional): The data type of the dequantized output. If + ``None`` the return type is inferred from the scales and biases + when possible and otherwise defaults to ``bfloat16``. + Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 3a195ef54..5fe867b37 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -77,6 +77,84 @@ class TestQuantized(mlx_tests.MLXTestCase): w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") self.assertTrue(mx.all(w_hat == 0)) + def test_mxfp8_quantize_dequantize(self): + w = 2 * mx.random.uniform(shape=(512, 32)) - 1 + w = w.astype(mx.bfloat16) + + # Invalid bits / group size + with self.assertRaises(ValueError): + mx.quantize(w, bits=3, group_size=32, mode="mxfp8") + + with self.assertRaises(ValueError): + mx.quantize(w, group_size=32, bits=7, mode="mxfp8") + + w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, bits=8, group_size=16, mode="mxfp8") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8") + + w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8") + self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-2)) + + # test quantize/dequantize 0s + a = mx.zeros((256, 512)) + w_q, scales = mx.quantize(a, group_size=32, bits=8, mode="mxfp8") + w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8") + self.assertTrue(mx.all(w_hat == 0)) + + def test_nvfp4_quantize_dequantize(self): + lut = mx.array( + [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + ) + w = lut[mx.random.randint(0, 16, shape=(128, 512))] + w = w.reshape(-1, 16) + w[:, 0] = 6 + w = (w + 3e-6).astype(mx.bfloat16) + + # Invalid bits / group size + with self.assertRaises(ValueError): + mx.quantize(w, bits=3, group_size=16, mode="nvfp4") + + with self.assertRaises(ValueError): + mx.quantize(w, group_size=64, bits=4, mode="nvfp4") + + w_q, scales = mx.quantize(w, group_size=16, bits=4, mode="nvfp4") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, bits=4, group_size=32, mode="nvfp4") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, group_size=32, bits=4, mode="nvfp4") + + w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4") + self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) + + # test quantize/dequantize 0s + a = mx.zeros((256, 512)) + w_q, scales = mx.quantize(a, group_size=16, bits=4, mode="nvfp4") + w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4") + self.assertTrue(mx.all(w_hat == 0)) + def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key)