mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 03:18:12 +08:00
Add quantize/dequantize for mxfp8 and nvfp4 (#2688)
* Add quantize/dequantize slow path for mxfp8 and nvfp4 * fast cuda kernel for mx/nv quantization * fallback for cuda < 12.8 (#2697) * format (#2700) * fix (#2701) * metal kernels * docs * fix jit * add default bits and group sizes * improve quant docs * fix output type of mxfp4 matmuls
This commit is contained in:
@@ -4194,13 +4194,13 @@ void init_ops(nb::module_& m) {
|
||||
"scales"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"group_size"_a = nb::none(),
|
||||
"bits"_a = nb::none(),
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@@ -4216,10 +4216,12 @@ void init_ops(nb::module_& m) {
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. See supported values and defaults in the
|
||||
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the quantized array. See supported values and defaults in the
|
||||
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
@@ -4229,35 +4231,36 @@ void init_ops(nb::module_& m) {
|
||||
"quantize",
|
||||
&mx::quantize,
|
||||
nb::arg(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"group_size"_a = nb::none(),
|
||||
"bits"_a = nb::none(),
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
Quantize the array ``w``.
|
||||
|
||||
Note, every ``group_size`` elements in a row of ``w`` are quantized
|
||||
together. Hence, number of columns of ``w`` should be divisible by
|
||||
``group_size``. In particular, the rows of ``w`` are divided into groups of
|
||||
size ``group_size`` which are quantized together.
|
||||
together. Hence, the last dimension of ``w`` should be divisible by
|
||||
``group_size``.
|
||||
|
||||
.. warning::
|
||||
|
||||
``quantize`` currently only supports 2D inputs with the second
|
||||
dimension divisible by ``group_size``
|
||||
``quantize`` only supports inputs with two or more dimensions with
|
||||
the last dimension divisible by ``group_size``
|
||||
|
||||
The supported quantization modes are ``"affine"`` and ``"mxfp4"``. They
|
||||
are described in more detail below.
|
||||
The supported quantization modes are ``"affine"``, ``"mxfp4"``,
|
||||
``"mxfp8"``, and ``"nvfp4"``. They are described in more detail below.
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
w (array): Array to be quantized
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. Default: ``64``.
|
||||
scale and bias. See supported values and defaults in the
|
||||
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
``w`` in the quantized array. See supported values and defaults in the
|
||||
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
@@ -4268,7 +4271,22 @@ void init_ops(nb::module_& m) {
|
||||
* biases (array): The quantization biases (returned for ``mode=="affine"``).
|
||||
|
||||
Notes:
|
||||
The ``affine`` mode quantizes groups of :math:`g` consecutive
|
||||
.. _quantize-modes:
|
||||
|
||||
.. table:: Quantization modes
|
||||
|
||||
====== ====================== ========================== ============= =====
|
||||
mode group size bits scale type bias
|
||||
====== ====================== ========================== ============= =====
|
||||
affine 32, 64\ :sup:`*`, 128 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes
|
||||
mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no
|
||||
mxfp8 32\ :sup:`*` 4\ :sup:`*` e8m0 no
|
||||
nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no
|
||||
====== ====================== ========================== ============= =====
|
||||
|
||||
:sup:`*` indicates the default value when unspecified.
|
||||
|
||||
The ``"affine"`` mode quantizes groups of :math:`g` consecutive
|
||||
elements in a row of ``w``. For each group the quantized
|
||||
representation of each element :math:`\hat{w_i}` is computed as follows:
|
||||
|
||||
@@ -4291,11 +4309,17 @@ void init_ops(nb::module_& m) {
|
||||
:math:`\beta` which are the returned ``scales`` and
|
||||
``biases`` respectively.
|
||||
|
||||
The ``mxfp4`` mode similarly quantizes groups of :math:`g` elements
|
||||
of ``w``. For ``mxfp4`` the group size must be ``32``. The elements
|
||||
are quantized to 4-bit precision floating-point values (E2M1) with a
|
||||
shared 8-bit scale per group. Unlike ``affine`` quantization,
|
||||
``mxfp4`` does not have a bias value. More details on the format can
|
||||
The ``"mxfp4"``, ``"mxfp8"``, and ``"nvfp4"`` modes similarly
|
||||
quantize groups of :math:`g` elements of ``w``. For the ``"mx"``
|
||||
modes, the group size must be ``32``. For ``"nvfp4"`` the group
|
||||
size must be 16. The elements are quantized to 4-bit or 8-bit
|
||||
precision floating-point values: E2M1 for ``"fp4"`` and E4M3 for
|
||||
``"fp8"``. There is a shared 8-bit scale per group. The ``"mx"``
|
||||
modes us an E8M0 scale and the ``"nv"`` mode uses an E4M3 scale.
|
||||
Unlike ``affine`` quantization, these modes does not have a bias
|
||||
value.
|
||||
|
||||
More details on the ``"mx"`` formats can
|
||||
be found in the `specification <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>`_.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -4304,13 +4328,14 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"group_size"_a = nb::none(),
|
||||
"bits"_a = nb::none(),
|
||||
"mode"_a = "affine",
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
@@ -4320,16 +4345,23 @@ void init_ops(nb::module_& m) {
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
scale and bias. See supported values and defaults in the
|
||||
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the quantized array. See supported values and defaults in the
|
||||
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
|
||||
dtype (Dtype, optional): The data type of the dequantized output. If
|
||||
``None`` the return type is inferred from the scales and biases
|
||||
when possible and otherwise defaults to ``bfloat16``.
|
||||
Default: ``None``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
|
||||
Notes:
|
||||
The currently supported quantization modes are ``"affine"`` and ``mxfp4``.
|
||||
The currently supported quantization modes are ``"affine"``,
|
||||
``"mxfp4``, ``"mxfp8"``, and ``"nvfp4"``.
|
||||
|
||||
For ``affine`` quantization, given the notation in :func:`quantize`,
|
||||
we compute :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s`
|
||||
@@ -4349,14 +4381,14 @@ void init_ops(nb::module_& m) {
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"group_size"_a = nb::none(),
|
||||
"bits"_a = nb::none(),
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"sorted_indices"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
@@ -4379,10 +4411,12 @@ void init_ops(nb::module_& m) {
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. See supported values and defaults in the
|
||||
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the quantized array. See supported values and defaults in the
|
||||
:ref:`table of quantization modes <quantize-modes>`. Default: ``None``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
sorted_indices (bool, optional): May allow a faster implementation
|
||||
if the passed indices are sorted. Default: ``False``.
|
||||
|
||||
@@ -55,26 +55,109 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, group_size=32, mode="mxfp4")
|
||||
mx.quantize(w, bits=3, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
|
||||
mx.quantize(w, group_size=64, mode="mxfp4")
|
||||
|
||||
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
|
||||
w_q, scales = mx.quantize(w, mode="mxfp4")
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=3, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
|
||||
mx.dequantize(w_q, scales, group_size=64, mode="mxfp4")
|
||||
|
||||
# Invalid output type
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
|
||||
mx.dequantize(
|
||||
w_q, scales, group_size=32, bits=4, mode="mxfp4", dtype=mx.int32
|
||||
)
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="mxfp4")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
w_q, scales = mx.quantize(a, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="mxfp4")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_mxfp8_quantize_dequantize(self):
|
||||
w = 2 * mx.random.uniform(shape=(512, 32)) - 1
|
||||
w = w.astype(mx.bfloat16)
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, mode="mxfp8")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=32, bits=7, mode="mxfp8")
|
||||
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp8")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=16, mode="mxfp8")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=4, mode="mxfp8")
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, mode="mxfp8")
|
||||
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-1))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, mode="mxfp8")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="mxfp8")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_nvfp4_quantize_dequantize(self):
|
||||
lut = mx.array(
|
||||
[
|
||||
+0.0,
|
||||
+0.5,
|
||||
+1.0,
|
||||
+1.5,
|
||||
+2.0,
|
||||
+3.0,
|
||||
+4.0,
|
||||
+6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
)
|
||||
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
|
||||
w = w.reshape(-1, 16)
|
||||
w[:, 0] = 6
|
||||
w = (w + 3e-6).astype(mx.bfloat16)
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, mode="nvfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=64, mode="nvfp4")
|
||||
|
||||
w_q, scales = mx.quantize(w, mode="nvfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=3, mode="nvfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=32, mode="nvfp4")
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, mode="nvfp4")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, mode="nvfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="nvfp4")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_qmm(self):
|
||||
@@ -662,6 +745,25 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
test_shape(32, 512, 32, transpose=False, **kwargs)
|
||||
test_shape(1, 512, 32, transpose=False, **kwargs)
|
||||
|
||||
def test_qmm_mxfp4_type(self):
|
||||
indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||
|
||||
for t in [mx.bfloat16, mx.float16, mx.float32]:
|
||||
x = mx.random.normal((32, 256)).astype(t)
|
||||
|
||||
w = mx.random.normal((32, 256))
|
||||
wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32)
|
||||
out = mx.quantized_matmul(x, wq, s, mode="mxfp4", group_size=32, bits=4)
|
||||
self.assertEqual(out.dtype, t)
|
||||
|
||||
w = mx.random.normal((4, 32, 256))
|
||||
wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32)
|
||||
|
||||
out = mx.gather_qmm(
|
||||
x, wq, s, rhs_indices=indices, mode="mxfp4", group_size=32, bits=4
|
||||
)
|
||||
self.assertEqual(out.dtype, t)
|
||||
|
||||
def test_gather_matmul_grad(self):
|
||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||
|
||||
Reference in New Issue
Block a user