Add quantize/dequantize slow path for mxfp8 and nvfp4

This commit is contained in:
Awni Hannun
2025-10-20 16:53:03 -07:00
parent 5d7efafe92
commit 8afc36cb87
5 changed files with 197 additions and 78 deletions

View File

@@ -4268,10 +4268,11 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dequantize the matrix ``w`` using quantization parameters.
@@ -4284,6 +4285,10 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
dtype (Dtype, optional): The data type of the dequantized output. If
``None`` the return type is inferred from the scales and biases
when possible and otherwise defaults to ``bfloat16``.
Default: ``None``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns:

View File

@@ -77,6 +77,84 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.all(w_hat == 0))
def test_mxfp8_quantize_dequantize(self):
w = 2 * mx.random.uniform(shape=(512, 32)) - 1
w = w.astype(mx.bfloat16)
# Invalid bits / group size
with self.assertRaises(ValueError):
mx.quantize(w, bits=3, group_size=32, mode="mxfp8")
with self.assertRaises(ValueError):
mx.quantize(w, group_size=32, bits=7, mode="mxfp8")
w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, bits=8, group_size=16, mode="mxfp8")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-2))
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
w_q, scales = mx.quantize(a, group_size=32, bits=8, mode="mxfp8")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
self.assertTrue(mx.all(w_hat == 0))
def test_nvfp4_quantize_dequantize(self):
lut = mx.array(
[
+0.0,
+0.5,
+1.0,
+1.5,
+2.0,
+3.0,
+4.0,
+6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
)
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
w = w.reshape(-1, 16)
w[:, 0] = 6
w = (w + 3e-6).astype(mx.bfloat16)
# Invalid bits / group size
with self.assertRaises(ValueError):
mx.quantize(w, bits=3, group_size=16, mode="nvfp4")
with self.assertRaises(ValueError):
mx.quantize(w, group_size=64, bits=4, mode="nvfp4")
w_q, scales = mx.quantize(w, group_size=16, bits=4, mode="nvfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, bits=4, group_size=32, mode="nvfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="nvfp4")
w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
w_q, scales = mx.quantize(a, group_size=16, bits=4, mode="nvfp4")
w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4")
self.assertTrue(mx.all(w_hat == 0))
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)