Quantize embedding (#994)

* quantize embedding

* rename as_linear + comment

* consistency in docs

* fix test
This commit is contained in:
Awni Hannun 2024-04-15 16:42:10 -07:00 committed by GitHub
parent 2e7c02d5cd
commit cd9e184529
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 269 additions and 54 deletions

View File

@ -0,0 +1,20 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != "__init__" %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@ -173,6 +173,7 @@ In detail:
:toctree: _autosummary :toctree: _autosummary
value_and_grad value_and_grad
quantize
.. toctree:: .. toctree::

View File

@ -31,6 +31,7 @@ Layers
Mish Mish
MultiHeadAttention MultiHeadAttention
PReLU PReLU
QuantizedEmbedding
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
@ -43,4 +44,4 @@ Layers
Softshrink Softshrink
Step Step
Transformer Transformer
Upsample Upsample

View File

@ -19,3 +19,4 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten tree_flatten
tree_unflatten tree_unflatten
tree_map tree_map
tree_map_with_path

View File

@ -61,7 +61,7 @@ from mlx.nn.layers.normalization import (
) )
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
from mlx.nn.layers.recurrent import GRU, LSTM, RNN from mlx.nn.layers.recurrent import GRU, LSTM, RNN
from mlx.nn.layers.transformer import ( from mlx.nn.layers.transformer import (
MultiHeadAttention, MultiHeadAttention,

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math import math
@ -14,7 +14,7 @@ class Embedding(Module):
Args: Args:
num_embeddings (int): How many possible discrete tokens can we embed. num_embeddings (int): How many possible discrete tokens can we embed.
Usually called the vocabulary size. Usually called the vocabulary size.
dims (int): The dimensionality of the embeddings. dims (int): The dimensionality of the embeddings.
""" """
@ -28,3 +28,12 @@ class Embedding(Module):
def __call__(self, x): def __call__(self, x):
return self.weight[x] return self.weight[x]
def as_linear(self, x):
"""
Call the embedding layer as a linear layer.
Use this for example when input embedding and output projection
weights are tied.
"""
return x @ self.weight.T

View File

@ -1,11 +1,143 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math import math
from typing import Callable, Optional
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear from mlx.nn.layers.linear import Linear
from mlx.utils import tree_flatten, tree_map from mlx.utils import tree_map_with_path
def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
class_predicate: Optional[callable] = None,
):
"""Quantize the sub-modules of a module according to a predicate.
By default all :obj:`Linear` and :obj:`Embedding` layers will be
quantized. Note also, the module is updated in-place.
Args:
model (mlx.nn.Module): The model whose leaf modules may be quantized.
group_size (int): The quantization group size (see
:func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see
:func:`mlx.core.quantize`). Default: ``4``.
class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
it should be quantized and ``False`` otherwise. If ``None``, then
all linear and embedding layers are quantized. Default: ``None``.
"""
class_predicate = class_predicate or (
lambda _, m: isinstance(m, (Linear, Embedding))
)
def _maybe_quantize(path, m):
if class_predicate(path, m):
if isinstance(m, Linear):
return QuantizedLinear.from_linear(m, group_size, bits)
elif isinstance(m, Embedding):
return QuantizedEmbedding.from_embedding(m, group_size, bits)
else:
raise ValueError(f"Unable to quantize model of type {type(m)}")
else:
return m
leaves = model.leaf_modules()
leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=Module.is_module)
model.update_modules(leaves)
class QuantizedEmbedding(Module):
"""The same as :obj:`Embedding` but with a quantized weight matrix.
:obj:`QuantizedEmbedding` also provides a :meth:`from_embedding`
classmethod to convert embedding layers to :obj:`QuantizedEmbedding`
layers.
Args:
num_embeddings (int): How many possible discrete tokens can we embed.
Usually called the vocabulary size.
dims (int): The dimensionality of the embeddings.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
"""
def __init__(
self,
num_embeddings: int,
dims: int,
group_size: int = 64,
bits: int = 4,
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
# Initialize the quantized weight
scale = math.sqrt(1 / dims)
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
self.num_embeddings = num_embeddings
self.dims = dims
# Freeze this model's parameters
self.freeze()
def __call__(self, x):
s = x.shape
x = x.flatten()
out = mx.dequantize(
self["weight"][x],
scales=self["scales"][x],
biases=self["biases"][x],
group_size=self.group_size,
bits=self.bits,
)
return out.reshape(*s, -1)
def as_linear(self, x):
"""
Call the quantized embedding layer as a quantized linear layer.
Use this for example when input embedding and output projection
weights are tied.
"""
return mx.quantized_matmul(
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
def _extra_repr(self):
return (
f"{self.num_embeddings}, {self.dims}, "
f"group_size={self.group_size}, bits={self.bits}"
)
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
embedding_layer.weight, group_size, bits
)
return ql
class QuantizedLinear(Module): class QuantizedLinear(Module):
@ -15,23 +147,18 @@ class QuantizedLinear(Module):
parameters are frozen and will not be included in any gradient computation parameters are frozen and will not be included in any gradient computation
but this will probably change in the future. but this will probably change in the future.
QuantizedLinear also provides two useful classmethods to convert linear :obj:`QuantizedLinear` also provides a classmethod :meth:`from_linear` to
layers to QuantizedLinear layers. convert linear layers to :obj:`QuantizedLinear` layers.
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
linear transformation up to the quantization error.
- :meth:`quantize_module` swaps all the linear layers of the passed module
with QuantizedLinear ones.
Args: Args:
input_dims (int): The dimensionality of the input features input_dims (int): The dimensionality of the input features.
output_dims (int): The dimensionality of the output features output_dims (int): The dimensionality of the output features.
bias (bool, optional): If set to ``False`` then the layer will not use bias (bool, optional): If set to ``False`` then the layer will not use
a bias. (default: True). a bias. Default: ``True``.
group_size (int, optional): The group size to use for the quantized group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. (default: 64) weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight. bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. (default: 4) See :func:`~mlx.core.quantize`. Default: ``4``.
""" """
def __init__( def __init__(
@ -94,8 +221,7 @@ class QuantizedLinear(Module):
@classmethod @classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
"""Create a QuantizedLinear layer from the parameters of a provided """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
linear layer."""
output_dims, input_dims = linear_layer.weight.shape output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits) ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize( ql.weight, ql.scales, ql.biases = mx.quantize(
@ -105,21 +231,3 @@ class QuantizedLinear(Module):
ql.bias = linear_layer.bias ql.bias = linear_layer.bias
return ql return ql
@classmethod
def quantize_module(
cls,
model: Module,
group_size: int = 64,
bits: int = 4,
linear_class_predicate=lambda m: isinstance(m, Linear),
):
def _quantize_if_linear(m):
if linear_class_predicate(m):
return cls.from_linear(m, group_size, bits)
else:
return m
leaves = model.leaf_modules()
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
model.update_modules(leaves)

View File

@ -3,7 +3,7 @@ from collections import defaultdict
def tree_map(fn, tree, *rest, is_leaf=None): def tree_map(fn, tree, *rest, is_leaf=None):
"""Applies ``fn`` to the leaves of the python tree ``tree`` and """Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results. returns a new collection with the results.
If ``rest`` is provided, every item is assumed to be a superset of ``tree`` If ``rest`` is provided, every item is assumed to be a superset of ``tree``
@ -27,14 +27,14 @@ def tree_map(fn, tree, *rest, is_leaf=None):
model.update(tree_map(lambda x: x*x, model.parameters())) model.update(tree_map(lambda x: x*x, model.parameters()))
Args: Args:
fn (Callable): The function that processes the leaves of the tree fn (callable): The function that processes the leaves of the tree.
tree (Any): The main python tree that will be iterated upon tree (Any): The main Python tree that will be iterated upon.
rest (Tuple[Any]): Extra trees to be iterated together with tree rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
is_leaf (Optional[Callable]): An optional callable that returns True if is_leaf (callable, optional): An optional callable that returns ``True``
the passed object is considered a leaf or False otherwise. if the passed object is considered a leaf or ``False`` otherwise.
Returns: Returns:
A python tree with the new values returned by ``fn``. A Python tree with the new values returned by ``fn``.
""" """
if is_leaf is not None and is_leaf(tree): if is_leaf is not None and is_leaf(tree):
return fn(tree, *rest) return fn(tree, *rest)
@ -53,8 +53,57 @@ def tree_map(fn, tree, *rest, is_leaf=None):
return fn(tree, *rest) return fn(tree, *rest)
def tree_map_with_path(fn, tree, *rest, is_leaf=None, path=None):
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
returns a new collection with the results.
This function is the same :func:`tree_map` but the ``fn`` takes the path as
the first argument followed by the remaining tree nodes.
Args:
fn (callable): The function that processes the leaves of the tree.
tree (Any): The main Python tree that will be iterated upon.
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
is_leaf (callable, optional): An optional callable that returns ``True``
if the passed object is considered a leaf or ``False`` otherwise.
Returns:
A Python tree with the new values returned by ``fn``.
Example:
>>> from mlx.utils import tree_map_with_path
>>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
>>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
model.0.w
model.0.b
model.1.w
model.1.b
"""
if is_leaf is not None and is_leaf(tree):
return fn(path, tree, *rest)
elif isinstance(tree, (list, tuple)):
prefix = f"{path}." if path else ""
TreeType = type(tree)
return TreeType(
tree_map_with_path(
fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f"{prefix}{i}"
)
for i, child in enumerate(tree)
)
elif isinstance(tree, dict):
prefix = f"{path}." if path else ""
return {
k: tree_map_with_path(
fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f"{prefix}{k}"
)
for k, child in tree.items()
}
else:
return fn(path, tree, *rest)
def tree_flatten(tree, prefix="", is_leaf=None): def tree_flatten(tree, prefix="", is_leaf=None):
"""Flattens a python tree to a list of key, value tuples. """Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and The keys are using the dot notation to define trees of arbitrary depth and
complexity. complexity.
@ -70,17 +119,17 @@ def tree_flatten(tree, prefix="", is_leaf=None):
# [("hello.0.0.0", 0)] # [("hello.0.0.0", 0)]
.. note:: .. note::
Dictionaries should have keys that are valid python identifiers. Dictionaries should have keys that are valid Python identifiers.
Args: Args:
tree (Any): The python tree to be flattened. tree (Any): The Python tree to be flattened.
prefix (str): A prefix to use for the keys. The first character is prefix (str): A prefix to use for the keys. The first character is
always discarded. always discarded.
is_leaf (Callable): An optional callable that returns True if the is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise. passed object is considered a leaf or False otherwise.
Returns: Returns:
List[Tuple[str, Any]]: The flat representation of the python tree. List[Tuple[str, Any]]: The flat representation of the Python tree.
""" """
flat_tree = [] flat_tree = []
@ -98,7 +147,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
def tree_unflatten(tree): def tree_unflatten(tree):
"""Recreate a python tree from its flat representation. """Recreate a Python tree from its flat representation.
.. code-block:: python .. code-block:: python
@ -109,11 +158,11 @@ def tree_unflatten(tree):
# {"hello": {"world": 42}} # {"hello": {"world": 42}}
Args: Args:
tree (List[Tuple[str, Any]]): The flat representation of a python tree. tree (list[tuple[str, Any]]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`. For instance as returned by :meth:`tree_flatten`.
Returns: Returns:
A python tree. A Python tree.
""" """
if len(tree) == 1 and tree[0][0] == "": if len(tree) == 1 and tree[0][0] == "":
return tree[0][1] return tree[0][1]

View File

@ -172,6 +172,19 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertFalse(m.update(params_dict).eval()._training) self.assertFalse(m.update(params_dict).eval()._training)
self.assertTrue(m.train()._training) self.assertTrue(m.train()._training)
def test_quantize(self):
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
nn.quantize(m)
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
nn.quantize(m, class_predicate=lambda _, m: isinstance(m, nn.Linear))
self.assertTrue(isinstance(m.layers[0], nn.Embedding))
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):
@ -1606,6 +1619,19 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(h_out.shape, (44, 12)) self.assertEqual(h_out.shape, (44, 12))
self.assertEqual(c_out.shape, (44, 12)) self.assertEqual(c_out.shape, (44, 12))
def test_quantized_embedding(self):
emb = nn.Embedding(32, 256)
qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8)
x = mx.array([2, 6, 9, 3, 0, 3])
y = emb(x)
yq = qemb(x)
self.assertLess((y - yq).abs().max(), 1e-3)
x = mx.random.uniform(shape=(2, 256))
y = emb.as_linear(x)
yq = qemb.as_linear(x)
self.assertLess((y - yq).abs().max(), 1e-2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()