mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Quantize embedding (#994)
* quantize embedding * rename as_linear + comment * consistency in docs * fix test
This commit is contained in:
parent
2e7c02d5cd
commit
cd9e184529
20
docs/src/_templates/nn-module-template.rst
Normal file
20
docs/src/_templates/nn-module-template.rst
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
{{ fullname | escape | underline}}
|
||||||
|
|
||||||
|
.. currentmodule:: {{ module }}
|
||||||
|
|
||||||
|
.. autoclass:: {{ objname }}
|
||||||
|
|
||||||
|
{% block methods %}
|
||||||
|
|
||||||
|
{% if methods %}
|
||||||
|
.. rubric:: {{ _('Methods') }}
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
{% for item in methods %}
|
||||||
|
{%- if item not in inherited_members and item != "__init__" %}
|
||||||
|
~{{ name }}.{{ item }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{% endif %}
|
||||||
|
{% endblock %}
|
||||||
|
|
@ -173,6 +173,7 @@ In detail:
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
value_and_grad
|
value_and_grad
|
||||||
|
quantize
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ Layers
|
|||||||
Mish
|
Mish
|
||||||
MultiHeadAttention
|
MultiHeadAttention
|
||||||
PReLU
|
PReLU
|
||||||
|
QuantizedEmbedding
|
||||||
QuantizedLinear
|
QuantizedLinear
|
||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
|
@ -19,3 +19,4 @@ return python trees will be using the default python ``dict``, ``list`` and
|
|||||||
tree_flatten
|
tree_flatten
|
||||||
tree_unflatten
|
tree_unflatten
|
||||||
tree_map
|
tree_map
|
||||||
|
tree_map_with_path
|
||||||
|
@ -61,7 +61,7 @@ from mlx.nn.layers.normalization import (
|
|||||||
)
|
)
|
||||||
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
|
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
|
||||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||||
from mlx.nn.layers.quantized import QuantizedLinear
|
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
|
||||||
from mlx.nn.layers.recurrent import GRU, LSTM, RNN
|
from mlx.nn.layers.recurrent import GRU, LSTM, RNN
|
||||||
from mlx.nn.layers.transformer import (
|
from mlx.nn.layers.transformer import (
|
||||||
MultiHeadAttention,
|
MultiHeadAttention,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ class Embedding(Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_embeddings (int): How many possible discrete tokens can we embed.
|
num_embeddings (int): How many possible discrete tokens can we embed.
|
||||||
Usually called the vocabulary size.
|
Usually called the vocabulary size.
|
||||||
dims (int): The dimensionality of the embeddings.
|
dims (int): The dimensionality of the embeddings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -28,3 +28,12 @@ class Embedding(Module):
|
|||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return self.weight[x]
|
return self.weight[x]
|
||||||
|
|
||||||
|
def as_linear(self, x):
|
||||||
|
"""
|
||||||
|
Call the embedding layer as a linear layer.
|
||||||
|
|
||||||
|
Use this for example when input embedding and output projection
|
||||||
|
weights are tied.
|
||||||
|
"""
|
||||||
|
return x @ self.weight.T
|
||||||
|
@ -1,11 +1,143 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
from mlx.nn.layers.embedding import Embedding
|
||||||
from mlx.nn.layers.linear import Linear
|
from mlx.nn.layers.linear import Linear
|
||||||
from mlx.utils import tree_flatten, tree_map
|
from mlx.utils import tree_map_with_path
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(
|
||||||
|
model: Module,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
class_predicate: Optional[callable] = None,
|
||||||
|
):
|
||||||
|
"""Quantize the sub-modules of a module according to a predicate.
|
||||||
|
|
||||||
|
By default all :obj:`Linear` and :obj:`Embedding` layers will be
|
||||||
|
quantized. Note also, the module is updated in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (mlx.nn.Module): The model whose leaf modules may be quantized.
|
||||||
|
group_size (int): The quantization group size (see
|
||||||
|
:func:`mlx.core.quantize`). Default: ``64``.
|
||||||
|
bits (int): The number of bits per parameter (see
|
||||||
|
:func:`mlx.core.quantize`). Default: ``4``.
|
||||||
|
class_predicate (Optional[Callable]): A callable which receives the
|
||||||
|
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
|
||||||
|
it should be quantized and ``False`` otherwise. If ``None``, then
|
||||||
|
all linear and embedding layers are quantized. Default: ``None``.
|
||||||
|
"""
|
||||||
|
class_predicate = class_predicate or (
|
||||||
|
lambda _, m: isinstance(m, (Linear, Embedding))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _maybe_quantize(path, m):
|
||||||
|
if class_predicate(path, m):
|
||||||
|
if isinstance(m, Linear):
|
||||||
|
return QuantizedLinear.from_linear(m, group_size, bits)
|
||||||
|
elif isinstance(m, Embedding):
|
||||||
|
return QuantizedEmbedding.from_embedding(m, group_size, bits)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unable to quantize model of type {type(m)}")
|
||||||
|
else:
|
||||||
|
return m
|
||||||
|
|
||||||
|
leaves = model.leaf_modules()
|
||||||
|
leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=Module.is_module)
|
||||||
|
model.update_modules(leaves)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedEmbedding(Module):
|
||||||
|
"""The same as :obj:`Embedding` but with a quantized weight matrix.
|
||||||
|
|
||||||
|
:obj:`QuantizedEmbedding` also provides a :meth:`from_embedding`
|
||||||
|
classmethod to convert embedding layers to :obj:`QuantizedEmbedding`
|
||||||
|
layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_embeddings (int): How many possible discrete tokens can we embed.
|
||||||
|
Usually called the vocabulary size.
|
||||||
|
dims (int): The dimensionality of the embeddings.
|
||||||
|
group_size (int, optional): The group size to use for the quantized
|
||||||
|
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||||
|
bits (int, optional): The bit width to use for the quantized weight.
|
||||||
|
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_embeddings: int,
|
||||||
|
dims: int,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Quantization config
|
||||||
|
self.group_size = group_size
|
||||||
|
self.bits = bits
|
||||||
|
|
||||||
|
# Initialize the quantized weight
|
||||||
|
scale = math.sqrt(1 / dims)
|
||||||
|
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||||
|
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
# Freeze this model's parameters
|
||||||
|
self.freeze()
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
s = x.shape
|
||||||
|
x = x.flatten()
|
||||||
|
out = mx.dequantize(
|
||||||
|
self["weight"][x],
|
||||||
|
scales=self["scales"][x],
|
||||||
|
biases=self["biases"][x],
|
||||||
|
group_size=self.group_size,
|
||||||
|
bits=self.bits,
|
||||||
|
)
|
||||||
|
return out.reshape(*s, -1)
|
||||||
|
|
||||||
|
def as_linear(self, x):
|
||||||
|
"""
|
||||||
|
Call the quantized embedding layer as a quantized linear layer.
|
||||||
|
|
||||||
|
Use this for example when input embedding and output projection
|
||||||
|
weights are tied.
|
||||||
|
"""
|
||||||
|
return mx.quantized_matmul(
|
||||||
|
x,
|
||||||
|
self["weight"],
|
||||||
|
scales=self["scales"],
|
||||||
|
biases=self["biases"],
|
||||||
|
transpose=True,
|
||||||
|
group_size=self.group_size,
|
||||||
|
bits=self.bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return (
|
||||||
|
f"{self.num_embeddings}, {self.dims}, "
|
||||||
|
f"group_size={self.group_size}, bits={self.bits}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_embedding(
|
||||||
|
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||||
|
):
|
||||||
|
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||||
|
embedding_dims, dims = embedding_layer.weight.shape
|
||||||
|
ql = cls(embedding_dims, dims, group_size, bits)
|
||||||
|
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||||
|
embedding_layer.weight, group_size, bits
|
||||||
|
)
|
||||||
|
return ql
|
||||||
|
|
||||||
|
|
||||||
class QuantizedLinear(Module):
|
class QuantizedLinear(Module):
|
||||||
@ -15,23 +147,18 @@ class QuantizedLinear(Module):
|
|||||||
parameters are frozen and will not be included in any gradient computation
|
parameters are frozen and will not be included in any gradient computation
|
||||||
but this will probably change in the future.
|
but this will probably change in the future.
|
||||||
|
|
||||||
QuantizedLinear also provides two useful classmethods to convert linear
|
:obj:`QuantizedLinear` also provides a classmethod :meth:`from_linear` to
|
||||||
layers to QuantizedLinear layers.
|
convert linear layers to :obj:`QuantizedLinear` layers.
|
||||||
|
|
||||||
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
|
|
||||||
linear transformation up to the quantization error.
|
|
||||||
- :meth:`quantize_module` swaps all the linear layers of the passed module
|
|
||||||
with QuantizedLinear ones.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dims (int): The dimensionality of the input features
|
input_dims (int): The dimensionality of the input features.
|
||||||
output_dims (int): The dimensionality of the output features
|
output_dims (int): The dimensionality of the output features.
|
||||||
bias (bool, optional): If set to ``False`` then the layer will not use
|
bias (bool, optional): If set to ``False`` then the layer will not use
|
||||||
a bias. (default: True).
|
a bias. Default: ``True``.
|
||||||
group_size (int, optional): The group size to use for the quantized
|
group_size (int, optional): The group size to use for the quantized
|
||||||
weight. See :func:`~mlx.core.quantize`. (default: 64)
|
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||||
bits (int, optional): The bit width to use for the quantized weight.
|
bits (int, optional): The bit width to use for the quantized weight.
|
||||||
See :func:`~mlx.core.quantize`. (default: 4)
|
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -94,8 +221,7 @@ class QuantizedLinear(Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||||
"""Create a QuantizedLinear layer from the parameters of a provided
|
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||||
linear layer."""
|
|
||||||
output_dims, input_dims = linear_layer.weight.shape
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||||
@ -105,21 +231,3 @@ class QuantizedLinear(Module):
|
|||||||
ql.bias = linear_layer.bias
|
ql.bias = linear_layer.bias
|
||||||
|
|
||||||
return ql
|
return ql
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def quantize_module(
|
|
||||||
cls,
|
|
||||||
model: Module,
|
|
||||||
group_size: int = 64,
|
|
||||||
bits: int = 4,
|
|
||||||
linear_class_predicate=lambda m: isinstance(m, Linear),
|
|
||||||
):
|
|
||||||
def _quantize_if_linear(m):
|
|
||||||
if linear_class_predicate(m):
|
|
||||||
return cls.from_linear(m, group_size, bits)
|
|
||||||
else:
|
|
||||||
return m
|
|
||||||
|
|
||||||
leaves = model.leaf_modules()
|
|
||||||
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
|
|
||||||
model.update_modules(leaves)
|
|
||||||
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
|||||||
|
|
||||||
|
|
||||||
def tree_map(fn, tree, *rest, is_leaf=None):
|
def tree_map(fn, tree, *rest, is_leaf=None):
|
||||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
|
||||||
returns a new collection with the results.
|
returns a new collection with the results.
|
||||||
|
|
||||||
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
|
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
|
||||||
@ -27,14 +27,14 @@ def tree_map(fn, tree, *rest, is_leaf=None):
|
|||||||
model.update(tree_map(lambda x: x*x, model.parameters()))
|
model.update(tree_map(lambda x: x*x, model.parameters()))
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn (Callable): The function that processes the leaves of the tree
|
fn (callable): The function that processes the leaves of the tree.
|
||||||
tree (Any): The main python tree that will be iterated upon
|
tree (Any): The main Python tree that will be iterated upon.
|
||||||
rest (Tuple[Any]): Extra trees to be iterated together with tree
|
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
|
||||||
is_leaf (Optional[Callable]): An optional callable that returns True if
|
is_leaf (callable, optional): An optional callable that returns ``True``
|
||||||
the passed object is considered a leaf or False otherwise.
|
if the passed object is considered a leaf or ``False`` otherwise.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A python tree with the new values returned by ``fn``.
|
A Python tree with the new values returned by ``fn``.
|
||||||
"""
|
"""
|
||||||
if is_leaf is not None and is_leaf(tree):
|
if is_leaf is not None and is_leaf(tree):
|
||||||
return fn(tree, *rest)
|
return fn(tree, *rest)
|
||||||
@ -53,8 +53,57 @@ def tree_map(fn, tree, *rest, is_leaf=None):
|
|||||||
return fn(tree, *rest)
|
return fn(tree, *rest)
|
||||||
|
|
||||||
|
|
||||||
|
def tree_map_with_path(fn, tree, *rest, is_leaf=None, path=None):
|
||||||
|
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
|
||||||
|
returns a new collection with the results.
|
||||||
|
|
||||||
|
This function is the same :func:`tree_map` but the ``fn`` takes the path as
|
||||||
|
the first argument followed by the remaining tree nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn (callable): The function that processes the leaves of the tree.
|
||||||
|
tree (Any): The main Python tree that will be iterated upon.
|
||||||
|
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
|
||||||
|
is_leaf (callable, optional): An optional callable that returns ``True``
|
||||||
|
if the passed object is considered a leaf or ``False`` otherwise.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Python tree with the new values returned by ``fn``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from mlx.utils import tree_map_with_path
|
||||||
|
>>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
|
||||||
|
>>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
|
||||||
|
model.0.w
|
||||||
|
model.0.b
|
||||||
|
model.1.w
|
||||||
|
model.1.b
|
||||||
|
"""
|
||||||
|
if is_leaf is not None and is_leaf(tree):
|
||||||
|
return fn(path, tree, *rest)
|
||||||
|
elif isinstance(tree, (list, tuple)):
|
||||||
|
prefix = f"{path}." if path else ""
|
||||||
|
TreeType = type(tree)
|
||||||
|
return TreeType(
|
||||||
|
tree_map_with_path(
|
||||||
|
fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f"{prefix}{i}"
|
||||||
|
)
|
||||||
|
for i, child in enumerate(tree)
|
||||||
|
)
|
||||||
|
elif isinstance(tree, dict):
|
||||||
|
prefix = f"{path}." if path else ""
|
||||||
|
return {
|
||||||
|
k: tree_map_with_path(
|
||||||
|
fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f"{prefix}{k}"
|
||||||
|
)
|
||||||
|
for k, child in tree.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return fn(path, tree, *rest)
|
||||||
|
|
||||||
|
|
||||||
def tree_flatten(tree, prefix="", is_leaf=None):
|
def tree_flatten(tree, prefix="", is_leaf=None):
|
||||||
"""Flattens a python tree to a list of key, value tuples.
|
"""Flattens a Python tree to a list of key, value tuples.
|
||||||
|
|
||||||
The keys are using the dot notation to define trees of arbitrary depth and
|
The keys are using the dot notation to define trees of arbitrary depth and
|
||||||
complexity.
|
complexity.
|
||||||
@ -70,17 +119,17 @@ def tree_flatten(tree, prefix="", is_leaf=None):
|
|||||||
# [("hello.0.0.0", 0)]
|
# [("hello.0.0.0", 0)]
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Dictionaries should have keys that are valid python identifiers.
|
Dictionaries should have keys that are valid Python identifiers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (Any): The python tree to be flattened.
|
tree (Any): The Python tree to be flattened.
|
||||||
prefix (str): A prefix to use for the keys. The first character is
|
prefix (str): A prefix to use for the keys. The first character is
|
||||||
always discarded.
|
always discarded.
|
||||||
is_leaf (Callable): An optional callable that returns True if the
|
is_leaf (callable): An optional callable that returns True if the
|
||||||
passed object is considered a leaf or False otherwise.
|
passed object is considered a leaf or False otherwise.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, Any]]: The flat representation of the python tree.
|
List[Tuple[str, Any]]: The flat representation of the Python tree.
|
||||||
"""
|
"""
|
||||||
flat_tree = []
|
flat_tree = []
|
||||||
|
|
||||||
@ -98,7 +147,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
|
|||||||
|
|
||||||
|
|
||||||
def tree_unflatten(tree):
|
def tree_unflatten(tree):
|
||||||
"""Recreate a python tree from its flat representation.
|
"""Recreate a Python tree from its flat representation.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -109,11 +158,11 @@ def tree_unflatten(tree):
|
|||||||
# {"hello": {"world": 42}}
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
|
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
|
||||||
For instance as returned by :meth:`tree_flatten`.
|
For instance as returned by :meth:`tree_flatten`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A python tree.
|
A Python tree.
|
||||||
"""
|
"""
|
||||||
if len(tree) == 1 and tree[0][0] == "":
|
if len(tree) == 1 and tree[0][0] == "":
|
||||||
return tree[0][1]
|
return tree[0][1]
|
||||||
|
@ -172,6 +172,19 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.assertFalse(m.update(params_dict).eval()._training)
|
self.assertFalse(m.update(params_dict).eval()._training)
|
||||||
self.assertTrue(m.train()._training)
|
self.assertTrue(m.train()._training)
|
||||||
|
|
||||||
|
def test_quantize(self):
|
||||||
|
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
|
||||||
|
nn.quantize(m)
|
||||||
|
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
|
||||||
|
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||||
|
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||||
|
|
||||||
|
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
|
||||||
|
nn.quantize(m, class_predicate=lambda _, m: isinstance(m, nn.Linear))
|
||||||
|
self.assertTrue(isinstance(m.layers[0], nn.Embedding))
|
||||||
|
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||||
|
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
@ -1606,6 +1619,19 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(h_out.shape, (44, 12))
|
self.assertEqual(h_out.shape, (44, 12))
|
||||||
self.assertEqual(c_out.shape, (44, 12))
|
self.assertEqual(c_out.shape, (44, 12))
|
||||||
|
|
||||||
|
def test_quantized_embedding(self):
|
||||||
|
emb = nn.Embedding(32, 256)
|
||||||
|
qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8)
|
||||||
|
x = mx.array([2, 6, 9, 3, 0, 3])
|
||||||
|
y = emb(x)
|
||||||
|
yq = qemb(x)
|
||||||
|
self.assertLess((y - yq).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 256))
|
||||||
|
y = emb.as_linear(x)
|
||||||
|
yq = qemb.as_linear(x)
|
||||||
|
self.assertLess((y - yq).abs().max(), 1e-2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user