mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 03:48:15 +08:00
Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops * Add quantize functions to the docs and tests * Add a QuantizedLinear module
This commit is contained in:

committed by
GitHub

parent
4912ff3ec2
commit
57fe918cf8
@@ -38,6 +38,7 @@ from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
|
||||
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
TransformerEncoder,
|
||||
|
@@ -258,6 +258,44 @@ class Module(dict):
|
||||
filter_fn = filter_fn or Module.valid_parameter_filter
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
|
||||
def update_modules(self, modules: dict):
|
||||
"""Replace the child modules of this :class:`Module` instance with the
|
||||
provided ones in the dict of dicts and lists.
|
||||
|
||||
It is the equivalent of :meth:`Module.update` but for modules instead
|
||||
of parameters and allows us to flexibly edit complex architectures by
|
||||
programmatically swapping layers.
|
||||
|
||||
The passed in parameters dictionary need not be a full dictionary
|
||||
similar to :meth:`parameters`. Only the provided locations will be
|
||||
updated.
|
||||
|
||||
Args:
|
||||
modules (dict): A complete or partial dictionary of the modules
|
||||
submodules.
|
||||
"""
|
||||
|
||||
def apply(dst, modules):
|
||||
if isinstance(modules, dict):
|
||||
for k in modules:
|
||||
if k in dst:
|
||||
current_value = dst[k]
|
||||
new_value = modules[k]
|
||||
if self.is_module(current_value) and self.is_module(new_value):
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif isinstance(modules, list):
|
||||
for i in range(len(dst)):
|
||||
current_value = dst[i]
|
||||
new_value = modules[i]
|
||||
if self.is_module(current_value) and self.is_module(new_value):
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
|
||||
apply(self, modules)
|
||||
|
||||
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
||||
"""Apply a function to all the modules in this instance (including this
|
||||
instance).
|
||||
|
124
python/mlx/nn/layers/quantized.py
Normal file
124
python/mlx/nn/layers/quantized.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
class QuantizedLinear(Module):
|
||||
"""Applies an affine transformation to the input using a quantized weight matrix.
|
||||
|
||||
It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its
|
||||
parameters are frozen and will not be included in any gradient computation
|
||||
but this will probably change in the future.
|
||||
|
||||
QuantizedLinear also provides two useful classmethods to convert linear
|
||||
layers to QuantizedLinear layers.
|
||||
|
||||
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
|
||||
linear transformation up to the quantization error.
|
||||
- :meth:`quantize_module` swaps all the linear layers of the passed module
|
||||
with QuantizedLinear ones.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool): If set to ``False`` then the layer will not use a bias.
|
||||
(default: True).
|
||||
groups (int): The group size to use for the quantized weight. See
|
||||
:func:`~mlx.core.quantize`. (default: 128)
|
||||
width (int): The bit width to use for the quantized weight. See
|
||||
:func:`~mlx.core.quantize`. (default: 4)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
bias: bool = True,
|
||||
groups: int = 64,
|
||||
width: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.groups = groups
|
||||
self.width = width
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, groups, width)
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
self.bias = mx.zeros((output_dims,))
|
||||
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
def _extra_repr(self):
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= 32 // self.width
|
||||
return (
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
|
||||
f"groups={self.groups}, width={self.width}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
self.weight.T,
|
||||
scales=self.scales,
|
||||
biases=self.biases,
|
||||
groups=self.groups,
|
||||
width=self.width,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4):
|
||||
"""Create a QuantizedLinear layer from the parameters of a provided
|
||||
linear layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, groups, width)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, groups, width
|
||||
)
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
return ql
|
||||
|
||||
@classmethod
|
||||
def quantize_module(
|
||||
cls,
|
||||
model: Module,
|
||||
groups: int = 64,
|
||||
width: int = 4,
|
||||
linear_class_predicate=lambda m: isinstance(m, Linear),
|
||||
):
|
||||
def _quantize_if_linear(m):
|
||||
if linear_class_predicate(m):
|
||||
return cls.from_linear(m, groups, width)
|
||||
else:
|
||||
return m
|
||||
|
||||
leaves = model.leaf_modules()
|
||||
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
|
||||
model.update_modules(leaves)
|
@@ -445,13 +445,13 @@ class Adamax(Adam):
|
||||
|
||||
|
||||
class Lion(Optimizer):
|
||||
r"""Implementation of the Lion optimizer [1].
|
||||
r"""Implementation of the Lion optimizer [1].
|
||||
|
||||
Since updates are computed through the sign operation, they tend to
|
||||
have larger norm than for other optimizers such as SGD and Adam.
|
||||
We recommend a learning rate that is 3-10x smaller than AdamW and a
|
||||
weight decay 3-10x larger than AdamW to maintain the strength
|
||||
(lr * wd). Our Lion implementation follows the original paper. In
|
||||
Since updates are computed through the sign operation, they tend to
|
||||
have larger norm than for other optimizers such as SGD and Adam.
|
||||
We recommend a learning rate that is 3-10x smaller than AdamW and a
|
||||
weight decay 3-10x larger than AdamW to maintain the strength
|
||||
(lr * wd). Our Lion implementation follows the original paper. In
|
||||
detail,
|
||||
|
||||
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
|
||||
@@ -486,7 +486,7 @@ class Lion(Optimizer):
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""Performs the Lion parameter update and stores :math:`m`
|
||||
"""Performs the Lion parameter update and stores :math:`m`
|
||||
in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
b1, b2 = self.betas
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
|
||||
def tree_map(fn, tree, *rest):
|
||||
def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
|
||||
@@ -10,6 +10,9 @@ def tree_map(fn, tree, *rest):
|
||||
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
|
||||
than to :func:`map`.
|
||||
|
||||
The keyword argument ``is_leaf`` decides what constitutes a leaf from
|
||||
``tree`` similar to :func:`tree_flatten`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.nn as nn
|
||||
@@ -26,21 +29,28 @@ def tree_map(fn, tree, *rest):
|
||||
fn (Callable): The function that processes the leaves of the tree
|
||||
tree (Any): The main python tree that will be iterated upon
|
||||
rest (Tuple[Any]): Extra trees to be iterated together with tree
|
||||
is_leaf (Optional[Callable]): An optional callable that returns True if
|
||||
the passed object is considered a leaf or False otherwise.
|
||||
|
||||
Returns:
|
||||
A python tree with the new values returned by ``fn``.
|
||||
"""
|
||||
if isinstance(tree, list):
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
return fn(tree, *rest)
|
||||
elif isinstance(tree, list):
|
||||
return [
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple(
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
)
|
||||
elif isinstance(tree, dict):
|
||||
return {
|
||||
k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items()
|
||||
k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
|
||||
for k, child in tree.items()
|
||||
}
|
||||
else:
|
||||
return fn(tree, *rest)
|
||||
|
Reference in New Issue
Block a user