mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops * Add quantize functions to the docs and tests * Add a QuantizedLinear module
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							4912ff3ec2
						
					
				
				
					commit
					57fe918cf8
				
			| @@ -38,6 +38,7 @@ from mlx.nn.layers.embedding import Embedding | ||||
| from mlx.nn.layers.linear import Linear | ||||
| from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm | ||||
| from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding | ||||
| from mlx.nn.layers.quantized import QuantizedLinear | ||||
| from mlx.nn.layers.transformer import ( | ||||
|     MultiHeadAttention, | ||||
|     TransformerEncoder, | ||||
|   | ||||
| @@ -258,6 +258,44 @@ class Module(dict): | ||||
|         filter_fn = filter_fn or Module.valid_parameter_filter | ||||
|         self.update(self.filter_and_map(filter_fn, map_fn)) | ||||
|  | ||||
|     def update_modules(self, modules: dict): | ||||
|         """Replace the child modules of this :class:`Module` instance with the | ||||
|         provided ones in the dict of dicts and lists. | ||||
|  | ||||
|         It is the equivalent of :meth:`Module.update` but for modules instead | ||||
|         of parameters and allows us to flexibly edit complex architectures by | ||||
|         programmatically swapping layers. | ||||
|  | ||||
|         The passed in parameters dictionary need not be a full dictionary | ||||
|         similar to :meth:`parameters`. Only the provided locations will be | ||||
|         updated. | ||||
|  | ||||
|         Args: | ||||
|             modules (dict): A complete or partial dictionary of the modules | ||||
|                 submodules. | ||||
|         """ | ||||
|  | ||||
|         def apply(dst, modules): | ||||
|             if isinstance(modules, dict): | ||||
|                 for k in modules: | ||||
|                     if k in dst: | ||||
|                         current_value = dst[k] | ||||
|                         new_value = modules[k] | ||||
|                         if self.is_module(current_value) and self.is_module(new_value): | ||||
|                             dst[k] = new_value | ||||
|                         elif isinstance(current_value, (dict, list)): | ||||
|                             apply(current_value, new_value) | ||||
|             elif isinstance(modules, list): | ||||
|                 for i in range(len(dst)): | ||||
|                     current_value = dst[i] | ||||
|                     new_value = modules[i] | ||||
|                     if self.is_module(current_value) and self.is_module(new_value): | ||||
|                         dst[i] = new_value | ||||
|                     elif isinstance(current_value, (dict, list)): | ||||
|                         apply(current_value, new_value) | ||||
|  | ||||
|         apply(self, modules) | ||||
|  | ||||
|     def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]): | ||||
|         """Apply a function to all the modules in this instance (including this | ||||
|         instance). | ||||
|   | ||||
							
								
								
									
										124
									
								
								python/mlx/nn/layers/quantized.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								python/mlx/nn/layers/quantized.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,124 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import math | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
| from mlx.nn.layers.linear import Linear | ||||
| from mlx.utils import tree_flatten, tree_map | ||||
|  | ||||
|  | ||||
| class QuantizedLinear(Module): | ||||
|     """Applies an affine transformation to the input using a quantized weight matrix. | ||||
|  | ||||
|     It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its | ||||
|     parameters are frozen and will not be included in any gradient computation | ||||
|     but this will probably change in the future. | ||||
|  | ||||
|     QuantizedLinear also provides two useful classmethods to convert linear | ||||
|     layers to QuantizedLinear layers. | ||||
|  | ||||
|     - :meth:`from_linear` returns a QuantizedLinear layer that applies the same | ||||
|       linear transformation up to the quantization error. | ||||
|     - :meth:`quantize_module` swaps all the linear layers of the passed module | ||||
|       with QuantizedLinear ones. | ||||
|  | ||||
|     Args: | ||||
|         input_dims (int): The dimensionality of the input features | ||||
|         output_dims (int): The dimensionality of the output features | ||||
|         bias (bool): If set to ``False`` then the layer will not use a bias. | ||||
|             (default: True). | ||||
|         groups (int): The group size to use for the quantized weight. See | ||||
|             :func:`~mlx.core.quantize`. (default: 128) | ||||
|         width (int): The bit width to use for the quantized weight. See | ||||
|             :func:`~mlx.core.quantize`. (default: 4) | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_dims: int, | ||||
|         output_dims: int, | ||||
|         bias: bool = True, | ||||
|         groups: int = 64, | ||||
|         width: int = 4, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         # Quantization config | ||||
|         self.groups = groups | ||||
|         self.width = width | ||||
|  | ||||
|         # Initialize the quantized weight | ||||
|         scale = math.sqrt(1 / input_dims) | ||||
|         weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(output_dims, input_dims), | ||||
|         ) | ||||
|         self.weight, self.scales, self.biases = mx.quantize(weight, groups, width) | ||||
|  | ||||
|         # And bias if needed | ||||
|         if bias: | ||||
|             self.bias = mx.zeros((output_dims,)) | ||||
|  | ||||
|         # Freeze this model's parameters | ||||
|         self.freeze() | ||||
|  | ||||
|     def unfreeze(self, *args, **kwargs): | ||||
|         """Wrap unfreeze so that we unfreeze any layers we might contain but | ||||
|         our parameters will remain frozen.""" | ||||
|         super().unfreeze(*args, **kwargs) | ||||
|         self.freeze(recurse=False) | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         out_dims, in_dims = self.weight.shape | ||||
|         in_dims *= 32 // self.width | ||||
|         return ( | ||||
|             f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}," | ||||
|             f"groups={self.groups}, width={self.width}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         x = mx.quantized_matmul( | ||||
|             x, | ||||
|             self.weight.T, | ||||
|             scales=self.scales, | ||||
|             biases=self.biases, | ||||
|             groups=self.groups, | ||||
|             width=self.width, | ||||
|         ) | ||||
|         if "bias" in self: | ||||
|             x = x + self.bias | ||||
|         return x | ||||
|  | ||||
|     @classmethod | ||||
|     def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4): | ||||
|         """Create a QuantizedLinear layer from the parameters of a provided | ||||
|         linear layer.""" | ||||
|         output_dims, input_dims = linear_layer.weight.shape | ||||
|         ql = cls(input_dims, output_dims, False, groups, width) | ||||
|         ql.weight, ql.scales, ql.biases = mx.quantize( | ||||
|             linear_layer.weight, groups, width | ||||
|         ) | ||||
|         if "bias" in linear_layer: | ||||
|             ql.bias = linear_layer.bias | ||||
|  | ||||
|         return ql | ||||
|  | ||||
|     @classmethod | ||||
|     def quantize_module( | ||||
|         cls, | ||||
|         model: Module, | ||||
|         groups: int = 64, | ||||
|         width: int = 4, | ||||
|         linear_class_predicate=lambda m: isinstance(m, Linear), | ||||
|     ): | ||||
|         def _quantize_if_linear(m): | ||||
|             if linear_class_predicate(m): | ||||
|                 return cls.from_linear(m, groups, width) | ||||
|             else: | ||||
|                 return m | ||||
|  | ||||
|         leaves = model.leaf_modules() | ||||
|         leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module) | ||||
|         model.update_modules(leaves) | ||||
| @@ -445,13 +445,13 @@ class Adamax(Adam): | ||||
|  | ||||
|  | ||||
| class Lion(Optimizer): | ||||
|     r"""Implementation of the Lion optimizer [1].  | ||||
|     r"""Implementation of the Lion optimizer [1]. | ||||
|  | ||||
|     Since updates are computed through the sign operation, they tend to  | ||||
|     have larger norm than for other optimizers such as SGD and Adam.  | ||||
|     We recommend a learning rate that is 3-10x smaller than AdamW and a  | ||||
|     weight decay 3-10x larger than AdamW to maintain the strength  | ||||
|     (lr * wd). Our Lion implementation follows the original paper. In  | ||||
|     Since updates are computed through the sign operation, they tend to | ||||
|     have larger norm than for other optimizers such as SGD and Adam. | ||||
|     We recommend a learning rate that is 3-10x smaller than AdamW and a | ||||
|     weight decay 3-10x larger than AdamW to maintain the strength | ||||
|     (lr * wd). Our Lion implementation follows the original paper. In | ||||
|     detail, | ||||
|  | ||||
|     [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv | ||||
| @@ -486,7 +486,7 @@ class Lion(Optimizer): | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|         """Performs the Lion parameter update and stores :math:`m`  | ||||
|         """Performs the Lion parameter update and stores :math:`m` | ||||
|         in the optimizer state.""" | ||||
|         lr = self.learning_rate | ||||
|         b1, b2 = self.betas | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
|  | ||||
| def tree_map(fn, tree, *rest): | ||||
| def tree_map(fn, tree, *rest, is_leaf=None): | ||||
|     """Applies ``fn`` to the leaves of the python tree ``tree`` and | ||||
|     returns a new collection with the results. | ||||
|  | ||||
| @@ -10,6 +10,9 @@ def tree_map(fn, tree, *rest): | ||||
|     ``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap` | ||||
|     than to :func:`map`. | ||||
|  | ||||
|     The keyword argument ``is_leaf`` decides what constitutes a leaf from | ||||
|     ``tree`` similar to :func:`tree_flatten`. | ||||
|  | ||||
|     .. code-block:: python | ||||
|  | ||||
|         import mlx.nn as nn | ||||
| @@ -26,21 +29,28 @@ def tree_map(fn, tree, *rest): | ||||
|         fn (Callable): The function that processes the leaves of the tree | ||||
|         tree (Any): The main python tree that will be iterated upon | ||||
|         rest (Tuple[Any]): Extra trees to be iterated together with tree | ||||
|         is_leaf (Optional[Callable]): An optional callable that returns True if | ||||
|             the passed object is considered a leaf or False otherwise. | ||||
|  | ||||
|     Returns: | ||||
|         A python tree with the new values returned by ``fn``. | ||||
|     """ | ||||
|     if isinstance(tree, list): | ||||
|     if is_leaf is not None and is_leaf(tree): | ||||
|         return fn(tree, *rest) | ||||
|     elif isinstance(tree, list): | ||||
|         return [ | ||||
|             tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree) | ||||
|             tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) | ||||
|             for i, child in enumerate(tree) | ||||
|         ] | ||||
|     elif isinstance(tree, tuple): | ||||
|         return tuple( | ||||
|             tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree) | ||||
|             tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) | ||||
|             for i, child in enumerate(tree) | ||||
|         ) | ||||
|     elif isinstance(tree, dict): | ||||
|         return { | ||||
|             k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items() | ||||
|             k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf) | ||||
|             for k, child in tree.items() | ||||
|         } | ||||
|     else: | ||||
|         return fn(tree, *rest) | ||||
|   | ||||
| @@ -3035,4 +3035,101 @@ void init_ops(py::module_& m) { | ||||
|         Returns: | ||||
|           result (array): The result of the multiplication of ``x`` with ``w``. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "quantize", | ||||
|       &quantize, | ||||
|       "w"_a, | ||||
|       py::pos_only(), | ||||
|       "groups"_a = 128, | ||||
|       "width"_a = 4, | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array] | ||||
|  | ||||
|         Quantize the matrix ``w`` using ``width`` bits per element. | ||||
|  | ||||
|         Note, every ``groups`` elements in a row of ``w`` are quantized | ||||
|         together. Hence, number of columns of ``w`` should be divisible by | ||||
|         ``groups``. In particular, the rows of ``w`` are divided into groups of | ||||
|         size ``groups`` which are quantized together. | ||||
|  | ||||
|         .. warning:: | ||||
|  | ||||
|           ``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32 | ||||
|  | ||||
|         Formally, for a group of :math:`g` consecutive elements :math:`w_1` to | ||||
|         :math:`w_g` in a row of ``w`` we compute the quantized representation | ||||
|         of each element :math:`\hat{w_i}` as follows | ||||
|  | ||||
|         .. math:: | ||||
|  | ||||
|           \begin{aligned} | ||||
|             \alpha &= \max_i w_i \\ | ||||
|             \beta &= \min_i w_i \\ | ||||
|             s &= \frac{\alpha - \beta}{2^b - 1} \\ | ||||
|             \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). | ||||
|           \end{aligned} | ||||
|  | ||||
|         After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits | ||||
|         and is packed in an unsigned 32-bit integer from the lower to upper | ||||
|         bits. For instance, for 4-bit quantization we fit 8 elements in an | ||||
|         unsigned 32 bit integer where the 1st element occupies the 4 least | ||||
|         significant bits, the 2nd bits 4-7 etc. | ||||
|  | ||||
|         In order to be able to dequantize the elements of ``w`` we also need to | ||||
|         save :math:`s` and :math:`\beta` which are the returned ``scales`` and | ||||
|         ``biases`` respectively. | ||||
|  | ||||
|         Args: | ||||
|           w (array): Matrix to be quantized | ||||
|           groups (int, optional): The size of the group in ``w`` that shares a | ||||
|             scale and bias. (default: 128) | ||||
|           width (int, optional): The bitwidth of the elements in ``w``. | ||||
|             (default: 4) | ||||
|  | ||||
|         Returns: | ||||
|           (tuple): A tuple containing | ||||
|  | ||||
|             - w_q (array): The quantized version of ``w`` | ||||
|             - scales (array): The scale to multiply each element with, namely :math:`s` | ||||
|             - biases (array): The biases to add to each element, namely :math:`\beta` | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "dequantize", | ||||
|       &dequantize, | ||||
|       "w"_a, | ||||
|       py::pos_only(), | ||||
|       "scales"_a, | ||||
|       "biases"_a, | ||||
|       "groups"_a = 128, | ||||
|       "width"_a = 4, | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|         Dequantize the matrix ``w`` using the provided ``scales`` and | ||||
|         ``biases`` and the ``groups`` and ``width`` configuration. | ||||
|  | ||||
|         Formally, given the notation in :func:`quantize`, we compute | ||||
|         :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and | ||||
|         :math:`\beta` as follows | ||||
|  | ||||
|         .. math:: | ||||
|  | ||||
|           w_i = s \hat{w_i} - \beta | ||||
|  | ||||
|         Args: | ||||
|           w (array): Matrix to be quantized | ||||
|           scales (array): The scales to use per ``groups`` elements of ``w`` | ||||
|           biases (array): The biases to use per ``groups`` elements of ``w`` | ||||
|           groups (int, optional): The size of the group in ``w`` that shares a | ||||
|             scale and bias. (default: 128) | ||||
|           width (int, optional): The bitwidth of the elements in ``w``. | ||||
|             (default: 4) | ||||
|  | ||||
|         Returns: | ||||
|           result (array): The dequantized version of w | ||||
|       )pbdoc"); | ||||
| } | ||||
|   | ||||
| @@ -6,48 +6,14 @@ import mlx.core as mx | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| def select_bits(w, width, start): | ||||
|     shift_left = 32 - (start + width) | ||||
|     shift_right = shift_left + start | ||||
|     return (w * (2**shift_left)) // (2**shift_right) | ||||
|  | ||||
|  | ||||
| def dequantize(w, scales, biases, width): | ||||
|     w_full = mx.concatenate( | ||||
|         [select_bits(w, width, i)[..., None] for i in range(0, 32, width)], axis=-1 | ||||
|     ) | ||||
|     w_full = w_full.reshape(len(w), scales.shape[-1], -1) | ||||
|     w_full = scales[..., None] * w_full + biases[..., None] | ||||
|     w_full = w_full.reshape(len(w), -1) | ||||
|  | ||||
|     return w_full | ||||
|  | ||||
|  | ||||
| def quantize(w, width, groups): | ||||
|     w = w.reshape(len(w), -1, groups) | ||||
|     w_max = w.max(-1, keepdims=True) | ||||
|     w_min = w.min(-1, keepdims=True) | ||||
|     delta = (w_max - w_min) / (2**width - 1) | ||||
|  | ||||
|     w_int = mx.round((w - w_min) / delta).astype(mx.uint32) | ||||
|     scales = delta.squeeze(-1) | ||||
|     biases = w_min.squeeze(-1) | ||||
|  | ||||
|     shifts = mx.array([2**i for i in range(0, 32, width)], dtype=mx.uint32) | ||||
|     w_int = w_int.reshape(len(w), -1, 32 // width) | ||||
|     w_int = w_int * shifts[None, None] | ||||
|     packed_w = w_int.sum(-1) | ||||
|  | ||||
|     return packed_w, scales, biases | ||||
|  | ||||
|  | ||||
| class TestQuantized(mlx_tests.MLXTestCase): | ||||
|     def test_quantize_dequantize(self): | ||||
|         w = mx.random.normal(shape=(128, 128)) | ||||
|         w_q, scales, biases = quantize(w, 4, 64) | ||||
|         w_hat = dequantize(w_q, scales, biases, 4) | ||||
|         w_hat2 = dequantize(*quantize(w_hat, 4, 64), 4) | ||||
|         self.assertLess((w_hat - w_hat2).abs().max(), 1e-6) | ||||
|         for b in [2, 4, 8]: | ||||
|             w_q, scales, biases = mx.quantize(w, 64, b) | ||||
|             w_hat = mx.dequantize(w_q, scales, biases, 64, b) | ||||
|             errors = (w - w_hat).abs().reshape(*scales.shape, -1) | ||||
|             self.assertTrue((errors <= scales[..., None] / 2).all()) | ||||
|  | ||||
|     def test_qmm(self): | ||||
|         key = mx.random.key(0) | ||||
| @@ -62,14 +28,16 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|                             ): | ||||
|                                 x = mx.random.normal(shape=(M, K), key=k1) | ||||
|                                 w = mx.random.normal(shape=(N, K), key=k2) | ||||
|                                 w_q, scales, biases = quantize(w, width, groups) | ||||
|                                 w_hat = dequantize(w_q, scales, biases, width) | ||||
|                                 w_q, scales, biases = mx.quantize(w, groups, width) | ||||
|                                 w_hat = mx.dequantize( | ||||
|                                     w_q, scales, biases, groups, width | ||||
|                                 ) | ||||
|                                 y_q = mx.quantized_matmul( | ||||
|                                     x, w_q.T, scales, biases, width=width, groups=groups | ||||
|                                 ) | ||||
|                                 y_hat = x @ w_hat.T | ||||
|                                 self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                                 self.assertLess((y_q - y_hat).abs().max(), 0.1) | ||||
|                                 self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
|     def test_qmm_shapes(self): | ||||
|         key = mx.random.key(0) | ||||
| @@ -77,8 +45,8 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         groups = 64 | ||||
|         width = 4 | ||||
|         w = mx.random.normal(shape=(32, 128), key=k2) | ||||
|         w_q, scales, biases = quantize(w, width, groups) | ||||
|         w_hat = dequantize(w_q, scales, biases, width) | ||||
|         w_q, scales, biases = mx.quantize(w, groups, width) | ||||
|         w_hat = mx.dequantize(w_q, scales, biases, groups, width) | ||||
|         for s in [(3, 128), (2, 1, 7, 128)]: | ||||
|             x = mx.random.normal(shape=(3, 128), key=k1) | ||||
|             y_q = mx.quantized_matmul( | ||||
| @@ -86,7 +54,7 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|             ) | ||||
|             y_hat = x @ w_hat.T | ||||
|             self.assertEqual(y_q.shape, y_hat.shape) | ||||
|             self.assertLess((y_q - y_hat).abs().max(), 0.1) | ||||
|             self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
|     def test_qmv(self): | ||||
|         key = mx.random.key(0) | ||||
| @@ -95,17 +63,17 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|             for width in [2, 4, 8]: | ||||
|                 for M in [512, 1024]: | ||||
|                     for N in [512, 1024]: | ||||
|                         # with self.subTest(shape=(M, N), groups=groups, width=width): | ||||
|                         x = mx.random.normal(shape=(1, N), key=k1) | ||||
|                         w = mx.random.normal(shape=(M, N), key=k2) | ||||
|                         w_q, scales, biases = quantize(w, width, groups) | ||||
|                         w_hat = dequantize(w_q, scales, biases, width) | ||||
|                         y_q = mx.quantized_matmul( | ||||
|                             x, w_q.T, scales, biases, width=width, groups=groups | ||||
|                         ) | ||||
|                         y_hat = x @ w_hat.T | ||||
|                         self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                         self.assertLess((y_q - y_hat).abs().max(), 0.1) | ||||
|                         with self.subTest(shape=(M, N), groups=groups, width=width): | ||||
|                             x = mx.random.normal(shape=(1, N), key=k1) | ||||
|                             w = mx.random.normal(shape=(M, N), key=k2) | ||||
|                             w_q, scales, biases = mx.quantize(w, groups, width) | ||||
|                             w_hat = mx.dequantize(w_q, scales, biases, groups, width) | ||||
|                             y_q = mx.quantized_matmul( | ||||
|                                 x, w_q.T, scales, biases, width=width, groups=groups | ||||
|                             ) | ||||
|                             y_hat = x @ w_hat.T | ||||
|                             self.assertEqual(y_q.shape, y_hat.shape) | ||||
|                             self.assertLess((y_q - y_hat).abs().max(), 1e-3) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user