Add mode parameter for quantization (#2499)

* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
This commit is contained in:
Awni Hannun
2025-08-28 06:45:26 -07:00
committed by GitHub
parent 7ef8a6f2d5
commit 70560b6bd5
28 changed files with 3635 additions and 757 deletions

View File

@@ -39,6 +39,6 @@ class Embedding(Module):
"""
return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits)
return QuantizedEmbedding.from_embedding(self, group_size, bits, mode)

View File

@@ -70,9 +70,9 @@ class Linear(Module):
x = x @ self["weight"].T
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits)
return QuantizedLinear.from_linear(self, group_size, bits, mode)
class Bilinear(Module):

View File

@@ -12,6 +12,8 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
*,
mode: str = "affine",
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
):
"""Quantize the sub-modules of a module according to a predicate.
@@ -26,6 +28,8 @@ def quantize(
:func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see
:func:`mlx.core.quantize`). Default: ``4``.
mode (str): The quantization method to use (see
:func:`mlx.core.quantize`). Default: ``"affine"``.
class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` or a
dict of params for `to_quantized` if it should be quantized and
@@ -39,7 +43,7 @@ def quantize(
if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"):
if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits)
return m.to_quantized(group_size=group_size, bits=bits, mode=mode)
elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params)
else:
@@ -72,6 +76,8 @@ class QuantizedEmbedding(Module):
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
mode (str): The quantization method to use (see
:func:`mlx.core.quantize`). Default: ``"affine"``.
"""
def __init__(
@@ -80,17 +86,23 @@ class QuantizedEmbedding(Module):
dims: int,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.mode = mode
# Initialize the quantized weight
scale = math.sqrt(1 / dims)
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
(self.scales,) = scales_biases
self.num_embeddings = num_embeddings
self.dims = dims
@@ -98,12 +110,14 @@ class QuantizedEmbedding(Module):
self.freeze()
def __call__(self, x):
biases = self.get("biases")
return mx.dequantize(
self["weight"][x],
scales=self["scales"][x],
biases=self["biases"][x],
biases=biases[x] if biases is not None else None,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
def as_linear(self, x):
@@ -117,28 +131,40 @@ class QuantizedEmbedding(Module):
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
biases=self.get("biases"),
transpose=True,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
def _extra_repr(self):
return (
f"{self.num_embeddings}, {self.dims}, "
f"group_size={self.group_size}, bits={self.bits}"
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
)
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
cls,
embedding_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
embedding_layer.weight, group_size, bits
ql = cls(embedding_dims, dims, group_size, bits, mode=mode)
ql.weight, *scales_biases = mx.quantize(
embedding_layer.weight,
group_size,
bits,
mode=mode,
)
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
(ql.scales,) = scales_biases
return ql
@@ -161,6 +187,8 @@ class QuantizedLinear(Module):
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
mode (str): The quantization method to use (see
:func:`mlx.core.quantize`). Default: ``"affine"``.
"""
def __init__(
@@ -170,12 +198,14 @@ class QuantizedLinear(Module):
bias: bool = True,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.mode = mode
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@@ -184,7 +214,11 @@ class QuantizedLinear(Module):
high=scale,
shape=(output_dims, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
(self.scales,) = scales_biases
# And bias if needed
if bias:
@@ -198,7 +232,7 @@ class QuantizedLinear(Module):
in_dims *= 32 // self.bits
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
)
def __call__(self, x):
@@ -206,23 +240,38 @@ class QuantizedLinear(Module):
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
biases=self.get("biases"),
transpose=True,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits
ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode)
ql.weight, *scales_biases = mx.quantize(
linear_layer.weight,
group_size,
bits,
mode=mode,
)
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
(ql.scales,) = scales_biases
if "bias" in linear_layer:
ql.bias = linear_layer.bias

View File

@@ -4153,14 +4153,15 @@ void init_ops(nb::module_& m) {
nb::arg(),
nb::arg(),
"scales"_a,
"biases"_a,
"biases"_a = nb::none(),
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of
@@ -4171,7 +4172,8 @@ void init_ops(nb::module_& m) {
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
biases (array, optional): The biases to use per ``group_size``
elements of ``w``. Default: ``None``.
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``.
@@ -4179,6 +4181,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns:
array: The result of the multiplication of ``x`` with ``w``.
@@ -4189,10 +4192,11 @@ void init_ops(nb::module_& m) {
nb::arg(),
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
"def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element.
@@ -4203,30 +4207,11 @@ void init_ops(nb::module_& m) {
.. warning::
``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32
``quantize`` currently only supports 2D inputs with the second
dimension divisible by ``group_size``
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
:math:`w_g` in a row of ``w`` we compute the quantized representation
of each element :math:`\hat{w_i}` as follows
.. math::
\begin{aligned}
\alpha &= \max_i w_i \\
\beta &= \min_i w_i \\
s &= \frac{\alpha - \beta}{2^b - 1} \\
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
\end{aligned}
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
and is packed in an unsigned 32-bit integer from the lower to upper
bits. For instance, for 4-bit quantization we fit 8 elements in an
unsigned 32 bit integer where the 1st element occupies the 4 least
significant bits, the 2nd bits 4-7 etc.
In order to be able to dequantize the elements of ``w`` we also need to
save :math:`s` and :math:`\beta` which are the returned ``scales`` and
``biases`` respectively.
The supported quantization modes are ``"affine"`` and ``"mxfp4"``. They
are described in more detail below.
Args:
w (array): Matrix to be quantized
@@ -4234,49 +4219,86 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns:
tuple: A tuple containing
tuple: A tuple with either two or three elements containing:
* w_q (array): The quantized version of ``w``
* scales (array): The scale to multiply each element with, namely :math:`s`
* biases (array): The biases to add to each element, namely :math:`\beta`
* scales (array): The quantization scales
* biases (array): The quantization biases (returned for ``mode=="affine"``).
Notes:
The ``affine`` mode quantizes groups of :math:`g` consecutive
elements in a row of ``w``. For each group the quantized
representation of each element :math:`\hat{w_i}` is computed as follows:
.. math::
\begin{aligned}
\alpha &= \max_i w_i \\
\beta &= \min_i w_i \\
s &= \frac{\alpha - \beta}{2^b - 1} \\
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
\end{aligned}
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
and is packed in an unsigned 32-bit integer from the lower to upper
bits. For instance, for 4-bit quantization we fit 8 elements in an
unsigned 32 bit integer where the 1st element occupies the 4 least
significant bits, the 2nd bits 4-7 etc.
To dequantize the elements of ``w``, we also save :math:`s` and
:math:`\beta` which are the returned ``scales`` and
``biases`` respectively.
The ``mxfp4`` mode similarly quantizes groups of :math:`g` elements
of ``w``. For ``mxfp4`` the group size must be ``32``. The elements
are quantized to 4-bit precision floating-point values (E2M1) with a
shared 8-bit scale per group. Unlike ``affine`` quantization,
``mxfp4`` does not have a bias value. More details on the format can
be found in the `specification <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>`_.
)pbdoc");
m.def(
"dequantize",
&mx::dequantize,
nb::arg(),
"scales"_a,
"biases"_a,
"biases"_a = nb::none(),
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dequantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``group_size`` and ``bits`` configuration.
Formally, given the notation in :func:`quantize`, we compute
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
:math:`\beta` as follows
.. math::
w_i = s \hat{w_i} + \beta
Dequantize the matrix ``w`` using quantization parameters.
Args:
w (array): Matrix to be quantized
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
w (array): Matrix to be dequantized
scales (array): The scales to use per ``group_size`` elements of ``w``.
biases (array, optional): The biases to use per ``group_size``
elements of ``w``. Default: ``None``.
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns:
array: The dequantized version of ``w``
Notes:
The currently supported quantization modes are ``"affine"`` and ``mxfp4``.
For ``affine`` quantization, given the notation in :func:`quantize`,
we compute :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s`
and :math:`\beta` as follows
.. math::
w_i = s \hat{w_i} + \beta
)pbdoc");
m.def(
"gather_qmm",
@@ -4284,17 +4306,18 @@ void init_ops(nb::module_& m) {
nb::arg(),
nb::arg(),
"scales"_a,
"biases"_a,
"biases"_a = nb::none(),
"lhs_indices"_a = nb::none(),
"rhs_indices"_a = nb::none(),
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"sorted_indices"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
@@ -4310,7 +4333,8 @@ void init_ops(nb::module_& m) {
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
biases (array, optional): The biases to use per ``group_size``
elements of ``w``. Default: ``None``.
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
transpose (bool, optional): Defines whether to multiply with the
@@ -4320,6 +4344,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.

View File

@@ -48,6 +48,8 @@ cuda_skip = {
"TestQuantized.test_qmm_shapes",
"TestQuantized.test_qmm_vjp",
"TestQuantized.test_qmv",
"TestQuantized.test_mxfp4_qmv",
"TestQuantized.test_mxfp4_qvm",
"TestQuantized.test_qvm",
"TestQuantized.test_qvm_splitk",
"TestQuantized.test_small_matrix",

View File

@@ -198,6 +198,12 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
nn.quantize(m, group_size=32, mode="mxfp4")
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
self.assertTrue(isinstance(m.layers[2].scales, mx.array))
def test_quantize_freeze(self):
lin = nn.Linear(512, 512)
qlin = lin.to_quantized()

View File

@@ -27,6 +27,56 @@ class TestQuantized(mlx_tests.MLXTestCase):
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
self.assertTrue(mx.all(a_hat == 0))
def test_mxfp4_quantize_dequantize(self):
lut = mx.array(
[
+0.0,
+0.5,
+1.0,
+1.5,
+2.0,
+3.0,
+4.0,
+6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
)
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
w = w.reshape(-1, 32)
w[:, 0] = 6
w = (w + 3e-6).astype(mx.bfloat16)
# Invalid bits / group size
with self.assertRaises(ValueError):
mx.quantize(w, bits=3, group_size=32, mode="mxfp4")
with self.assertRaises(ValueError):
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.all(w_hat == 0))
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
@@ -168,6 +218,34 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_mxfp4_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[256, 512, 67], # M
[64, 128], # N
[0, 1, 3, 8], # B
)
for M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=32):
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
w_shape = (M, N) if B == 0 else (B, M, N)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=True,
group_size=32,
mode="mxfp4",
)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
@@ -233,6 +311,103 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_mxfp4_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
# Add a splitk
tests = list(tests)
tests.append((128, 16384, 0))
for M, N, B in tests:
with self.subTest(shape=(B, M, N)):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=False,
group_size=32,
mode="mxfp4",
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_mode_error_cases(self):
w = mx.random.normal(shape=(256, 256))
x = mx.random.normal(shape=(1, 256))
# Invalid mode
with self.assertRaises(ValueError):
mx.quantize(w, mode="xyz")
wq, scales, biases = mx.quantize(w, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode="xyz")
with self.assertRaises(ValueError):
mx.quantized_matmul(
x, wq, scales, biases, bits=4, group_size=32, mode="xyz"
)
rhs_indices = mx.array(0)
with self.assertRaises(ValueError):
mx.gather_qmm(
x,
wq,
scales,
biases,
rhs_indices=rhs_indices,
bits=4,
group_size=32,
mode="xyz",
)
# Only quantize floating point types
with self.assertRaises(ValueError):
mx.quantize(mx.zeros((128, 128), mx.int32))
with self.assertRaises(ValueError):
mx.quantize(mx.zeros((128, 128), mx.int32), mode="mxfp4")
# Must have bias for affine
with self.assertRaises(ValueError):
mx.dequantize(wq, scales, None, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.gather_qmm(
x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32
)
# Must be floating point
x = mx.zeros(shape=(256,), dtype=mx.int32)
scales = mx.zeros(scales.shape, dtype=mx.int32)
biases = mx.zeros(scales.shape, dtype=mx.int32)
with self.assertRaises(ValueError):
mx.dequantize(wq, scales, biases, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.gather_qmm(
x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32
)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))
@@ -360,9 +535,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_gather_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@@ -379,6 +558,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=True,
group_size=64,
bits=4,
mode="affine",
):
with self.subTest(
M=M,
@@ -392,12 +572,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
):
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
w = mx.random.normal(
shape=batch_B + ((N, K) if transpose else (K, N))
).astype(dtype)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode)
if lhs_indices is not None:
lhs_indices = mx.array(lhs_indices)
@@ -415,8 +596,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
inputs = (
@@ -460,6 +641,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
"batch_B": (4, 1),
"rhs_indices": ((2,), (0,), (1,)),
},
{
"batch_A": (1,),
"lhs_indices": (0,),
"batch_B": (3,),
"rhs_indices": (2, 1),
"group_size": 32,
"mode": "mxfp4",
},
)
for kwargs in inputs:
@@ -503,9 +692,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@@ -525,19 +719,23 @@ class TestQuantized(mlx_tests.MLXTestCase):
parameters = [
# L, K, D, E, I, transpose
(32, 512, 512, 4, 2, True),
(32, 512, 544, 4, 2, True),
(133, 512, 512, 4, 2, True),
(133, 512, 555, 4, 2, True),
(133, 512, 512, 4, 2, True),
(64, 512, 512, 4, 2, False),
(64, 512, 544, 4, 2, False),
(133, 512, 512, 4, 2, False),
(133, 512, 544, 4, 2, False),
(133, 512, 555, 4, 2, False),
(64, 512, 512, 4, 2, False),
(32, 512, 512, 4, 2, True, "affine"),
(32, 512, 544, 4, 2, True, "mxfp4"),
(133, 512, 512, 4, 2, True, "affine"),
(133, 512, 555, 4, 2, True, "affine"),
(133, 512, 512, 4, 2, True, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
(64, 512, 544, 4, 2, False, "mxfp4"),
(133, 512, 512, 4, 2, False, "affine"),
(133, 512, 544, 4, 2, False, "affine"),
(133, 512, 555, 4, 2, False, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
]
for L, K, D, E, I, transpose in parameters:
for L, K, D, E, I, transpose, mode in parameters:
if mode == "mxfp4":
group_size = 32
else:
group_size = 64
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
@@ -546,14 +744,28 @@ class TestQuantized(mlx_tests.MLXTestCase):
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
y2 = mx.gather_qmm(
x,
*wq,
group_size=group_size,
mode=mode,
transpose=transpose,
rhs_indices=indices
)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
xs,
*wq,
group_size=group_size,
mode=mode,
rhs_indices=idx,
transpose=transpose,
sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)