mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Quantize embedding (#994)
* quantize embedding * rename as_linear + comment * consistency in docs * fix test
This commit is contained in:
@@ -61,7 +61,7 @@ from mlx.nn.layers.normalization import (
|
||||
)
|
||||
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
|
||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
|
||||
from mlx.nn.layers.recurrent import GRU, LSTM, RNN
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
@@ -14,7 +14,7 @@ class Embedding(Module):
|
||||
|
||||
Args:
|
||||
num_embeddings (int): How many possible discrete tokens can we embed.
|
||||
Usually called the vocabulary size.
|
||||
Usually called the vocabulary size.
|
||||
dims (int): The dimensionality of the embeddings.
|
||||
"""
|
||||
|
||||
@@ -28,3 +28,12 @@ class Embedding(Module):
|
||||
|
||||
def __call__(self, x):
|
||||
return self.weight[x]
|
||||
|
||||
def as_linear(self, x):
|
||||
"""
|
||||
Call the embedding layer as a linear layer.
|
||||
|
||||
Use this for example when input embedding and output projection
|
||||
weights are tied.
|
||||
"""
|
||||
return x @ self.weight.T
|
||||
|
||||
@@ -1,11 +1,143 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
from mlx.utils import tree_map_with_path
|
||||
|
||||
|
||||
def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
class_predicate: Optional[callable] = None,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
|
||||
By default all :obj:`Linear` and :obj:`Embedding` layers will be
|
||||
quantized. Note also, the module is updated in-place.
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): The model whose leaf modules may be quantized.
|
||||
group_size (int): The quantization group size (see
|
||||
:func:`mlx.core.quantize`). Default: ``64``.
|
||||
bits (int): The number of bits per parameter (see
|
||||
:func:`mlx.core.quantize`). Default: ``4``.
|
||||
class_predicate (Optional[Callable]): A callable which receives the
|
||||
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
|
||||
it should be quantized and ``False`` otherwise. If ``None``, then
|
||||
all linear and embedding layers are quantized. Default: ``None``.
|
||||
"""
|
||||
class_predicate = class_predicate or (
|
||||
lambda _, m: isinstance(m, (Linear, Embedding))
|
||||
)
|
||||
|
||||
def _maybe_quantize(path, m):
|
||||
if class_predicate(path, m):
|
||||
if isinstance(m, Linear):
|
||||
return QuantizedLinear.from_linear(m, group_size, bits)
|
||||
elif isinstance(m, Embedding):
|
||||
return QuantizedEmbedding.from_embedding(m, group_size, bits)
|
||||
else:
|
||||
raise ValueError(f"Unable to quantize model of type {type(m)}")
|
||||
else:
|
||||
return m
|
||||
|
||||
leaves = model.leaf_modules()
|
||||
leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=Module.is_module)
|
||||
model.update_modules(leaves)
|
||||
|
||||
|
||||
class QuantizedEmbedding(Module):
|
||||
"""The same as :obj:`Embedding` but with a quantized weight matrix.
|
||||
|
||||
:obj:`QuantizedEmbedding` also provides a :meth:`from_embedding`
|
||||
classmethod to convert embedding layers to :obj:`QuantizedEmbedding`
|
||||
layers.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): How many possible discrete tokens can we embed.
|
||||
Usually called the vocabulary size.
|
||||
dims (int): The dimensionality of the embeddings.
|
||||
group_size (int, optional): The group size to use for the quantized
|
||||
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||
bits (int, optional): The bit width to use for the quantized weight.
|
||||
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
dims: int,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / dims)
|
||||
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
self.num_embeddings = num_embeddings
|
||||
self.dims = dims
|
||||
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def __call__(self, x):
|
||||
s = x.shape
|
||||
x = x.flatten()
|
||||
out = mx.dequantize(
|
||||
self["weight"][x],
|
||||
scales=self["scales"][x],
|
||||
biases=self["biases"][x],
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
return out.reshape(*s, -1)
|
||||
|
||||
def as_linear(self, x):
|
||||
"""
|
||||
Call the quantized embedding layer as a quantized linear layer.
|
||||
|
||||
Use this for example when input embedding and output projection
|
||||
weights are tied.
|
||||
"""
|
||||
return mx.quantized_matmul(
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.num_embeddings}, {self.dims}, "
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
ql = cls(embedding_dims, dims, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
embedding_layer.weight, group_size, bits
|
||||
)
|
||||
return ql
|
||||
|
||||
|
||||
class QuantizedLinear(Module):
|
||||
@@ -15,23 +147,18 @@ class QuantizedLinear(Module):
|
||||
parameters are frozen and will not be included in any gradient computation
|
||||
but this will probably change in the future.
|
||||
|
||||
QuantizedLinear also provides two useful classmethods to convert linear
|
||||
layers to QuantizedLinear layers.
|
||||
|
||||
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
|
||||
linear transformation up to the quantization error.
|
||||
- :meth:`quantize_module` swaps all the linear layers of the passed module
|
||||
with QuantizedLinear ones.
|
||||
:obj:`QuantizedLinear` also provides a classmethod :meth:`from_linear` to
|
||||
convert linear layers to :obj:`QuantizedLinear` layers.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
input_dims (int): The dimensionality of the input features.
|
||||
output_dims (int): The dimensionality of the output features.
|
||||
bias (bool, optional): If set to ``False`` then the layer will not use
|
||||
a bias. (default: True).
|
||||
a bias. Default: ``True``.
|
||||
group_size (int, optional): The group size to use for the quantized
|
||||
weight. See :func:`~mlx.core.quantize`. (default: 64)
|
||||
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||
bits (int, optional): The bit width to use for the quantized weight.
|
||||
See :func:`~mlx.core.quantize`. (default: 4)
|
||||
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -94,8 +221,7 @@ class QuantizedLinear(Module):
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
"""Create a QuantizedLinear layer from the parameters of a provided
|
||||
linear layer."""
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
@@ -105,21 +231,3 @@ class QuantizedLinear(Module):
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
return ql
|
||||
|
||||
@classmethod
|
||||
def quantize_module(
|
||||
cls,
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
linear_class_predicate=lambda m: isinstance(m, Linear),
|
||||
):
|
||||
def _quantize_if_linear(m):
|
||||
if linear_class_predicate(m):
|
||||
return cls.from_linear(m, group_size, bits)
|
||||
else:
|
||||
return m
|
||||
|
||||
leaves = model.leaf_modules()
|
||||
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
|
||||
model.update_modules(leaves)
|
||||
|
||||
Reference in New Issue
Block a user