diff --git a/docs/src/_templates/nn-module-template.rst b/docs/src/_templates/nn-module-template.rst new file mode 100644 index 000000000..9e8b5cc74 --- /dev/null +++ b/docs/src/_templates/nn-module-template.rst @@ -0,0 +1,20 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + + {% block methods %} + + {% if methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + {% for item in methods %} + {%- if item not in inherited_members and item != "__init__" %} + ~{{ name }}.{{ item }} + {%- endif %} + {%- endfor %} + {% endif %} + {% endblock %} + diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 2a253ab25..229d295cb 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -173,6 +173,7 @@ In detail: :toctree: _autosummary value_and_grad + quantize .. toctree:: diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index c0b59b6d4..6fb624d54 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -31,6 +31,7 @@ Layers Mish MultiHeadAttention PReLU + QuantizedEmbedding QuantizedLinear RMSNorm ReLU @@ -43,4 +44,4 @@ Layers Softshrink Step Transformer - Upsample \ No newline at end of file + Upsample diff --git a/docs/src/python/tree_utils.rst b/docs/src/python/tree_utils.rst index 84d5afa9b..dbd0ebce9 100644 --- a/docs/src/python/tree_utils.rst +++ b/docs/src/python/tree_utils.rst @@ -19,3 +19,4 @@ return python trees will be using the default python ``dict``, ``list`` and tree_flatten tree_unflatten tree_map + tree_map_with_path diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 3b0856b30..fce721a06 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -61,7 +61,7 @@ from mlx.nn.layers.normalization import ( ) from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding -from mlx.nn.layers.quantized import QuantizedLinear +from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize from mlx.nn.layers.recurrent import GRU, LSTM, RNN from mlx.nn.layers.transformer import ( MultiHeadAttention, diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index 18482eddc..a8327a280 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math @@ -14,7 +14,7 @@ class Embedding(Module): Args: num_embeddings (int): How many possible discrete tokens can we embed. - Usually called the vocabulary size. + Usually called the vocabulary size. dims (int): The dimensionality of the embeddings. """ @@ -28,3 +28,12 @@ class Embedding(Module): def __call__(self, x): return self.weight[x] + + def as_linear(self, x): + """ + Call the embedding layer as a linear layer. + + Use this for example when input embedding and output projection + weights are tied. + """ + return x @ self.weight.T diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 15eccf0b1..08910467d 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -1,11 +1,143 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math +from typing import Callable, Optional import mlx.core as mx from mlx.nn.layers.base import Module +from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear -from mlx.utils import tree_flatten, tree_map +from mlx.utils import tree_map_with_path + + +def quantize( + model: Module, + group_size: int = 64, + bits: int = 4, + class_predicate: Optional[callable] = None, +): + """Quantize the sub-modules of a module according to a predicate. + + By default all :obj:`Linear` and :obj:`Embedding` layers will be + quantized. Note also, the module is updated in-place. + + Args: + model (mlx.nn.Module): The model whose leaf modules may be quantized. + group_size (int): The quantization group size (see + :func:`mlx.core.quantize`). Default: ``64``. + bits (int): The number of bits per parameter (see + :func:`mlx.core.quantize`). Default: ``4``. + class_predicate (Optional[Callable]): A callable which receives the + :obj:`Module` path and :obj:`Module` itself and returns ``True`` if + it should be quantized and ``False`` otherwise. If ``None``, then + all linear and embedding layers are quantized. Default: ``None``. + """ + class_predicate = class_predicate or ( + lambda _, m: isinstance(m, (Linear, Embedding)) + ) + + def _maybe_quantize(path, m): + if class_predicate(path, m): + if isinstance(m, Linear): + return QuantizedLinear.from_linear(m, group_size, bits) + elif isinstance(m, Embedding): + return QuantizedEmbedding.from_embedding(m, group_size, bits) + else: + raise ValueError(f"Unable to quantize model of type {type(m)}") + else: + return m + + leaves = model.leaf_modules() + leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=Module.is_module) + model.update_modules(leaves) + + +class QuantizedEmbedding(Module): + """The same as :obj:`Embedding` but with a quantized weight matrix. + + :obj:`QuantizedEmbedding` also provides a :meth:`from_embedding` + classmethod to convert embedding layers to :obj:`QuantizedEmbedding` + layers. + + Args: + num_embeddings (int): How many possible discrete tokens can we embed. + Usually called the vocabulary size. + dims (int): The dimensionality of the embeddings. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``64``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + """ + + def __init__( + self, + num_embeddings: int, + dims: int, + group_size: int = 64, + bits: int = 4, + ): + super().__init__() + + # Quantization config + self.group_size = group_size + self.bits = bits + + # Initialize the quantized weight + scale = math.sqrt(1 / dims) + weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + self.num_embeddings = num_embeddings + self.dims = dims + + # Freeze this model's parameters + self.freeze() + + def __call__(self, x): + s = x.shape + x = x.flatten() + out = mx.dequantize( + self["weight"][x], + scales=self["scales"][x], + biases=self["biases"][x], + group_size=self.group_size, + bits=self.bits, + ) + return out.reshape(*s, -1) + + def as_linear(self, x): + """ + Call the quantized embedding layer as a quantized linear layer. + + Use this for example when input embedding and output projection + weights are tied. + """ + return mx.quantized_matmul( + x, + self["weight"], + scales=self["scales"], + biases=self["biases"], + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + + def _extra_repr(self): + return ( + f"{self.num_embeddings}, {self.dims}, " + f"group_size={self.group_size}, bits={self.bits}" + ) + + @classmethod + def from_embedding( + cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 + ): + """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" + embedding_dims, dims = embedding_layer.weight.shape + ql = cls(embedding_dims, dims, group_size, bits) + ql.weight, ql.scales, ql.biases = mx.quantize( + embedding_layer.weight, group_size, bits + ) + return ql class QuantizedLinear(Module): @@ -15,23 +147,18 @@ class QuantizedLinear(Module): parameters are frozen and will not be included in any gradient computation but this will probably change in the future. - QuantizedLinear also provides two useful classmethods to convert linear - layers to QuantizedLinear layers. - - - :meth:`from_linear` returns a QuantizedLinear layer that applies the same - linear transformation up to the quantization error. - - :meth:`quantize_module` swaps all the linear layers of the passed module - with QuantizedLinear ones. + :obj:`QuantizedLinear` also provides a classmethod :meth:`from_linear` to + convert linear layers to :obj:`QuantizedLinear` layers. Args: - input_dims (int): The dimensionality of the input features - output_dims (int): The dimensionality of the output features + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. bias (bool, optional): If set to ``False`` then the layer will not use - a bias. (default: True). + a bias. Default: ``True``. group_size (int, optional): The group size to use for the quantized - weight. See :func:`~mlx.core.quantize`. (default: 64) + weight. See :func:`~mlx.core.quantize`. Default: ``64``. bits (int, optional): The bit width to use for the quantized weight. - See :func:`~mlx.core.quantize`. (default: 4) + See :func:`~mlx.core.quantize`. Default: ``4``. """ def __init__( @@ -94,8 +221,7 @@ class QuantizedLinear(Module): @classmethod def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): - """Create a QuantizedLinear layer from the parameters of a provided - linear layer.""" + """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape ql = cls(input_dims, output_dims, False, group_size, bits) ql.weight, ql.scales, ql.biases = mx.quantize( @@ -105,21 +231,3 @@ class QuantizedLinear(Module): ql.bias = linear_layer.bias return ql - - @classmethod - def quantize_module( - cls, - model: Module, - group_size: int = 64, - bits: int = 4, - linear_class_predicate=lambda m: isinstance(m, Linear), - ): - def _quantize_if_linear(m): - if linear_class_predicate(m): - return cls.from_linear(m, group_size, bits) - else: - return m - - leaves = model.leaf_modules() - leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module) - model.update_modules(leaves) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 802b03831..31e94a2a1 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -3,7 +3,7 @@ from collections import defaultdict def tree_map(fn, tree, *rest, is_leaf=None): - """Applies ``fn`` to the leaves of the python tree ``tree`` and + """Applies ``fn`` to the leaves of the Python tree ``tree`` and returns a new collection with the results. If ``rest`` is provided, every item is assumed to be a superset of ``tree`` @@ -27,14 +27,14 @@ def tree_map(fn, tree, *rest, is_leaf=None): model.update(tree_map(lambda x: x*x, model.parameters())) Args: - fn (Callable): The function that processes the leaves of the tree - tree (Any): The main python tree that will be iterated upon - rest (Tuple[Any]): Extra trees to be iterated together with tree - is_leaf (Optional[Callable]): An optional callable that returns True if - the passed object is considered a leaf or False otherwise. + fn (callable): The function that processes the leaves of the tree. + tree (Any): The main Python tree that will be iterated upon. + rest (tuple[Any]): Extra trees to be iterated together with ``tree``. + is_leaf (callable, optional): An optional callable that returns ``True`` + if the passed object is considered a leaf or ``False`` otherwise. Returns: - A python tree with the new values returned by ``fn``. + A Python tree with the new values returned by ``fn``. """ if is_leaf is not None and is_leaf(tree): return fn(tree, *rest) @@ -53,8 +53,57 @@ def tree_map(fn, tree, *rest, is_leaf=None): return fn(tree, *rest) +def tree_map_with_path(fn, tree, *rest, is_leaf=None, path=None): + """Applies ``fn`` to the path and leaves of the Python tree ``tree`` and + returns a new collection with the results. + + This function is the same :func:`tree_map` but the ``fn`` takes the path as + the first argument followed by the remaining tree nodes. + + Args: + fn (callable): The function that processes the leaves of the tree. + tree (Any): The main Python tree that will be iterated upon. + rest (tuple[Any]): Extra trees to be iterated together with ``tree``. + is_leaf (callable, optional): An optional callable that returns ``True`` + if the passed object is considered a leaf or ``False`` otherwise. + + Returns: + A Python tree with the new values returned by ``fn``. + + Example: + >>> from mlx.utils import tree_map_with_path + >>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]} + >>> new_tree = tree_map_with_path(lambda path, _: print(path), tree) + model.0.w + model.0.b + model.1.w + model.1.b + """ + if is_leaf is not None and is_leaf(tree): + return fn(path, tree, *rest) + elif isinstance(tree, (list, tuple)): + prefix = f"{path}." if path else "" + TreeType = type(tree) + return TreeType( + tree_map_with_path( + fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f"{prefix}{i}" + ) + for i, child in enumerate(tree) + ) + elif isinstance(tree, dict): + prefix = f"{path}." if path else "" + return { + k: tree_map_with_path( + fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f"{prefix}{k}" + ) + for k, child in tree.items() + } + else: + return fn(path, tree, *rest) + + def tree_flatten(tree, prefix="", is_leaf=None): - """Flattens a python tree to a list of key, value tuples. + """Flattens a Python tree to a list of key, value tuples. The keys are using the dot notation to define trees of arbitrary depth and complexity. @@ -70,17 +119,17 @@ def tree_flatten(tree, prefix="", is_leaf=None): # [("hello.0.0.0", 0)] .. note:: - Dictionaries should have keys that are valid python identifiers. + Dictionaries should have keys that are valid Python identifiers. Args: - tree (Any): The python tree to be flattened. + tree (Any): The Python tree to be flattened. prefix (str): A prefix to use for the keys. The first character is always discarded. - is_leaf (Callable): An optional callable that returns True if the + is_leaf (callable): An optional callable that returns True if the passed object is considered a leaf or False otherwise. Returns: - List[Tuple[str, Any]]: The flat representation of the python tree. + List[Tuple[str, Any]]: The flat representation of the Python tree. """ flat_tree = [] @@ -98,7 +147,7 @@ def tree_flatten(tree, prefix="", is_leaf=None): def tree_unflatten(tree): - """Recreate a python tree from its flat representation. + """Recreate a Python tree from its flat representation. .. code-block:: python @@ -109,11 +158,11 @@ def tree_unflatten(tree): # {"hello": {"world": 42}} Args: - tree (List[Tuple[str, Any]]): The flat representation of a python tree. - For instance as returned by :meth:`tree_flatten`. + tree (list[tuple[str, Any]]): The flat representation of a Python tree. + For instance as returned by :meth:`tree_flatten`. Returns: - A python tree. + A Python tree. """ if len(tree) == 1 and tree[0][0] == "": return tree[0][1] diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index f5e8f6d8d..4ca5e465b 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -172,6 +172,19 @@ class TestBase(mlx_tests.MLXTestCase): self.assertFalse(m.update(params_dict).eval()._training) self.assertTrue(m.train()._training) + def test_quantize(self): + m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256)) + nn.quantize(m) + self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding)) + self.assertTrue(isinstance(m.layers[1], nn.ReLU)) + self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + + m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256)) + nn.quantize(m, class_predicate=lambda _, m: isinstance(m, nn.Linear)) + self.assertTrue(isinstance(m.layers[0], nn.Embedding)) + self.assertTrue(isinstance(m.layers[1], nn.ReLU)) + self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): @@ -1606,6 +1619,19 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(h_out.shape, (44, 12)) self.assertEqual(c_out.shape, (44, 12)) + def test_quantized_embedding(self): + emb = nn.Embedding(32, 256) + qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8) + x = mx.array([2, 6, 9, 3, 0, 3]) + y = emb(x) + yq = qemb(x) + self.assertLess((y - yq).abs().max(), 1e-3) + + x = mx.random.uniform(shape=(2, 256)) + y = emb.as_linear(x) + yq = qemb.as_linear(x) + self.assertLess((y - yq).abs().max(), 1e-2) + if __name__ == "__main__": unittest.main()