improve quant docs

This commit is contained in:
Awni Hannun
2025-10-28 07:12:35 -07:00
parent 94fe5114fa
commit 5a043fd793
3 changed files with 42 additions and 32 deletions

View File

@@ -4037,7 +4037,7 @@ std::pair<int, int> quantization_params_from_mode(
break;
case QuantizationMode::Mxfp8:
default_group_size = 32;
default_bits = 4;
default_bits = 8;
break;
}
return {

View File

@@ -4216,12 +4216,12 @@ void init_ops(nb::module_& m) {
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. If unspecified, a default is chosen based
on the mode. Default: ``None``.
bits (int, optional): The number of bits occupied by each element in
``w``. If unspecified, a default is chosen based on the mode.
Default: ``None``.
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. See supported values and defaults in the
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the quantized array. See supported values and defaults in the
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns:
@@ -4239,27 +4239,28 @@ void init_ops(nb::module_& m) {
nb::sig(
"def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element.
Quantize the array ``w``.
Note, every ``group_size`` elements in a row of ``w`` are quantized
together. Hence, number of columns of ``w`` should be divisible by
``group_size``. In particular, the rows of ``w`` are divided into groups of
size ``group_size`` which are quantized together.
together. Hence, the last dimension of ``w`` should be divisible by
``group_size``.
.. warning::
``quantize`` currently only supports 2D inputs with the second
dimension divisible by ``group_size``
``quantize`` only supports inputs with two or more dimensions with
the last dimension divisible by ``group_size``
The supported quantization modes are ``"affine"``, ``"mxfp4"``,
``"mxfp8"``, and ``"nvfp4"``. They are described in more detail below.
Args:
w (array): Matrix to be quantized
w (array): Array to be quantized
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. Default: ``64``.
scale and bias. See supported values and defaults in the
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
``w`` in the quantized array. See supported values and defaults in the
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns:
@@ -4270,15 +4271,20 @@ void init_ops(nb::module_& m) {
* biases (array): The quantization biases (returned for ``mode=="affine"``).
Notes:
.. _quantize-modes:
====== ====================== ========================== ============= ========
mode group size bits scale type has bias
====== ====================== ========================== ============= ========
affine 32, 64 (default), 128 2, 3, 4 (default), 5, 6, 8 same as input yes
mxfp4 32 4 e8m0 no
mxfp8 32 4 e8m0 no
nvfp4 16 4 e4m3 no
====== ====================== ========================== ============= ========
.. table:: Quantization modes
====== ====================== ========================== ============= =====
mode group size bits scale type bias
====== ====================== ========================== ============= =====
affine 32, 64\ :sup:`*`, 128 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes
mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no
mxfp8 32\ :sup:`*` 4\ :sup:`*` e8m0 no
nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no
====== ====================== ========================== ============= =====
:sup:`*` indicates the default value when unspecified.
The ``"affine"`` mode quantizes groups of :math:`g` consecutive
elements in a row of ``w``. For each group the quantized
@@ -4339,9 +4345,11 @@ void init_ops(nb::module_& m) {
biases (array, optional): The biases to use per ``group_size``
elements of ``w``. Default: ``None``.
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
scale and bias. See supported values and defaults in the
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the quantized array. See supported values and defaults in the
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
dtype (Dtype, optional): The data type of the dequantized output. If
``None`` the return type is inferred from the scales and biases
when possible and otherwise defaults to ``bfloat16``.
@@ -4403,10 +4411,12 @@ void init_ops(nb::module_& m) {
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. See supported values and defaults in the
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the quantized array. See supported values and defaults in the
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.

View File

@@ -146,7 +146,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_q, scales = mx.quantize(w, mode="nvfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, mode="nvfp4")
mx.dequantize(w_q, scales, bits=3, mode="nvfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=32, mode="nvfp4")