From 57fe918cf843a9ef66acd8a9c94ab26465ff15c5 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 20 Dec 2023 14:17:38 -0800 Subject: [PATCH] Adds C++ and nn quantization utilities (#230) * Add C++ de-/quantize ops * Add quantize functions to the docs and tests * Add a QuantizedLinear module --- docs/src/python/nn/layers.rst | 1 + docs/src/python/ops.rst | 3 + mlx/ops.cpp | 107 ++++++++++++++++++++++++++ mlx/ops.h | 16 ++++ python/mlx/nn/layers/__init__.py | 1 + python/mlx/nn/layers/base.py | 38 +++++++++ python/mlx/nn/layers/quantized.py | 124 ++++++++++++++++++++++++++++++ python/mlx/optimizers.py | 14 ++-- python/mlx/utils.py | 20 +++-- python/src/ops.cpp | 97 +++++++++++++++++++++++ python/tests/test_quantized.py | 80 ++++++------------- tests/ops_tests.cpp | 18 +++++ 12 files changed, 451 insertions(+), 68 deletions(-) create mode 100644 python/mlx/nn/layers/quantized.py diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 5628134d6..fab3ff785 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -26,3 +26,4 @@ Layers RoPE MultiHeadAttention Sequential + QuantizedLinear diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 731300047..7e391ec4c 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -34,6 +34,7 @@ Operations conv2d cos cosh + dequantize divide equal erf @@ -73,6 +74,8 @@ Operations partition pad prod + quantize + quantized_matmul reciprocal reshape round diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 88fb30f53..8aae5596d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2649,4 +2649,111 @@ array quantized_matmul( return out; } +std::tuple quantize( + const array& w, + int groups /* = 128 */, + int width /* = 4 */, + StreamOrDevice s /* = {} */) { + if (w.ndim() != 2) { + throw std::invalid_argument("[quantize] Only matrices supported for now"); + } + + if ((w.shape(0) % 32) != 0) { + throw std::invalid_argument( + "[quantize] All dimensions should be divisible by 32 for now"); + } + + if ((w.shape(-1) % groups) != 0) { + std::ostringstream msg; + msg << "[quantize] The last dimension of the matrix needs to be divisible by " + << "the quantization group size " << groups + << ". However the provided matrix" + << " has shape " << w.shape(); + throw std::invalid_argument(msg.str()); + } + + // Compute some constants used for the quantization + int n_bins = (1 << width) - 1; // 2**width - 1 + int el_per_int = 32 / width; + array shifts = power(array(2, uint32), arange(0, 32, width, uint32, s), s); + shifts = reshape(shifts, {1, 1, -1}, s); + + // Compute scales and biases + array packed_w = reshape(w, {w.shape(0), w.shape(1) / groups, groups}, s); + array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s); + array scales = squeeze(delta, -1, s); + array biases = squeeze(w_min, -1, s); + + // Quantize and pack w + packed_w = + astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32); + packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s); + packed_w = sum( + multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); + + return std::make_tuple(packed_w, scales, biases); +} + +array dequantize( + const array& w, + const array& scales, + const array& biases, + int groups /* = 128 */, + int width /* = 4 */, + StreamOrDevice s /* = {} */) { + if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) { + throw std::invalid_argument("[dequantize] Only matrices supported for now"); + } + + if ((w.shape(0) % 32) != 0) { + throw std::invalid_argument( + "[dequantize] All dimensions should be divisible by 32 for now"); + } + + if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) { + throw std::invalid_argument( + "[dequantize] Shape of scales and biases does not match the matrix"); + } + + if (w.dtype() != uint32) { + throw std::invalid_argument( + "[dequantize] The matrix should be given as a uint32"); + } + + // Compute some constants for the dequantization + int el_per_int = 32 / width; + + if (w.shape(1) * el_per_int != scales.shape(1) * groups) { + std::ostringstream msg; + msg << "[dequantize] Shape of scales and biases does not match the matrix " + << "given the quantization parameters. Provided matrix of shape " + << w.shape() << " and scales/biases of shape " << scales.shape() + << " with groups=" << groups << " and width=" << width << "."; + throw std::invalid_argument(msg.str()); + } + + // Extract the pieces from the passed quantized matrix + std::vector parts; + for (int start = 0; start < 32; start += width) { + // TODO: Implement bitwise operators for integral types + int shift_left = 32 - (start + width); + int shift_right = shift_left + start; + array p = multiply(w, array(1 << shift_left, uint32), s); + p = floor_divide(p, array(1 << shift_right, uint32), s); + p = expand_dims(p, -1, s); + parts.push_back(p); + } + array w_full = concatenate(parts, -1, s); + + // Dequantize + w_full = reshape(w_full, {w.shape(0), -1, groups}, s); + w_full = multiply(w_full, expand_dims(scales, -1, s), s); + w_full = add(w_full, expand_dims(biases, -1, s), s); + w_full = reshape(w_full, {w.shape(0), -1}, s); + + return w_full; +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 192ddb8d7..0c2c2916a 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1041,4 +1041,20 @@ array quantized_matmul( int width = 4, StreamOrDevice s = {}); +/** Quantize a matrix along its last axis */ +std::tuple quantize( + const array& w, + int groups = 128, + int width = 4, + StreamOrDevice s = {}); + +/** Dequantize a matrix produced by quantize() */ +array dequantize( + const array& w, + const array& scales, + const array& biases, + int groups = 128, + int width = 4, + StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 04557843b..aa22e495b 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -38,6 +38,7 @@ from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding +from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( MultiHeadAttention, TransformerEncoder, diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 2d80553a1..dcf079457 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -258,6 +258,44 @@ class Module(dict): filter_fn = filter_fn or Module.valid_parameter_filter self.update(self.filter_and_map(filter_fn, map_fn)) + def update_modules(self, modules: dict): + """Replace the child modules of this :class:`Module` instance with the + provided ones in the dict of dicts and lists. + + It is the equivalent of :meth:`Module.update` but for modules instead + of parameters and allows us to flexibly edit complex architectures by + programmatically swapping layers. + + The passed in parameters dictionary need not be a full dictionary + similar to :meth:`parameters`. Only the provided locations will be + updated. + + Args: + modules (dict): A complete or partial dictionary of the modules + submodules. + """ + + def apply(dst, modules): + if isinstance(modules, dict): + for k in modules: + if k in dst: + current_value = dst[k] + new_value = modules[k] + if self.is_module(current_value) and self.is_module(new_value): + dst[k] = new_value + elif isinstance(current_value, (dict, list)): + apply(current_value, new_value) + elif isinstance(modules, list): + for i in range(len(dst)): + current_value = dst[i] + new_value = modules[i] + if self.is_module(current_value) and self.is_module(new_value): + dst[i] = new_value + elif isinstance(current_value, (dict, list)): + apply(current_value, new_value) + + apply(self, modules) + def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]): """Apply a function to all the modules in this instance (including this instance). diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py new file mode 100644 index 000000000..e311049d3 --- /dev/null +++ b/python/mlx/nn/layers/quantized.py @@ -0,0 +1,124 @@ +# Copyright © 2023 Apple Inc. + +import math + +import mlx.core as mx +from mlx.nn.layers.base import Module +from mlx.nn.layers.linear import Linear +from mlx.utils import tree_flatten, tree_map + + +class QuantizedLinear(Module): + """Applies an affine transformation to the input using a quantized weight matrix. + + It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its + parameters are frozen and will not be included in any gradient computation + but this will probably change in the future. + + QuantizedLinear also provides two useful classmethods to convert linear + layers to QuantizedLinear layers. + + - :meth:`from_linear` returns a QuantizedLinear layer that applies the same + linear transformation up to the quantization error. + - :meth:`quantize_module` swaps all the linear layers of the passed module + with QuantizedLinear ones. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool): If set to ``False`` then the layer will not use a bias. + (default: True). + groups (int): The group size to use for the quantized weight. See + :func:`~mlx.core.quantize`. (default: 128) + width (int): The bit width to use for the quantized weight. See + :func:`~mlx.core.quantize`. (default: 4) + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + groups: int = 64, + width: int = 4, + ): + super().__init__() + + # Quantization config + self.groups = groups + self.width = width + + # Initialize the quantized weight + scale = math.sqrt(1 / input_dims) + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims), + ) + self.weight, self.scales, self.biases = mx.quantize(weight, groups, width) + + # And bias if needed + if bias: + self.bias = mx.zeros((output_dims,)) + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + def _extra_repr(self): + out_dims, in_dims = self.weight.shape + in_dims *= 32 // self.width + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}," + f"groups={self.groups}, width={self.width}" + ) + + def __call__(self, x): + x = mx.quantized_matmul( + x, + self.weight.T, + scales=self.scales, + biases=self.biases, + groups=self.groups, + width=self.width, + ) + if "bias" in self: + x = x + self.bias + return x + + @classmethod + def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4): + """Create a QuantizedLinear layer from the parameters of a provided + linear layer.""" + output_dims, input_dims = linear_layer.weight.shape + ql = cls(input_dims, output_dims, False, groups, width) + ql.weight, ql.scales, ql.biases = mx.quantize( + linear_layer.weight, groups, width + ) + if "bias" in linear_layer: + ql.bias = linear_layer.bias + + return ql + + @classmethod + def quantize_module( + cls, + model: Module, + groups: int = 64, + width: int = 4, + linear_class_predicate=lambda m: isinstance(m, Linear), + ): + def _quantize_if_linear(m): + if linear_class_predicate(m): + return cls.from_linear(m, groups, width) + else: + return m + + leaves = model.leaf_modules() + leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module) + model.update_modules(leaves) diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index 4fc2f6eed..17a16c459 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -445,13 +445,13 @@ class Adamax(Adam): class Lion(Optimizer): - r"""Implementation of the Lion optimizer [1]. + r"""Implementation of the Lion optimizer [1]. - Since updates are computed through the sign operation, they tend to - have larger norm than for other optimizers such as SGD and Adam. - We recommend a learning rate that is 3-10x smaller than AdamW and a - weight decay 3-10x larger than AdamW to maintain the strength - (lr * wd). Our Lion implementation follows the original paper. In + Since updates are computed through the sign operation, they tend to + have larger norm than for other optimizers such as SGD and Adam. + We recommend a learning rate that is 3-10x smaller than AdamW and a + weight decay 3-10x larger than AdamW to maintain the strength + (lr * wd). Our Lion implementation follows the original paper. In detail, [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv @@ -486,7 +486,7 @@ class Lion(Optimizer): def apply_single( self, gradient: mx.array, parameter: mx.array, state: OptimizerState ): - """Performs the Lion parameter update and stores :math:`m` + """Performs the Lion parameter update and stores :math:`m` in the optimizer state.""" lr = self.learning_rate b1, b2 = self.betas diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 1fb6bd498..8cb8e90c8 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. -def tree_map(fn, tree, *rest): +def tree_map(fn, tree, *rest, is_leaf=None): """Applies ``fn`` to the leaves of the python tree ``tree`` and returns a new collection with the results. @@ -10,6 +10,9 @@ def tree_map(fn, tree, *rest): ``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap` than to :func:`map`. + The keyword argument ``is_leaf`` decides what constitutes a leaf from + ``tree`` similar to :func:`tree_flatten`. + .. code-block:: python import mlx.nn as nn @@ -26,21 +29,28 @@ def tree_map(fn, tree, *rest): fn (Callable): The function that processes the leaves of the tree tree (Any): The main python tree that will be iterated upon rest (Tuple[Any]): Extra trees to be iterated together with tree + is_leaf (Optional[Callable]): An optional callable that returns True if + the passed object is considered a leaf or False otherwise. Returns: A python tree with the new values returned by ``fn``. """ - if isinstance(tree, list): + if is_leaf is not None and is_leaf(tree): + return fn(tree, *rest) + elif isinstance(tree, list): return [ - tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree) + tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) + for i, child in enumerate(tree) ] elif isinstance(tree, tuple): return tuple( - tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree) + tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) + for i, child in enumerate(tree) ) elif isinstance(tree, dict): return { - k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items() + k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf) + for k, child in tree.items() } else: return fn(tree, *rest) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 3c4bdc018..627dd9a80 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3035,4 +3035,101 @@ void init_ops(py::module_& m) { Returns: result (array): The result of the multiplication of ``x`` with ``w``. )pbdoc"); + m.def( + "quantize", + &quantize, + "w"_a, + py::pos_only(), + "groups"_a = 128, + "width"_a = 4, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array] + + Quantize the matrix ``w`` using ``width`` bits per element. + + Note, every ``groups`` elements in a row of ``w`` are quantized + together. Hence, number of columns of ``w`` should be divisible by + ``groups``. In particular, the rows of ``w`` are divided into groups of + size ``groups`` which are quantized together. + + .. warning:: + + ``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32 + + Formally, for a group of :math:`g` consecutive elements :math:`w_1` to + :math:`w_g` in a row of ``w`` we compute the quantized representation + of each element :math:`\hat{w_i}` as follows + + .. math:: + + \begin{aligned} + \alpha &= \max_i w_i \\ + \beta &= \min_i w_i \\ + s &= \frac{\alpha - \beta}{2^b - 1} \\ + \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). + \end{aligned} + + After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits + and is packed in an unsigned 32-bit integer from the lower to upper + bits. For instance, for 4-bit quantization we fit 8 elements in an + unsigned 32 bit integer where the 1st element occupies the 4 least + significant bits, the 2nd bits 4-7 etc. + + In order to be able to dequantize the elements of ``w`` we also need to + save :math:`s` and :math:`\beta` which are the returned ``scales`` and + ``biases`` respectively. + + Args: + w (array): Matrix to be quantized + groups (int, optional): The size of the group in ``w`` that shares a + scale and bias. (default: 128) + width (int, optional): The bitwidth of the elements in ``w``. + (default: 4) + + Returns: + (tuple): A tuple containing + + - w_q (array): The quantized version of ``w`` + - scales (array): The scale to multiply each element with, namely :math:`s` + - biases (array): The biases to add to each element, namely :math:`\beta` + )pbdoc"); + m.def( + "dequantize", + &dequantize, + "w"_a, + py::pos_only(), + "scales"_a, + "biases"_a, + "groups"_a = 128, + "width"_a = 4, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array + + Dequantize the matrix ``w`` using the provided ``scales`` and + ``biases`` and the ``groups`` and ``width`` configuration. + + Formally, given the notation in :func:`quantize`, we compute + :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and + :math:`\beta` as follows + + .. math:: + + w_i = s \hat{w_i} - \beta + + Args: + w (array): Matrix to be quantized + scales (array): The scales to use per ``groups`` elements of ``w`` + biases (array): The biases to use per ``groups`` elements of ``w`` + groups (int, optional): The size of the group in ``w`` that shares a + scale and bias. (default: 128) + width (int, optional): The bitwidth of the elements in ``w``. + (default: 4) + + Returns: + result (array): The dequantized version of w + )pbdoc"); } diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 48493df26..72af0558c 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -6,48 +6,14 @@ import mlx.core as mx import mlx_tests -def select_bits(w, width, start): - shift_left = 32 - (start + width) - shift_right = shift_left + start - return (w * (2**shift_left)) // (2**shift_right) - - -def dequantize(w, scales, biases, width): - w_full = mx.concatenate( - [select_bits(w, width, i)[..., None] for i in range(0, 32, width)], axis=-1 - ) - w_full = w_full.reshape(len(w), scales.shape[-1], -1) - w_full = scales[..., None] * w_full + biases[..., None] - w_full = w_full.reshape(len(w), -1) - - return w_full - - -def quantize(w, width, groups): - w = w.reshape(len(w), -1, groups) - w_max = w.max(-1, keepdims=True) - w_min = w.min(-1, keepdims=True) - delta = (w_max - w_min) / (2**width - 1) - - w_int = mx.round((w - w_min) / delta).astype(mx.uint32) - scales = delta.squeeze(-1) - biases = w_min.squeeze(-1) - - shifts = mx.array([2**i for i in range(0, 32, width)], dtype=mx.uint32) - w_int = w_int.reshape(len(w), -1, 32 // width) - w_int = w_int * shifts[None, None] - packed_w = w_int.sum(-1) - - return packed_w, scales, biases - - class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 128)) - w_q, scales, biases = quantize(w, 4, 64) - w_hat = dequantize(w_q, scales, biases, 4) - w_hat2 = dequantize(*quantize(w_hat, 4, 64), 4) - self.assertLess((w_hat - w_hat2).abs().max(), 1e-6) + for b in [2, 4, 8]: + w_q, scales, biases = mx.quantize(w, 64, b) + w_hat = mx.dequantize(w_q, scales, biases, 64, b) + errors = (w - w_hat).abs().reshape(*scales.shape, -1) + self.assertTrue((errors <= scales[..., None] / 2).all()) def test_qmm(self): key = mx.random.key(0) @@ -62,14 +28,16 @@ class TestQuantized(mlx_tests.MLXTestCase): ): x = mx.random.normal(shape=(M, K), key=k1) w = mx.random.normal(shape=(N, K), key=k2) - w_q, scales, biases = quantize(w, width, groups) - w_hat = dequantize(w_q, scales, biases, width) + w_q, scales, biases = mx.quantize(w, groups, width) + w_hat = mx.dequantize( + w_q, scales, biases, groups, width + ) y_q = mx.quantized_matmul( x, w_q.T, scales, biases, width=width, groups=groups ) y_hat = x @ w_hat.T self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 0.1) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_qmm_shapes(self): key = mx.random.key(0) @@ -77,8 +45,8 @@ class TestQuantized(mlx_tests.MLXTestCase): groups = 64 width = 4 w = mx.random.normal(shape=(32, 128), key=k2) - w_q, scales, biases = quantize(w, width, groups) - w_hat = dequantize(w_q, scales, biases, width) + w_q, scales, biases = mx.quantize(w, groups, width) + w_hat = mx.dequantize(w_q, scales, biases, groups, width) for s in [(3, 128), (2, 1, 7, 128)]: x = mx.random.normal(shape=(3, 128), key=k1) y_q = mx.quantized_matmul( @@ -86,7 +54,7 @@ class TestQuantized(mlx_tests.MLXTestCase): ) y_hat = x @ w_hat.T self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 0.1) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_qmv(self): key = mx.random.key(0) @@ -95,17 +63,17 @@ class TestQuantized(mlx_tests.MLXTestCase): for width in [2, 4, 8]: for M in [512, 1024]: for N in [512, 1024]: - # with self.subTest(shape=(M, N), groups=groups, width=width): - x = mx.random.normal(shape=(1, N), key=k1) - w = mx.random.normal(shape=(M, N), key=k2) - w_q, scales, biases = quantize(w, width, groups) - w_hat = dequantize(w_q, scales, biases, width) - y_q = mx.quantized_matmul( - x, w_q.T, scales, biases, width=width, groups=groups - ) - y_hat = x @ w_hat.T - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 0.1) + with self.subTest(shape=(M, N), groups=groups, width=width): + x = mx.random.normal(shape=(1, N), key=k1) + w = mx.random.normal(shape=(M, N), key=k2) + w_q, scales, biases = mx.quantize(w, groups, width) + w_hat = mx.dequantize(w_q, scales, biases, groups, width) + y_q = mx.quantized_matmul( + x, w_q.T, scales, biases, width=width, groups=groups + ) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) if __name__ == "__main__": diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 4e87ffc33..5bee4b5c8 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2215,3 +2215,21 @@ TEST_CASE("test linspace") { expected = array(std::initializer_list{}, {0}); CHECK(array_equal(x, expected).item()); } + +TEST_CASE("test quantize dequantize") { + auto x1 = ones({128, 1}); + auto x2 = expand_dims(arange(0, 128, float32), 0); + auto x = x1 * x2; + + for (int i = 2; i <= 8; i *= 2) { + int el_per_int = 32 / i; + auto [x_q, scales, biases] = quantize(x, 128, i); + CHECK_EQ(x_q.shape(), std::vector{128, 128 / el_per_int}); + CHECK_EQ(scales.shape(), std::vector{128, 1}); + CHECK_EQ(biases.shape(), std::vector{128, 1}); + + auto x_hat = dequantize(x_q, scales, biases, 128, i); + auto max_diff = max(abs(x - x_hat)).item(); + CHECK(max_diff <= 127.0 / (1 << i)); + } +}