mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops * Add quantize functions to the docs and tests * Add a QuantizedLinear module
This commit is contained in:
parent
4912ff3ec2
commit
57fe918cf8
@ -26,3 +26,4 @@ Layers
|
||||
RoPE
|
||||
MultiHeadAttention
|
||||
Sequential
|
||||
QuantizedLinear
|
||||
|
@ -34,6 +34,7 @@ Operations
|
||||
conv2d
|
||||
cos
|
||||
cosh
|
||||
dequantize
|
||||
divide
|
||||
equal
|
||||
erf
|
||||
@ -73,6 +74,8 @@ Operations
|
||||
partition
|
||||
pad
|
||||
prod
|
||||
quantize
|
||||
quantized_matmul
|
||||
reciprocal
|
||||
reshape
|
||||
round
|
||||
|
107
mlx/ops.cpp
107
mlx/ops.cpp
@ -2649,4 +2649,111 @@ array quantized_matmul(
|
||||
return out;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
const array& w,
|
||||
int groups /* = 128 */,
|
||||
int width /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (w.ndim() != 2) {
|
||||
throw std::invalid_argument("[quantize] Only matrices supported for now");
|
||||
}
|
||||
|
||||
if ((w.shape(0) % 32) != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[quantize] All dimensions should be divisible by 32 for now");
|
||||
}
|
||||
|
||||
if ((w.shape(-1) % groups) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << groups
|
||||
<< ". However the provided matrix"
|
||||
<< " has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Compute some constants used for the quantization
|
||||
int n_bins = (1 << width) - 1; // 2**width - 1
|
||||
int el_per_int = 32 / width;
|
||||
array shifts = power(array(2, uint32), arange(0, 32, width, uint32, s), s);
|
||||
shifts = reshape(shifts, {1, 1, -1}, s);
|
||||
|
||||
// Compute scales and biases
|
||||
array packed_w = reshape(w, {w.shape(0), w.shape(1) / groups, groups}, s);
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s);
|
||||
array scales = squeeze(delta, -1, s);
|
||||
array biases = squeeze(w_min, -1, s);
|
||||
|
||||
// Quantize and pack w
|
||||
packed_w =
|
||||
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);
|
||||
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
|
||||
return std::make_tuple(packed_w, scales, biases);
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int groups /* = 128 */,
|
||||
int width /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
|
||||
throw std::invalid_argument("[dequantize] Only matrices supported for now");
|
||||
}
|
||||
|
||||
if ((w.shape(0) % 32) != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] All dimensions should be divisible by 32 for now");
|
||||
}
|
||||
|
||||
if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
}
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
|
||||
// Compute some constants for the dequantization
|
||||
int el_per_int = 32 / width;
|
||||
|
||||
if (w.shape(1) * el_per_int != scales.shape(1) * groups) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
||||
<< " with groups=" << groups << " and width=" << width << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Extract the pieces from the passed quantized matrix
|
||||
std::vector<array> parts;
|
||||
for (int start = 0; start < 32; start += width) {
|
||||
// TODO: Implement bitwise operators for integral types
|
||||
int shift_left = 32 - (start + width);
|
||||
int shift_right = shift_left + start;
|
||||
array p = multiply(w, array(1 << shift_left, uint32), s);
|
||||
p = floor_divide(p, array(1 << shift_right, uint32), s);
|
||||
p = expand_dims(p, -1, s);
|
||||
parts.push_back(p);
|
||||
}
|
||||
array w_full = concatenate(parts, -1, s);
|
||||
|
||||
// Dequantize
|
||||
w_full = reshape(w_full, {w.shape(0), -1, groups}, s);
|
||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
||||
w_full = reshape(w_full, {w.shape(0), -1}, s);
|
||||
|
||||
return w_full;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
16
mlx/ops.h
16
mlx/ops.h
@ -1041,4 +1041,20 @@ array quantized_matmul(
|
||||
int width = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
std::tuple<array, array, array> quantize(
|
||||
const array& w,
|
||||
int groups = 128,
|
||||
int width = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Dequantize a matrix produced by quantize() */
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int groups = 128,
|
||||
int width = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -38,6 +38,7 @@ from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
|
||||
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
TransformerEncoder,
|
||||
|
@ -258,6 +258,44 @@ class Module(dict):
|
||||
filter_fn = filter_fn or Module.valid_parameter_filter
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
|
||||
def update_modules(self, modules: dict):
|
||||
"""Replace the child modules of this :class:`Module` instance with the
|
||||
provided ones in the dict of dicts and lists.
|
||||
|
||||
It is the equivalent of :meth:`Module.update` but for modules instead
|
||||
of parameters and allows us to flexibly edit complex architectures by
|
||||
programmatically swapping layers.
|
||||
|
||||
The passed in parameters dictionary need not be a full dictionary
|
||||
similar to :meth:`parameters`. Only the provided locations will be
|
||||
updated.
|
||||
|
||||
Args:
|
||||
modules (dict): A complete or partial dictionary of the modules
|
||||
submodules.
|
||||
"""
|
||||
|
||||
def apply(dst, modules):
|
||||
if isinstance(modules, dict):
|
||||
for k in modules:
|
||||
if k in dst:
|
||||
current_value = dst[k]
|
||||
new_value = modules[k]
|
||||
if self.is_module(current_value) and self.is_module(new_value):
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif isinstance(modules, list):
|
||||
for i in range(len(dst)):
|
||||
current_value = dst[i]
|
||||
new_value = modules[i]
|
||||
if self.is_module(current_value) and self.is_module(new_value):
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
|
||||
apply(self, modules)
|
||||
|
||||
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
||||
"""Apply a function to all the modules in this instance (including this
|
||||
instance).
|
||||
|
124
python/mlx/nn/layers/quantized.py
Normal file
124
python/mlx/nn/layers/quantized.py
Normal file
@ -0,0 +1,124 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
class QuantizedLinear(Module):
|
||||
"""Applies an affine transformation to the input using a quantized weight matrix.
|
||||
|
||||
It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its
|
||||
parameters are frozen and will not be included in any gradient computation
|
||||
but this will probably change in the future.
|
||||
|
||||
QuantizedLinear also provides two useful classmethods to convert linear
|
||||
layers to QuantizedLinear layers.
|
||||
|
||||
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
|
||||
linear transformation up to the quantization error.
|
||||
- :meth:`quantize_module` swaps all the linear layers of the passed module
|
||||
with QuantizedLinear ones.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool): If set to ``False`` then the layer will not use a bias.
|
||||
(default: True).
|
||||
groups (int): The group size to use for the quantized weight. See
|
||||
:func:`~mlx.core.quantize`. (default: 128)
|
||||
width (int): The bit width to use for the quantized weight. See
|
||||
:func:`~mlx.core.quantize`. (default: 4)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
bias: bool = True,
|
||||
groups: int = 64,
|
||||
width: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.groups = groups
|
||||
self.width = width
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, groups, width)
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
self.bias = mx.zeros((output_dims,))
|
||||
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
def _extra_repr(self):
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= 32 // self.width
|
||||
return (
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
|
||||
f"groups={self.groups}, width={self.width}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
self.weight.T,
|
||||
scales=self.scales,
|
||||
biases=self.biases,
|
||||
groups=self.groups,
|
||||
width=self.width,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4):
|
||||
"""Create a QuantizedLinear layer from the parameters of a provided
|
||||
linear layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, groups, width)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, groups, width
|
||||
)
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
return ql
|
||||
|
||||
@classmethod
|
||||
def quantize_module(
|
||||
cls,
|
||||
model: Module,
|
||||
groups: int = 64,
|
||||
width: int = 4,
|
||||
linear_class_predicate=lambda m: isinstance(m, Linear),
|
||||
):
|
||||
def _quantize_if_linear(m):
|
||||
if linear_class_predicate(m):
|
||||
return cls.from_linear(m, groups, width)
|
||||
else:
|
||||
return m
|
||||
|
||||
leaves = model.leaf_modules()
|
||||
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
|
||||
model.update_modules(leaves)
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
|
||||
def tree_map(fn, tree, *rest):
|
||||
def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
|
||||
@ -10,6 +10,9 @@ def tree_map(fn, tree, *rest):
|
||||
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
|
||||
than to :func:`map`.
|
||||
|
||||
The keyword argument ``is_leaf`` decides what constitutes a leaf from
|
||||
``tree`` similar to :func:`tree_flatten`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.nn as nn
|
||||
@ -26,21 +29,28 @@ def tree_map(fn, tree, *rest):
|
||||
fn (Callable): The function that processes the leaves of the tree
|
||||
tree (Any): The main python tree that will be iterated upon
|
||||
rest (Tuple[Any]): Extra trees to be iterated together with tree
|
||||
is_leaf (Optional[Callable]): An optional callable that returns True if
|
||||
the passed object is considered a leaf or False otherwise.
|
||||
|
||||
Returns:
|
||||
A python tree with the new values returned by ``fn``.
|
||||
"""
|
||||
if isinstance(tree, list):
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
return fn(tree, *rest)
|
||||
elif isinstance(tree, list):
|
||||
return [
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple(
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
)
|
||||
elif isinstance(tree, dict):
|
||||
return {
|
||||
k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items()
|
||||
k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
|
||||
for k, child in tree.items()
|
||||
}
|
||||
else:
|
||||
return fn(tree, *rest)
|
||||
|
@ -3035,4 +3035,101 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): The result of the multiplication of ``x`` with ``w``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"quantize",
|
||||
&quantize,
|
||||
"w"_a,
|
||||
py::pos_only(),
|
||||
"groups"_a = 128,
|
||||
"width"_a = 4,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]
|
||||
|
||||
Quantize the matrix ``w`` using ``width`` bits per element.
|
||||
|
||||
Note, every ``groups`` elements in a row of ``w`` are quantized
|
||||
together. Hence, number of columns of ``w`` should be divisible by
|
||||
``groups``. In particular, the rows of ``w`` are divided into groups of
|
||||
size ``groups`` which are quantized together.
|
||||
|
||||
.. warning::
|
||||
|
||||
``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32
|
||||
|
||||
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
|
||||
:math:`w_g` in a row of ``w`` we compute the quantized representation
|
||||
of each element :math:`\hat{w_i}` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
\alpha &= \max_i w_i \\
|
||||
\beta &= \min_i w_i \\
|
||||
s &= \frac{\alpha - \beta}{2^b - 1} \\
|
||||
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
|
||||
\end{aligned}
|
||||
|
||||
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
|
||||
and is packed in an unsigned 32-bit integer from the lower to upper
|
||||
bits. For instance, for 4-bit quantization we fit 8 elements in an
|
||||
unsigned 32 bit integer where the 1st element occupies the 4 least
|
||||
significant bits, the 2nd bits 4-7 etc.
|
||||
|
||||
In order to be able to dequantize the elements of ``w`` we also need to
|
||||
save :math:`s` and :math:`\beta` which are the returned ``scales`` and
|
||||
``biases`` respectively.
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
groups (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: 128)
|
||||
width (int, optional): The bitwidth of the elements in ``w``.
|
||||
(default: 4)
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing
|
||||
|
||||
- w_q (array): The quantized version of ``w``
|
||||
- scales (array): The scale to multiply each element with, namely :math:`s`
|
||||
- biases (array): The biases to add to each element, namely :math:`\beta`
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"dequantize",
|
||||
&dequantize,
|
||||
"w"_a,
|
||||
py::pos_only(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"groups"_a = 128,
|
||||
"width"_a = 4,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||
``biases`` and the ``groups`` and ``width`` configuration.
|
||||
|
||||
Formally, given the notation in :func:`quantize`, we compute
|
||||
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
||||
:math:`\beta` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
w_i = s \hat{w_i} - \beta
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
scales (array): The scales to use per ``groups`` elements of ``w``
|
||||
biases (array): The biases to use per ``groups`` elements of ``w``
|
||||
groups (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: 128)
|
||||
width (int, optional): The bitwidth of the elements in ``w``.
|
||||
(default: 4)
|
||||
|
||||
Returns:
|
||||
result (array): The dequantized version of w
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -6,48 +6,14 @@ import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
def select_bits(w, width, start):
|
||||
shift_left = 32 - (start + width)
|
||||
shift_right = shift_left + start
|
||||
return (w * (2**shift_left)) // (2**shift_right)
|
||||
|
||||
|
||||
def dequantize(w, scales, biases, width):
|
||||
w_full = mx.concatenate(
|
||||
[select_bits(w, width, i)[..., None] for i in range(0, 32, width)], axis=-1
|
||||
)
|
||||
w_full = w_full.reshape(len(w), scales.shape[-1], -1)
|
||||
w_full = scales[..., None] * w_full + biases[..., None]
|
||||
w_full = w_full.reshape(len(w), -1)
|
||||
|
||||
return w_full
|
||||
|
||||
|
||||
def quantize(w, width, groups):
|
||||
w = w.reshape(len(w), -1, groups)
|
||||
w_max = w.max(-1, keepdims=True)
|
||||
w_min = w.min(-1, keepdims=True)
|
||||
delta = (w_max - w_min) / (2**width - 1)
|
||||
|
||||
w_int = mx.round((w - w_min) / delta).astype(mx.uint32)
|
||||
scales = delta.squeeze(-1)
|
||||
biases = w_min.squeeze(-1)
|
||||
|
||||
shifts = mx.array([2**i for i in range(0, 32, width)], dtype=mx.uint32)
|
||||
w_int = w_int.reshape(len(w), -1, 32 // width)
|
||||
w_int = w_int * shifts[None, None]
|
||||
packed_w = w_int.sum(-1)
|
||||
|
||||
return packed_w, scales, biases
|
||||
|
||||
|
||||
class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_quantize_dequantize(self):
|
||||
w = mx.random.normal(shape=(128, 128))
|
||||
w_q, scales, biases = quantize(w, 4, 64)
|
||||
w_hat = dequantize(w_q, scales, biases, 4)
|
||||
w_hat2 = dequantize(*quantize(w_hat, 4, 64), 4)
|
||||
self.assertLess((w_hat - w_hat2).abs().max(), 1e-6)
|
||||
for b in [2, 4, 8]:
|
||||
w_q, scales, biases = mx.quantize(w, 64, b)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
|
||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||
self.assertTrue((errors <= scales[..., None] / 2).all())
|
||||
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
@ -62,14 +28,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
):
|
||||
x = mx.random.normal(shape=(M, K), key=k1)
|
||||
w = mx.random.normal(shape=(N, K), key=k2)
|
||||
w_q, scales, biases = quantize(w, width, groups)
|
||||
w_hat = dequantize(w_q, scales, biases, width)
|
||||
w_q, scales, biases = mx.quantize(w, groups, width)
|
||||
w_hat = mx.dequantize(
|
||||
w_q, scales, biases, groups, width
|
||||
)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, width=width, groups=groups
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qmm_shapes(self):
|
||||
key = mx.random.key(0)
|
||||
@ -77,8 +45,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
groups = 64
|
||||
width = 4
|
||||
w = mx.random.normal(shape=(32, 128), key=k2)
|
||||
w_q, scales, biases = quantize(w, width, groups)
|
||||
w_hat = dequantize(w_q, scales, biases, width)
|
||||
w_q, scales, biases = mx.quantize(w, groups, width)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, groups, width)
|
||||
for s in [(3, 128), (2, 1, 7, 128)]:
|
||||
x = mx.random.normal(shape=(3, 128), key=k1)
|
||||
y_q = mx.quantized_matmul(
|
||||
@ -86,7 +54,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qmv(self):
|
||||
key = mx.random.key(0)
|
||||
@ -95,17 +63,17 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
for width in [2, 4, 8]:
|
||||
for M in [512, 1024]:
|
||||
for N in [512, 1024]:
|
||||
# with self.subTest(shape=(M, N), groups=groups, width=width):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = quantize(w, width, groups)
|
||||
w_hat = dequantize(w_q, scales, biases, width)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, width=width, groups=groups
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||
with self.subTest(shape=(M, N), groups=groups, width=width):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, groups, width)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, groups, width)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, width=width, groups=groups
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2215,3 +2215,21 @@ TEST_CASE("test linspace") {
|
||||
expected = array(std::initializer_list<float>{}, {0});
|
||||
CHECK(array_equal(x, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test quantize dequantize") {
|
||||
auto x1 = ones({128, 1});
|
||||
auto x2 = expand_dims(arange(0, 128, float32), 0);
|
||||
auto x = x1 * x2;
|
||||
|
||||
for (int i = 2; i <= 8; i *= 2) {
|
||||
int el_per_int = 32 / i;
|
||||
auto [x_q, scales, biases] = quantize(x, 128, i);
|
||||
CHECK_EQ(x_q.shape(), std::vector<int>{128, 128 / el_per_int});
|
||||
CHECK_EQ(scales.shape(), std::vector<int>{128, 1});
|
||||
CHECK_EQ(biases.shape(), std::vector<int>{128, 1});
|
||||
|
||||
auto x_hat = dequantize(x_q, scales, biases, 128, i);
|
||||
auto max_diff = max(abs(x - x_hat)).item<float>();
|
||||
CHECK(max_diff <= 127.0 / (1 << i));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user