mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 03:58:12 +08:00
Add mode parameter for quantization (#2499)
* add mode parameter for quantization * mxfp4 quantize/dequantize + start of optional biases * mxfp4 works * speedup * cpu mxfp4 * fix * fix test tol * fix * refactor * add quant mode enum
This commit is contained in:
@@ -39,6 +39,6 @@ class Embedding(Module):
|
||||
"""
|
||||
return x @ self.weight.T
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
|
||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits, mode)
|
||||
|
||||
@@ -70,9 +70,9 @@ class Linear(Module):
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
|
||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
||||
return QuantizedLinear.from_linear(self, group_size, bits, mode)
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
|
||||
@@ -12,6 +12,8 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
*,
|
||||
mode: str = "affine",
|
||||
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
@@ -26,6 +28,8 @@ def quantize(
|
||||
:func:`mlx.core.quantize`). Default: ``64``.
|
||||
bits (int): The number of bits per parameter (see
|
||||
:func:`mlx.core.quantize`). Default: ``4``.
|
||||
mode (str): The quantization method to use (see
|
||||
:func:`mlx.core.quantize`). Default: ``"affine"``.
|
||||
class_predicate (Optional[Callable]): A callable which receives the
|
||||
:obj:`Module` path and :obj:`Module` itself and returns ``True`` or a
|
||||
dict of params for `to_quantized` if it should be quantized and
|
||||
@@ -39,7 +43,7 @@ def quantize(
|
||||
if bool_or_params := class_predicate(path, m):
|
||||
if hasattr(m, "to_quantized"):
|
||||
if isinstance(bool_or_params, bool):
|
||||
return m.to_quantized(group_size=group_size, bits=bits)
|
||||
return m.to_quantized(group_size=group_size, bits=bits, mode=mode)
|
||||
elif isinstance(bool_or_params, dict):
|
||||
return m.to_quantized(**bool_or_params)
|
||||
else:
|
||||
@@ -72,6 +76,8 @@ class QuantizedEmbedding(Module):
|
||||
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||
bits (int, optional): The bit width to use for the quantized weight.
|
||||
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||
mode (str): The quantization method to use (see
|
||||
:func:`mlx.core.quantize`). Default: ``"affine"``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -80,17 +86,23 @@ class QuantizedEmbedding(Module):
|
||||
dims: int,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: str = "affine",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / dims)
|
||||
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
(self.scales,) = scales_biases
|
||||
self.num_embeddings = num_embeddings
|
||||
self.dims = dims
|
||||
|
||||
@@ -98,12 +110,14 @@ class QuantizedEmbedding(Module):
|
||||
self.freeze()
|
||||
|
||||
def __call__(self, x):
|
||||
biases = self.get("biases")
|
||||
return mx.dequantize(
|
||||
self["weight"][x],
|
||||
scales=self["scales"][x],
|
||||
biases=self["biases"][x],
|
||||
biases=biases[x] if biases is not None else None,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
def as_linear(self, x):
|
||||
@@ -117,28 +131,40 @@ class QuantizedEmbedding(Module):
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
biases=self.get("biases"),
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.num_embeddings}, {self.dims}, "
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||
cls,
|
||||
embedding_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: str = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
ql = cls(embedding_dims, dims, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
embedding_layer.weight, group_size, bits
|
||||
ql = cls(embedding_dims, dims, group_size, bits, mode=mode)
|
||||
ql.weight, *scales_biases = mx.quantize(
|
||||
embedding_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
mode=mode,
|
||||
)
|
||||
if mode == "affine":
|
||||
ql.scales, ql.biases = scales_biases
|
||||
else:
|
||||
(ql.scales,) = scales_biases
|
||||
return ql
|
||||
|
||||
|
||||
@@ -161,6 +187,8 @@ class QuantizedLinear(Module):
|
||||
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||
bits (int, optional): The bit width to use for the quantized weight.
|
||||
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||
mode (str): The quantization method to use (see
|
||||
:func:`mlx.core.quantize`). Default: ``"affine"``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -170,12 +198,14 @@ class QuantizedLinear(Module):
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: str = "affine",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
@@ -184,7 +214,11 @@ class QuantizedLinear(Module):
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
(self.scales,) = scales_biases
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
@@ -198,7 +232,7 @@ class QuantizedLinear(Module):
|
||||
in_dims *= 32 // self.bits
|
||||
return (
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
@@ -206,23 +240,38 @@ class QuantizedLinear(Module):
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
biases=self.get("biases"),
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(
|
||||
cls,
|
||||
linear_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: str = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode)
|
||||
ql.weight, *scales_biases = mx.quantize(
|
||||
linear_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
mode=mode,
|
||||
)
|
||||
if mode == "affine":
|
||||
ql.scales, ql.biases = scales_biases
|
||||
else:
|
||||
(ql.scales,) = scales_biases
|
||||
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
|
||||
@@ -4153,14 +4153,15 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@@ -4171,7 +4172,8 @@ void init_ops(nb::module_& m) {
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
@@ -4179,6 +4181,7 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
@@ -4189,10 +4192,11 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@@ -4203,30 +4207,11 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
.. warning::
|
||||
|
||||
``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32
|
||||
``quantize`` currently only supports 2D inputs with the second
|
||||
dimension divisible by ``group_size``
|
||||
|
||||
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
|
||||
:math:`w_g` in a row of ``w`` we compute the quantized representation
|
||||
of each element :math:`\hat{w_i}` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
\alpha &= \max_i w_i \\
|
||||
\beta &= \min_i w_i \\
|
||||
s &= \frac{\alpha - \beta}{2^b - 1} \\
|
||||
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
|
||||
\end{aligned}
|
||||
|
||||
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
|
||||
and is packed in an unsigned 32-bit integer from the lower to upper
|
||||
bits. For instance, for 4-bit quantization we fit 8 elements in an
|
||||
unsigned 32 bit integer where the 1st element occupies the 4 least
|
||||
significant bits, the 2nd bits 4-7 etc.
|
||||
|
||||
In order to be able to dequantize the elements of ``w`` we also need to
|
||||
save :math:`s` and :math:`\beta` which are the returned ``scales`` and
|
||||
``biases`` respectively.
|
||||
The supported quantization modes are ``"affine"`` and ``"mxfp4"``. They
|
||||
are described in more detail below.
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
@@ -4234,49 +4219,86 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
tuple: A tuple with either two or three elements containing:
|
||||
|
||||
* w_q (array): The quantized version of ``w``
|
||||
* scales (array): The scale to multiply each element with, namely :math:`s`
|
||||
* biases (array): The biases to add to each element, namely :math:`\beta`
|
||||
* scales (array): The quantization scales
|
||||
* biases (array): The quantization biases (returned for ``mode=="affine"``).
|
||||
|
||||
Notes:
|
||||
The ``affine`` mode quantizes groups of :math:`g` consecutive
|
||||
elements in a row of ``w``. For each group the quantized
|
||||
representation of each element :math:`\hat{w_i}` is computed as follows:
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
\alpha &= \max_i w_i \\
|
||||
\beta &= \min_i w_i \\
|
||||
s &= \frac{\alpha - \beta}{2^b - 1} \\
|
||||
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
|
||||
\end{aligned}
|
||||
|
||||
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
|
||||
and is packed in an unsigned 32-bit integer from the lower to upper
|
||||
bits. For instance, for 4-bit quantization we fit 8 elements in an
|
||||
unsigned 32 bit integer where the 1st element occupies the 4 least
|
||||
significant bits, the 2nd bits 4-7 etc.
|
||||
|
||||
To dequantize the elements of ``w``, we also save :math:`s` and
|
||||
:math:`\beta` which are the returned ``scales`` and
|
||||
``biases`` respectively.
|
||||
|
||||
The ``mxfp4`` mode similarly quantizes groups of :math:`g` elements
|
||||
of ``w``. For ``mxfp4`` the group size must be ``32``. The elements
|
||||
are quantized to 4-bit precision floating-point values (E2M1) with a
|
||||
shared 8-bit scale per group. Unlike ``affine`` quantization,
|
||||
``mxfp4`` does not have a bias value. More details on the format can
|
||||
be found in the `specification <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>`_.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"dequantize",
|
||||
&mx::dequantize,
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||
``biases`` and the ``group_size`` and ``bits`` configuration.
|
||||
|
||||
Formally, given the notation in :func:`quantize`, we compute
|
||||
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
||||
:math:`\beta` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
w_i = s \hat{w_i} + \beta
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
w (array): Matrix to be dequantized
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``.
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
|
||||
Notes:
|
||||
The currently supported quantization modes are ``"affine"`` and ``mxfp4``.
|
||||
|
||||
For ``affine`` quantization, given the notation in :func:`quantize`,
|
||||
we compute :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s`
|
||||
and :math:`\beta` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
w_i = s \hat{w_i} + \beta
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gather_qmm",
|
||||
@@ -4284,17 +4306,18 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"sorted_indices"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
@@ -4310,7 +4333,8 @@ void init_ops(nb::module_& m) {
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
||||
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
@@ -4320,6 +4344,7 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
sorted_indices (bool, optional): May allow a faster implementation
|
||||
if the passed indices are sorted. Default: ``False``.
|
||||
|
||||
|
||||
@@ -48,6 +48,8 @@ cuda_skip = {
|
||||
"TestQuantized.test_qmm_shapes",
|
||||
"TestQuantized.test_qmm_vjp",
|
||||
"TestQuantized.test_qmv",
|
||||
"TestQuantized.test_mxfp4_qmv",
|
||||
"TestQuantized.test_mxfp4_qvm",
|
||||
"TestQuantized.test_qvm",
|
||||
"TestQuantized.test_qvm_splitk",
|
||||
"TestQuantized.test_small_matrix",
|
||||
|
||||
@@ -198,6 +198,12 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
|
||||
nn.quantize(m, group_size=32, mode="mxfp4")
|
||||
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
self.assertTrue(isinstance(m.layers[2].scales, mx.array))
|
||||
|
||||
def test_quantize_freeze(self):
|
||||
lin = nn.Linear(512, 512)
|
||||
qlin = lin.to_quantized()
|
||||
|
||||
@@ -27,6 +27,56 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||
self.assertTrue(mx.all(a_hat == 0))
|
||||
|
||||
def test_mxfp4_quantize_dequantize(self):
|
||||
lut = mx.array(
|
||||
[
|
||||
+0.0,
|
||||
+0.5,
|
||||
+1.0,
|
||||
+1.5,
|
||||
+2.0,
|
||||
+3.0,
|
||||
+4.0,
|
||||
+6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
)
|
||||
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
|
||||
w = w.reshape(-1, 32)
|
||||
w[:, 0] = 6
|
||||
w = (w + 3e-6).astype(mx.bfloat16)
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, group_size=32, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
|
||||
|
||||
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
@@ -168,6 +218,34 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_mxfp4_qmv(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[256, 512, 67], # M
|
||||
[64, 128], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
for M, N, B in tests:
|
||||
with self.subTest(shape=(B, M, N), group_size=32):
|
||||
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (M, N) if B == 0 else (B, M, N)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
|
||||
y_q = mx.quantized_matmul(
|
||||
x,
|
||||
w_q,
|
||||
scales,
|
||||
transpose=True,
|
||||
group_size=32,
|
||||
mode="mxfp4",
|
||||
)
|
||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qvm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
@@ -233,6 +311,103 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
def test_mxfp4_qvm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[32, 128, 256], # M
|
||||
[128, 256, 67], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
# Add a splitk
|
||||
tests = list(tests)
|
||||
tests.append((128, 16384, 0))
|
||||
|
||||
for M, N, B in tests:
|
||||
with self.subTest(shape=(B, M, N)):
|
||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
|
||||
y_q = mx.quantized_matmul(
|
||||
x,
|
||||
w_q,
|
||||
scales,
|
||||
transpose=False,
|
||||
group_size=32,
|
||||
mode="mxfp4",
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
def test_mode_error_cases(self):
|
||||
w = mx.random.normal(shape=(256, 256))
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
|
||||
# Invalid mode
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, mode="xyz")
|
||||
|
||||
wq, scales, biases = mx.quantize(w, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode="xyz")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(
|
||||
x, wq, scales, biases, bits=4, group_size=32, mode="xyz"
|
||||
)
|
||||
|
||||
rhs_indices = mx.array(0)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.gather_qmm(
|
||||
x,
|
||||
wq,
|
||||
scales,
|
||||
biases,
|
||||
rhs_indices=rhs_indices,
|
||||
bits=4,
|
||||
group_size=32,
|
||||
mode="xyz",
|
||||
)
|
||||
|
||||
# Only quantize floating point types
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(mx.zeros((128, 128), mx.int32))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(mx.zeros((128, 128), mx.int32), mode="mxfp4")
|
||||
|
||||
# Must have bias for affine
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(wq, scales, None, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.gather_qmm(
|
||||
x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32
|
||||
)
|
||||
|
||||
# Must be floating point
|
||||
x = mx.zeros(shape=(256,), dtype=mx.int32)
|
||||
scales = mx.zeros(scales.shape, dtype=mx.int32)
|
||||
biases = mx.zeros(scales.shape, dtype=mx.int32)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(wq, scales, biases, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.gather_qmm(
|
||||
x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32
|
||||
)
|
||||
|
||||
def test_throw(self):
|
||||
x = mx.random.normal(shape=(10, 512))
|
||||
w = mx.random.normal(shape=(32, 512))
|
||||
@@ -360,9 +535,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_gather_qmm(self):
|
||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||
def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"):
|
||||
if mode == "affine":
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
|
||||
else:
|
||||
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
|
||||
b = None
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
|
||||
if transpose:
|
||||
w_hat = w_hat.swapaxes(-1, -2)
|
||||
return w_hat, qw, s, b
|
||||
@@ -379,6 +558,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
transpose=True,
|
||||
group_size=64,
|
||||
bits=4,
|
||||
mode="affine",
|
||||
):
|
||||
with self.subTest(
|
||||
M=M,
|
||||
@@ -392,12 +572,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
transpose=transpose,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
mode=mode,
|
||||
):
|
||||
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
|
||||
w = mx.random.normal(
|
||||
shape=batch_B + ((N, K) if transpose else (K, N))
|
||||
).astype(dtype)
|
||||
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
|
||||
w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode)
|
||||
|
||||
if lhs_indices is not None:
|
||||
lhs_indices = mx.array(lhs_indices)
|
||||
@@ -415,8 +596,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
transpose=transpose,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
inputs = (
|
||||
@@ -460,6 +641,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
"batch_B": (4, 1),
|
||||
"rhs_indices": ((2,), (0,), (1,)),
|
||||
},
|
||||
{
|
||||
"batch_A": (1,),
|
||||
"lhs_indices": (0,),
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
"group_size": 32,
|
||||
"mode": "mxfp4",
|
||||
},
|
||||
)
|
||||
|
||||
for kwargs in inputs:
|
||||
@@ -503,9 +692,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
||||
|
||||
def test_gather_qmm_sorted(self):
|
||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||
def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"):
|
||||
if mode == "affine":
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
|
||||
else:
|
||||
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
|
||||
b = None
|
||||
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
|
||||
if transpose:
|
||||
w_hat = w_hat.swapaxes(-1, -2)
|
||||
return w_hat, qw, s, b
|
||||
@@ -525,19 +719,23 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
parameters = [
|
||||
# L, K, D, E, I, transpose
|
||||
(32, 512, 512, 4, 2, True),
|
||||
(32, 512, 544, 4, 2, True),
|
||||
(133, 512, 512, 4, 2, True),
|
||||
(133, 512, 555, 4, 2, True),
|
||||
(133, 512, 512, 4, 2, True),
|
||||
(64, 512, 512, 4, 2, False),
|
||||
(64, 512, 544, 4, 2, False),
|
||||
(133, 512, 512, 4, 2, False),
|
||||
(133, 512, 544, 4, 2, False),
|
||||
(133, 512, 555, 4, 2, False),
|
||||
(64, 512, 512, 4, 2, False),
|
||||
(32, 512, 512, 4, 2, True, "affine"),
|
||||
(32, 512, 544, 4, 2, True, "mxfp4"),
|
||||
(133, 512, 512, 4, 2, True, "affine"),
|
||||
(133, 512, 555, 4, 2, True, "affine"),
|
||||
(133, 512, 512, 4, 2, True, "affine"),
|
||||
(64, 512, 512, 4, 2, False, "affine"),
|
||||
(64, 512, 544, 4, 2, False, "mxfp4"),
|
||||
(133, 512, 512, 4, 2, False, "affine"),
|
||||
(133, 512, 544, 4, 2, False, "affine"),
|
||||
(133, 512, 555, 4, 2, False, "affine"),
|
||||
(64, 512, 512, 4, 2, False, "affine"),
|
||||
]
|
||||
for L, K, D, E, I, transpose in parameters:
|
||||
for L, K, D, E, I, transpose, mode in parameters:
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
else:
|
||||
group_size = 64
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
xshape = (L, 1, 1, K)
|
||||
@@ -546,14 +744,28 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
||||
x = mx.random.normal(xshape) / K**0.5
|
||||
w = mx.random.normal(wshape) / K**0.5
|
||||
w, *wq = quantize(w, transpose=transpose)
|
||||
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
|
||||
|
||||
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
|
||||
y2 = mx.gather_qmm(
|
||||
x,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
transpose=transpose,
|
||||
rhs_indices=indices
|
||||
)
|
||||
xs, idx, inv_order = gather_sort(x, indices)
|
||||
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||
|
||||
y4 = mx.gather_qmm(
|
||||
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
|
||||
xs,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
rhs_indices=idx,
|
||||
transpose=transpose,
|
||||
sorted_indices=True
|
||||
)
|
||||
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||
|
||||
Reference in New Issue
Block a user