From 5a043fd79389ee9eedf94fe063b6f3804dbf4c78 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 28 Oct 2025 07:12:35 -0700 Subject: [PATCH] improve quant docs --- mlx/ops.cpp | 2 +- python/src/ops.cpp | 70 +++++++++++++++++++--------------- python/tests/test_quantized.py | 2 +- 3 files changed, 42 insertions(+), 32 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 53729bbf9..b7c3bafe2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4037,7 +4037,7 @@ std::pair quantization_params_from_mode( break; case QuantizationMode::Mxfp8: default_group_size = 32; - default_bits = 4; + default_bits = 8; break; } return { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index d5593b93b..9816837ba 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4216,12 +4216,12 @@ void init_ops(nb::module_& m) { transpose (bool, optional): Defines whether to multiply with the transposed ``w`` or not, namely whether we are performing ``x @ w.T`` or ``x @ w``. Default: ``True``. - group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. If unspecified, a default is chosen based - on the mode. Default: ``None``. - bits (int, optional): The number of bits occupied by each element in - ``w``. If unspecified, a default is chosen based on the mode. - Default: ``None``. + group_size (int, optional): The size of the group in ``w`` that shares a + scale and bias. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): The number of bits occupied by each element of + ``w`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: @@ -4239,27 +4239,28 @@ void init_ops(nb::module_& m) { nb::sig( "def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), R"pbdoc( - Quantize the matrix ``w`` using ``bits`` bits per element. + Quantize the array ``w``. Note, every ``group_size`` elements in a row of ``w`` are quantized - together. Hence, number of columns of ``w`` should be divisible by - ``group_size``. In particular, the rows of ``w`` are divided into groups of - size ``group_size`` which are quantized together. + together. Hence, the last dimension of ``w`` should be divisible by + ``group_size``. .. warning:: - ``quantize`` currently only supports 2D inputs with the second - dimension divisible by ``group_size`` + ``quantize`` only supports inputs with two or more dimensions with + the last dimension divisible by ``group_size`` The supported quantization modes are ``"affine"``, ``"mxfp4"``, ``"mxfp8"``, and ``"nvfp4"``. They are described in more detail below. Args: - w (array): Matrix to be quantized + w (array): Array to be quantized group_size (int, optional): The size of the group in ``w`` that shares a - scale and bias. Default: ``64``. + scale and bias. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. bits (int, optional): The number of bits occupied by each element of - ``w`` in the returned quantized matrix. Default: ``4``. + ``w`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: @@ -4270,15 +4271,20 @@ void init_ops(nb::module_& m) { * biases (array): The quantization biases (returned for ``mode=="affine"``). Notes: + .. _quantize-modes: - ====== ====================== ========================== ============= ======== - mode group size bits scale type has bias - ====== ====================== ========================== ============= ======== - affine 32, 64 (default), 128 2, 3, 4 (default), 5, 6, 8 same as input yes - mxfp4 32 4 e8m0 no - mxfp8 32 4 e8m0 no - nvfp4 16 4 e4m3 no - ====== ====================== ========================== ============= ======== + .. table:: Quantization modes + + ====== ====================== ========================== ============= ===== + mode group size bits scale type bias + ====== ====================== ========================== ============= ===== + affine 32, 64\ :sup:`*`, 128 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes + mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no + mxfp8 32\ :sup:`*` 4\ :sup:`*` e8m0 no + nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no + ====== ====================== ========================== ============= ===== + + :sup:`*` indicates the default value when unspecified. The ``"affine"`` mode quantizes groups of :math:`g` consecutive elements in a row of ``w``. For each group the quantized @@ -4339,9 +4345,11 @@ void init_ops(nb::module_& m) { biases (array, optional): The biases to use per ``group_size`` elements of ``w``. Default: ``None``. group_size (int, optional): The size of the group in ``w`` that shares a - scale and bias. Default: ``64``. - bits (int, optional): The number of bits occupied by each element in - ``w``. Default: ``4``. + scale and bias. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): The number of bits occupied by each element of + ``w`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. dtype (Dtype, optional): The data type of the dequantized output. If ``None`` the return type is inferred from the scales and biases when possible and otherwise defaults to ``bfloat16``. @@ -4403,10 +4411,12 @@ void init_ops(nb::module_& m) { transpose (bool, optional): Defines whether to multiply with the transposed ``w`` or not, namely whether we are performing ``x @ w.T`` or ``x @ w``. Default: ``True``. - group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. Default: ``64``. - bits (int, optional): The number of bits occupied by each element in - ``w``. Default: ``4``. + group_size (int, optional): The size of the group in ``w`` that shares a + scale and bias. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): The number of bits occupied by each element of + ``w`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. sorted_indices (bool, optional): May allow a faster implementation if the passed indices are sorted. Default: ``False``. diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 1c1020c46..ef67d4478 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -146,7 +146,7 @@ class TestQuantized(mlx_tests.MLXTestCase): w_q, scales = mx.quantize(w, mode="nvfp4") with self.assertRaises(ValueError): - mx.dequantize(w_q, scales, mode="nvfp4") + mx.dequantize(w_q, scales, bits=3, mode="nvfp4") with self.assertRaises(ValueError): mx.dequantize(w_q, scales, group_size=32, mode="nvfp4")