Quantize embedding (#994)

* quantize embedding

* rename as_linear + comment

* consistency in docs

* fix test
This commit is contained in:
Awni Hannun 2024-04-15 16:42:10 -07:00 committed by GitHub
parent 2e7c02d5cd
commit cd9e184529
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 269 additions and 54 deletions

View File

@ -0,0 +1,20 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != "__init__" %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@ -173,6 +173,7 @@ In detail:
:toctree: _autosummary
value_and_grad
quantize
.. toctree::

View File

@ -31,6 +31,7 @@ Layers
Mish
MultiHeadAttention
PReLU
QuantizedEmbedding
QuantizedLinear
RMSNorm
ReLU

View File

@ -19,3 +19,4 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten
tree_unflatten
tree_map
tree_map_with_path

View File

@ -61,7 +61,7 @@ from mlx.nn.layers.normalization import (
)
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
from mlx.nn.layers.recurrent import GRU, LSTM, RNN
from mlx.nn.layers.transformer import (
MultiHeadAttention,

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
@ -28,3 +28,12 @@ class Embedding(Module):
def __call__(self, x):
return self.weight[x]
def as_linear(self, x):
"""
Call the embedding layer as a linear layer.
Use this for example when input embedding and output projection
weights are tied.
"""
return x @ self.weight.T

View File

@ -1,11 +1,143 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Callable, Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.utils import tree_flatten, tree_map
from mlx.utils import tree_map_with_path
def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
class_predicate: Optional[callable] = None,
):
"""Quantize the sub-modules of a module according to a predicate.
By default all :obj:`Linear` and :obj:`Embedding` layers will be
quantized. Note also, the module is updated in-place.
Args:
model (mlx.nn.Module): The model whose leaf modules may be quantized.
group_size (int): The quantization group size (see
:func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see
:func:`mlx.core.quantize`). Default: ``4``.
class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
it should be quantized and ``False`` otherwise. If ``None``, then
all linear and embedding layers are quantized. Default: ``None``.
"""
class_predicate = class_predicate or (
lambda _, m: isinstance(m, (Linear, Embedding))
)
def _maybe_quantize(path, m):
if class_predicate(path, m):
if isinstance(m, Linear):
return QuantizedLinear.from_linear(m, group_size, bits)
elif isinstance(m, Embedding):
return QuantizedEmbedding.from_embedding(m, group_size, bits)
else:
raise ValueError(f"Unable to quantize model of type {type(m)}")
else:
return m
leaves = model.leaf_modules()
leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=Module.is_module)
model.update_modules(leaves)
class QuantizedEmbedding(Module):
"""The same as :obj:`Embedding` but with a quantized weight matrix.
:obj:`QuantizedEmbedding` also provides a :meth:`from_embedding`
classmethod to convert embedding layers to :obj:`QuantizedEmbedding`
layers.
Args:
num_embeddings (int): How many possible discrete tokens can we embed.
Usually called the vocabulary size.
dims (int): The dimensionality of the embeddings.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
"""
def __init__(
self,
num_embeddings: int,
dims: int,
group_size: int = 64,
bits: int = 4,
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
# Initialize the quantized weight
scale = math.sqrt(1 / dims)
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
self.num_embeddings = num_embeddings
self.dims = dims
# Freeze this model's parameters
self.freeze()
def __call__(self, x):
s = x.shape
x = x.flatten()
out = mx.dequantize(
self["weight"][x],
scales=self["scales"][x],
biases=self["biases"][x],
group_size=self.group_size,
bits=self.bits,
)
return out.reshape(*s, -1)
def as_linear(self, x):
"""
Call the quantized embedding layer as a quantized linear layer.
Use this for example when input embedding and output projection
weights are tied.
"""
return mx.quantized_matmul(
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
def _extra_repr(self):
return (
f"{self.num_embeddings}, {self.dims}, "
f"group_size={self.group_size}, bits={self.bits}"
)
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
embedding_layer.weight, group_size, bits
)
return ql
class QuantizedLinear(Module):
@ -15,23 +147,18 @@ class QuantizedLinear(Module):
parameters are frozen and will not be included in any gradient computation
but this will probably change in the future.
QuantizedLinear also provides two useful classmethods to convert linear
layers to QuantizedLinear layers.
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
linear transformation up to the quantization error.
- :meth:`quantize_module` swaps all the linear layers of the passed module
with QuantizedLinear ones.
:obj:`QuantizedLinear` also provides a classmethod :meth:`from_linear` to
convert linear layers to :obj:`QuantizedLinear` layers.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
input_dims (int): The dimensionality of the input features.
output_dims (int): The dimensionality of the output features.
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. (default: True).
a bias. Default: ``True``.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. (default: 64)
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. (default: 4)
See :func:`~mlx.core.quantize`. Default: ``4``.
"""
def __init__(
@ -94,8 +221,7 @@ class QuantizedLinear(Module):
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
"""Create a QuantizedLinear layer from the parameters of a provided
linear layer."""
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
@ -105,21 +231,3 @@ class QuantizedLinear(Module):
ql.bias = linear_layer.bias
return ql
@classmethod
def quantize_module(
cls,
model: Module,
group_size: int = 64,
bits: int = 4,
linear_class_predicate=lambda m: isinstance(m, Linear),
):
def _quantize_if_linear(m):
if linear_class_predicate(m):
return cls.from_linear(m, group_size, bits)
else:
return m
leaves = model.leaf_modules()
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
model.update_modules(leaves)

View File

@ -3,7 +3,7 @@ from collections import defaultdict
def tree_map(fn, tree, *rest, is_leaf=None):
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
@ -27,14 +27,14 @@ def tree_map(fn, tree, *rest, is_leaf=None):
model.update(tree_map(lambda x: x*x, model.parameters()))
Args:
fn (Callable): The function that processes the leaves of the tree
tree (Any): The main python tree that will be iterated upon
rest (Tuple[Any]): Extra trees to be iterated together with tree
is_leaf (Optional[Callable]): An optional callable that returns True if
the passed object is considered a leaf or False otherwise.
fn (callable): The function that processes the leaves of the tree.
tree (Any): The main Python tree that will be iterated upon.
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
is_leaf (callable, optional): An optional callable that returns ``True``
if the passed object is considered a leaf or ``False`` otherwise.
Returns:
A python tree with the new values returned by ``fn``.
A Python tree with the new values returned by ``fn``.
"""
if is_leaf is not None and is_leaf(tree):
return fn(tree, *rest)
@ -53,8 +53,57 @@ def tree_map(fn, tree, *rest, is_leaf=None):
return fn(tree, *rest)
def tree_map_with_path(fn, tree, *rest, is_leaf=None, path=None):
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
returns a new collection with the results.
This function is the same :func:`tree_map` but the ``fn`` takes the path as
the first argument followed by the remaining tree nodes.
Args:
fn (callable): The function that processes the leaves of the tree.
tree (Any): The main Python tree that will be iterated upon.
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
is_leaf (callable, optional): An optional callable that returns ``True``
if the passed object is considered a leaf or ``False`` otherwise.
Returns:
A Python tree with the new values returned by ``fn``.
Example:
>>> from mlx.utils import tree_map_with_path
>>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
>>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
model.0.w
model.0.b
model.1.w
model.1.b
"""
if is_leaf is not None and is_leaf(tree):
return fn(path, tree, *rest)
elif isinstance(tree, (list, tuple)):
prefix = f"{path}." if path else ""
TreeType = type(tree)
return TreeType(
tree_map_with_path(
fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f"{prefix}{i}"
)
for i, child in enumerate(tree)
)
elif isinstance(tree, dict):
prefix = f"{path}." if path else ""
return {
k: tree_map_with_path(
fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f"{prefix}{k}"
)
for k, child in tree.items()
}
else:
return fn(path, tree, *rest)
def tree_flatten(tree, prefix="", is_leaf=None):
"""Flattens a python tree to a list of key, value tuples.
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
complexity.
@ -70,17 +119,17 @@ def tree_flatten(tree, prefix="", is_leaf=None):
# [("hello.0.0.0", 0)]
.. note::
Dictionaries should have keys that are valid python identifiers.
Dictionaries should have keys that are valid Python identifiers.
Args:
tree (Any): The python tree to be flattened.
tree (Any): The Python tree to be flattened.
prefix (str): A prefix to use for the keys. The first character is
always discarded.
is_leaf (Callable): An optional callable that returns True if the
is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
Returns:
List[Tuple[str, Any]]: The flat representation of the python tree.
List[Tuple[str, Any]]: The flat representation of the Python tree.
"""
flat_tree = []
@ -98,7 +147,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
def tree_unflatten(tree):
"""Recreate a python tree from its flat representation.
"""Recreate a Python tree from its flat representation.
.. code-block:: python
@ -109,11 +158,11 @@ def tree_unflatten(tree):
# {"hello": {"world": 42}}
Args:
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`.
Returns:
A python tree.
A Python tree.
"""
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]

View File

@ -172,6 +172,19 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertFalse(m.update(params_dict).eval()._training)
self.assertTrue(m.train()._training)
def test_quantize(self):
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
nn.quantize(m)
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
nn.quantize(m, class_predicate=lambda _, m: isinstance(m, nn.Linear))
self.assertTrue(isinstance(m.layers[0], nn.Embedding))
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):
@ -1606,6 +1619,19 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(h_out.shape, (44, 12))
self.assertEqual(c_out.shape, (44, 12))
def test_quantized_embedding(self):
emb = nn.Embedding(32, 256)
qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8)
x = mx.array([2, 6, 9, 3, 0, 3])
y = emb(x)
yq = qemb(x)
self.assertLess((y - yq).abs().max(), 1e-3)
x = mx.random.uniform(shape=(2, 256))
y = emb.as_linear(x)
yq = qemb.as_linear(x)
self.assertLess((y - yq).abs().max(), 1e-2)
if __name__ == "__main__":
unittest.main()