mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Quantize embedding (#994)
* quantize embedding * rename as_linear + comment * consistency in docs * fix test
This commit is contained in:
		| @@ -61,7 +61,7 @@ from mlx.nn.layers.normalization import ( | ||||
| ) | ||||
| from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d | ||||
| from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding | ||||
| from mlx.nn.layers.quantized import QuantizedLinear | ||||
| from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize | ||||
| from mlx.nn.layers.recurrent import GRU, LSTM, RNN | ||||
| from mlx.nn.layers.transformer import ( | ||||
|     MultiHeadAttention, | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| import math | ||||
|  | ||||
| @@ -14,7 +14,7 @@ class Embedding(Module): | ||||
|  | ||||
|     Args: | ||||
|         num_embeddings (int): How many possible discrete tokens can we embed. | ||||
|                               Usually called the vocabulary size. | ||||
|            Usually called the vocabulary size. | ||||
|         dims (int): The dimensionality of the embeddings. | ||||
|     """ | ||||
|  | ||||
| @@ -28,3 +28,12 @@ class Embedding(Module): | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         return self.weight[x] | ||||
|  | ||||
|     def as_linear(self, x): | ||||
|         """ | ||||
|         Call the embedding layer as a linear layer. | ||||
|  | ||||
|         Use this for example when input embedding and output projection | ||||
|         weights are tied. | ||||
|         """ | ||||
|         return x @ self.weight.T | ||||
|   | ||||
| @@ -1,11 +1,143 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| import math | ||||
| from typing import Callable, Optional | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
| from mlx.nn.layers.embedding import Embedding | ||||
| from mlx.nn.layers.linear import Linear | ||||
| from mlx.utils import tree_flatten, tree_map | ||||
| from mlx.utils import tree_map_with_path | ||||
|  | ||||
|  | ||||
| def quantize( | ||||
|     model: Module, | ||||
|     group_size: int = 64, | ||||
|     bits: int = 4, | ||||
|     class_predicate: Optional[callable] = None, | ||||
| ): | ||||
|     """Quantize the sub-modules of a module according to a predicate. | ||||
|  | ||||
|     By default all :obj:`Linear` and :obj:`Embedding` layers will be | ||||
|     quantized. Note also, the module is updated in-place. | ||||
|  | ||||
|     Args: | ||||
|         model (mlx.nn.Module): The model whose leaf modules may be quantized. | ||||
|         group_size (int): The quantization group size (see | ||||
|            :func:`mlx.core.quantize`). Default: ``64``. | ||||
|         bits (int): The number of bits per parameter (see | ||||
|            :func:`mlx.core.quantize`). Default: ``4``. | ||||
|         class_predicate (Optional[Callable]): A callable which receives the | ||||
|           :obj:`Module` path and :obj:`Module` itself and returns ``True`` if | ||||
|           it should be quantized and ``False`` otherwise. If ``None``, then | ||||
|           all linear and embedding layers are quantized. Default: ``None``. | ||||
|     """ | ||||
|     class_predicate = class_predicate or ( | ||||
|         lambda _, m: isinstance(m, (Linear, Embedding)) | ||||
|     ) | ||||
|  | ||||
|     def _maybe_quantize(path, m): | ||||
|         if class_predicate(path, m): | ||||
|             if isinstance(m, Linear): | ||||
|                 return QuantizedLinear.from_linear(m, group_size, bits) | ||||
|             elif isinstance(m, Embedding): | ||||
|                 return QuantizedEmbedding.from_embedding(m, group_size, bits) | ||||
|             else: | ||||
|                 raise ValueError(f"Unable to quantize model of type {type(m)}") | ||||
|         else: | ||||
|             return m | ||||
|  | ||||
|     leaves = model.leaf_modules() | ||||
|     leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=Module.is_module) | ||||
|     model.update_modules(leaves) | ||||
|  | ||||
|  | ||||
| class QuantizedEmbedding(Module): | ||||
|     """The same as :obj:`Embedding` but with a  quantized weight matrix. | ||||
|  | ||||
|     :obj:`QuantizedEmbedding` also provides a :meth:`from_embedding` | ||||
|     classmethod to convert embedding layers to :obj:`QuantizedEmbedding` | ||||
|     layers. | ||||
|  | ||||
|     Args: | ||||
|         num_embeddings (int): How many possible discrete tokens can we embed. | ||||
|            Usually called the vocabulary size. | ||||
|         dims (int): The dimensionality of the embeddings. | ||||
|         group_size (int, optional): The group size to use for the quantized | ||||
|             weight. See :func:`~mlx.core.quantize`. Default: ``64``. | ||||
|         bits (int, optional): The bit width to use for the quantized weight. | ||||
|             See :func:`~mlx.core.quantize`. Default: ``4``. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_embeddings: int, | ||||
|         dims: int, | ||||
|         group_size: int = 64, | ||||
|         bits: int = 4, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         # Quantization config | ||||
|         self.group_size = group_size | ||||
|         self.bits = bits | ||||
|  | ||||
|         # Initialize the quantized weight | ||||
|         scale = math.sqrt(1 / dims) | ||||
|         weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) | ||||
|         self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) | ||||
|         self.num_embeddings = num_embeddings | ||||
|         self.dims = dims | ||||
|  | ||||
|         # Freeze this model's parameters | ||||
|         self.freeze() | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         s = x.shape | ||||
|         x = x.flatten() | ||||
|         out = mx.dequantize( | ||||
|             self["weight"][x], | ||||
|             scales=self["scales"][x], | ||||
|             biases=self["biases"][x], | ||||
|             group_size=self.group_size, | ||||
|             bits=self.bits, | ||||
|         ) | ||||
|         return out.reshape(*s, -1) | ||||
|  | ||||
|     def as_linear(self, x): | ||||
|         """ | ||||
|         Call the quantized embedding layer as a quantized linear layer. | ||||
|  | ||||
|         Use this for example when input embedding and output projection | ||||
|         weights are tied. | ||||
|         """ | ||||
|         return mx.quantized_matmul( | ||||
|             x, | ||||
|             self["weight"], | ||||
|             scales=self["scales"], | ||||
|             biases=self["biases"], | ||||
|             transpose=True, | ||||
|             group_size=self.group_size, | ||||
|             bits=self.bits, | ||||
|         ) | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return ( | ||||
|             f"{self.num_embeddings}, {self.dims}, " | ||||
|             f"group_size={self.group_size}, bits={self.bits}" | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     def from_embedding( | ||||
|         cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 | ||||
|     ): | ||||
|         """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" | ||||
|         embedding_dims, dims = embedding_layer.weight.shape | ||||
|         ql = cls(embedding_dims, dims, group_size, bits) | ||||
|         ql.weight, ql.scales, ql.biases = mx.quantize( | ||||
|             embedding_layer.weight, group_size, bits | ||||
|         ) | ||||
|         return ql | ||||
|  | ||||
|  | ||||
| class QuantizedLinear(Module): | ||||
| @@ -15,23 +147,18 @@ class QuantizedLinear(Module): | ||||
|     parameters are frozen and will not be included in any gradient computation | ||||
|     but this will probably change in the future. | ||||
|  | ||||
|     QuantizedLinear also provides two useful classmethods to convert linear | ||||
|     layers to QuantizedLinear layers. | ||||
|  | ||||
|     - :meth:`from_linear` returns a QuantizedLinear layer that applies the same | ||||
|       linear transformation up to the quantization error. | ||||
|     - :meth:`quantize_module` swaps all the linear layers of the passed module | ||||
|       with QuantizedLinear ones. | ||||
|     :obj:`QuantizedLinear` also provides a classmethod :meth:`from_linear` to | ||||
|     convert linear layers to :obj:`QuantizedLinear` layers. | ||||
|  | ||||
|     Args: | ||||
|         input_dims (int): The dimensionality of the input features | ||||
|         output_dims (int): The dimensionality of the output features | ||||
|         input_dims (int): The dimensionality of the input features. | ||||
|         output_dims (int): The dimensionality of the output features. | ||||
|         bias (bool, optional): If set to ``False`` then the layer will not use | ||||
|             a bias. (default: True). | ||||
|             a bias. Default: ``True``. | ||||
|         group_size (int, optional): The group size to use for the quantized | ||||
|             weight. See :func:`~mlx.core.quantize`. (default: 64) | ||||
|             weight. See :func:`~mlx.core.quantize`. Default: ``64``. | ||||
|         bits (int, optional): The bit width to use for the quantized weight. | ||||
|             See :func:`~mlx.core.quantize`. (default: 4) | ||||
|             See :func:`~mlx.core.quantize`. Default: ``4``. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
| @@ -94,8 +221,7 @@ class QuantizedLinear(Module): | ||||
|  | ||||
|     @classmethod | ||||
|     def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): | ||||
|         """Create a QuantizedLinear layer from the parameters of a provided | ||||
|         linear layer.""" | ||||
|         """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" | ||||
|         output_dims, input_dims = linear_layer.weight.shape | ||||
|         ql = cls(input_dims, output_dims, False, group_size, bits) | ||||
|         ql.weight, ql.scales, ql.biases = mx.quantize( | ||||
| @@ -105,21 +231,3 @@ class QuantizedLinear(Module): | ||||
|             ql.bias = linear_layer.bias | ||||
|  | ||||
|         return ql | ||||
|  | ||||
|     @classmethod | ||||
|     def quantize_module( | ||||
|         cls, | ||||
|         model: Module, | ||||
|         group_size: int = 64, | ||||
|         bits: int = 4, | ||||
|         linear_class_predicate=lambda m: isinstance(m, Linear), | ||||
|     ): | ||||
|         def _quantize_if_linear(m): | ||||
|             if linear_class_predicate(m): | ||||
|                 return cls.from_linear(m, group_size, bits) | ||||
|             else: | ||||
|                 return m | ||||
|  | ||||
|         leaves = model.leaf_modules() | ||||
|         leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module) | ||||
|         model.update_modules(leaves) | ||||
|   | ||||
| @@ -3,7 +3,7 @@ from collections import defaultdict | ||||
|  | ||||
|  | ||||
| def tree_map(fn, tree, *rest, is_leaf=None): | ||||
|     """Applies ``fn`` to the leaves of the python tree ``tree`` and | ||||
|     """Applies ``fn`` to the leaves of the Python tree ``tree`` and | ||||
|     returns a new collection with the results. | ||||
|  | ||||
|     If ``rest`` is provided, every item is assumed to be a superset of ``tree`` | ||||
| @@ -27,14 +27,14 @@ def tree_map(fn, tree, *rest, is_leaf=None): | ||||
|         model.update(tree_map(lambda x: x*x, model.parameters())) | ||||
|  | ||||
|     Args: | ||||
|         fn (Callable): The function that processes the leaves of the tree | ||||
|         tree (Any): The main python tree that will be iterated upon | ||||
|         rest (Tuple[Any]): Extra trees to be iterated together with tree | ||||
|         is_leaf (Optional[Callable]): An optional callable that returns True if | ||||
|             the passed object is considered a leaf or False otherwise. | ||||
|         fn (callable): The function that processes the leaves of the tree. | ||||
|         tree (Any): The main Python tree that will be iterated upon. | ||||
|         rest (tuple[Any]): Extra trees to be iterated together with ``tree``. | ||||
|         is_leaf (callable, optional): An optional callable that returns ``True`` | ||||
|            if the passed object is considered a leaf or ``False`` otherwise. | ||||
|  | ||||
|     Returns: | ||||
|         A python tree with the new values returned by ``fn``. | ||||
|         A Python tree with the new values returned by ``fn``. | ||||
|     """ | ||||
|     if is_leaf is not None and is_leaf(tree): | ||||
|         return fn(tree, *rest) | ||||
| @@ -53,8 +53,57 @@ def tree_map(fn, tree, *rest, is_leaf=None): | ||||
|         return fn(tree, *rest) | ||||
|  | ||||
|  | ||||
| def tree_map_with_path(fn, tree, *rest, is_leaf=None, path=None): | ||||
|     """Applies ``fn`` to the path and leaves of the Python tree ``tree`` and | ||||
|     returns a new collection with the results. | ||||
|  | ||||
|     This function is the same :func:`tree_map` but the ``fn`` takes the path as | ||||
|     the first argument followed by the remaining tree nodes. | ||||
|  | ||||
|     Args: | ||||
|         fn (callable): The function that processes the leaves of the tree. | ||||
|         tree (Any): The main Python tree that will be iterated upon. | ||||
|         rest (tuple[Any]): Extra trees to be iterated together with ``tree``. | ||||
|         is_leaf (callable, optional): An optional callable that returns ``True`` | ||||
|            if the passed object is considered a leaf or ``False`` otherwise. | ||||
|  | ||||
|     Returns: | ||||
|         A Python tree with the new values returned by ``fn``. | ||||
|  | ||||
|     Example: | ||||
|         >>> from mlx.utils import tree_map_with_path | ||||
|         >>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]} | ||||
|         >>> new_tree = tree_map_with_path(lambda path, _: print(path), tree) | ||||
|         model.0.w | ||||
|         model.0.b | ||||
|         model.1.w | ||||
|         model.1.b | ||||
|     """ | ||||
|     if is_leaf is not None and is_leaf(tree): | ||||
|         return fn(path, tree, *rest) | ||||
|     elif isinstance(tree, (list, tuple)): | ||||
|         prefix = f"{path}." if path else "" | ||||
|         TreeType = type(tree) | ||||
|         return TreeType( | ||||
|             tree_map_with_path( | ||||
|                 fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f"{prefix}{i}" | ||||
|             ) | ||||
|             for i, child in enumerate(tree) | ||||
|         ) | ||||
|     elif isinstance(tree, dict): | ||||
|         prefix = f"{path}." if path else "" | ||||
|         return { | ||||
|             k: tree_map_with_path( | ||||
|                 fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f"{prefix}{k}" | ||||
|             ) | ||||
|             for k, child in tree.items() | ||||
|         } | ||||
|     else: | ||||
|         return fn(path, tree, *rest) | ||||
|  | ||||
|  | ||||
| def tree_flatten(tree, prefix="", is_leaf=None): | ||||
|     """Flattens a python tree to a list of key, value tuples. | ||||
|     """Flattens a Python tree to a list of key, value tuples. | ||||
|  | ||||
|     The keys are using the dot notation to define trees of arbitrary depth and | ||||
|     complexity. | ||||
| @@ -70,17 +119,17 @@ def tree_flatten(tree, prefix="", is_leaf=None): | ||||
|         # [("hello.0.0.0", 0)] | ||||
|  | ||||
|     .. note:: | ||||
|        Dictionaries should have keys that are valid python identifiers. | ||||
|        Dictionaries should have keys that are valid Python identifiers. | ||||
|  | ||||
|     Args: | ||||
|         tree (Any): The python tree to be flattened. | ||||
|         tree (Any): The Python tree to be flattened. | ||||
|         prefix (str): A prefix to use for the keys. The first character is | ||||
|             always discarded. | ||||
|         is_leaf (Callable): An optional callable that returns True if the | ||||
|         is_leaf (callable): An optional callable that returns True if the | ||||
|             passed object is considered a leaf or False otherwise. | ||||
|  | ||||
|     Returns: | ||||
|         List[Tuple[str, Any]]: The flat representation of the python tree. | ||||
|         List[Tuple[str, Any]]: The flat representation of the Python tree. | ||||
|     """ | ||||
|     flat_tree = [] | ||||
|  | ||||
| @@ -98,7 +147,7 @@ def tree_flatten(tree, prefix="", is_leaf=None): | ||||
|  | ||||
|  | ||||
| def tree_unflatten(tree): | ||||
|     """Recreate a python tree from its flat representation. | ||||
|     """Recreate a Python tree from its flat representation. | ||||
|  | ||||
|     .. code-block:: python | ||||
|  | ||||
| @@ -109,11 +158,11 @@ def tree_unflatten(tree): | ||||
|         # {"hello": {"world": 42}} | ||||
|  | ||||
|     Args: | ||||
|         tree (List[Tuple[str, Any]]): The flat representation of a python tree. | ||||
|                                       For instance as returned by :meth:`tree_flatten`. | ||||
|         tree (list[tuple[str, Any]]): The flat representation of a Python tree. | ||||
|            For instance as returned by :meth:`tree_flatten`. | ||||
|  | ||||
|     Returns: | ||||
|         A python tree. | ||||
|         A Python tree. | ||||
|     """ | ||||
|     if len(tree) == 1 and tree[0][0] == "": | ||||
|         return tree[0][1] | ||||
|   | ||||
| @@ -172,6 +172,19 @@ class TestBase(mlx_tests.MLXTestCase): | ||||
|         self.assertFalse(m.update(params_dict).eval()._training) | ||||
|         self.assertTrue(m.train()._training) | ||||
|  | ||||
|     def test_quantize(self): | ||||
|         m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256)) | ||||
|         nn.quantize(m) | ||||
|         self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding)) | ||||
|         self.assertTrue(isinstance(m.layers[1], nn.ReLU)) | ||||
|         self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) | ||||
|  | ||||
|         m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256)) | ||||
|         nn.quantize(m, class_predicate=lambda _, m: isinstance(m, nn.Linear)) | ||||
|         self.assertTrue(isinstance(m.layers[0], nn.Embedding)) | ||||
|         self.assertTrue(isinstance(m.layers[1], nn.ReLU)) | ||||
|         self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) | ||||
|  | ||||
|  | ||||
| class TestLayers(mlx_tests.MLXTestCase): | ||||
|     def test_identity(self): | ||||
| @@ -1606,6 +1619,19 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(h_out.shape, (44, 12)) | ||||
|         self.assertEqual(c_out.shape, (44, 12)) | ||||
|  | ||||
|     def test_quantized_embedding(self): | ||||
|         emb = nn.Embedding(32, 256) | ||||
|         qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8) | ||||
|         x = mx.array([2, 6, 9, 3, 0, 3]) | ||||
|         y = emb(x) | ||||
|         yq = qemb(x) | ||||
|         self.assertLess((y - yq).abs().max(), 1e-3) | ||||
|  | ||||
|         x = mx.random.uniform(shape=(2, 256)) | ||||
|         y = emb.as_linear(x) | ||||
|         yq = qemb.as_linear(x) | ||||
|         self.assertLess((y - yq).abs().max(), 1e-2) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun