Adds C++ and nn quantization utilities (#230)

* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
This commit is contained in:
Angelos Katharopoulos 2023-12-20 14:17:38 -08:00 committed by GitHub
parent 4912ff3ec2
commit 57fe918cf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 451 additions and 68 deletions

View File

@ -26,3 +26,4 @@ Layers
RoPE RoPE
MultiHeadAttention MultiHeadAttention
Sequential Sequential
QuantizedLinear

View File

@ -34,6 +34,7 @@ Operations
conv2d conv2d
cos cos
cosh cosh
dequantize
divide divide
equal equal
erf erf
@ -73,6 +74,8 @@ Operations
partition partition
pad pad
prod prod
quantize
quantized_matmul
reciprocal reciprocal
reshape reshape
round round

View File

@ -2649,4 +2649,111 @@ array quantized_matmul(
return out; return out;
} }
std::tuple<array, array, array> quantize(
const array& w,
int groups /* = 128 */,
int width /* = 4 */,
StreamOrDevice s /* = {} */) {
if (w.ndim() != 2) {
throw std::invalid_argument("[quantize] Only matrices supported for now");
}
if ((w.shape(0) % 32) != 0) {
throw std::invalid_argument(
"[quantize] All dimensions should be divisible by 32 for now");
}
if ((w.shape(-1) % groups) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << groups
<< ". However the provided matrix"
<< " has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
// Compute some constants used for the quantization
int n_bins = (1 << width) - 1; // 2**width - 1
int el_per_int = 32 / width;
array shifts = power(array(2, uint32), arange(0, 32, width, uint32, s), s);
shifts = reshape(shifts, {1, 1, -1}, s);
// Compute scales and biases
array packed_w = reshape(w, {w.shape(0), w.shape(1) / groups, groups}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s);
array scales = squeeze(delta, -1, s);
array biases = squeeze(w_min, -1, s);
// Quantize and pack w
packed_w =
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return std::make_tuple(packed_w, scales, biases);
}
array dequantize(
const array& w,
const array& scales,
const array& biases,
int groups /* = 128 */,
int width /* = 4 */,
StreamOrDevice s /* = {} */) {
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
throw std::invalid_argument("[dequantize] Only matrices supported for now");
}
if ((w.shape(0) % 32) != 0) {
throw std::invalid_argument(
"[dequantize] All dimensions should be divisible by 32 for now");
}
if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) {
throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix");
}
if (w.dtype() != uint32) {
throw std::invalid_argument(
"[dequantize] The matrix should be given as a uint32");
}
// Compute some constants for the dequantization
int el_per_int = 32 / width;
if (w.shape(1) * el_per_int != scales.shape(1) * groups) {
std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
<< w.shape() << " and scales/biases of shape " << scales.shape()
<< " with groups=" << groups << " and width=" << width << ".";
throw std::invalid_argument(msg.str());
}
// Extract the pieces from the passed quantized matrix
std::vector<array> parts;
for (int start = 0; start < 32; start += width) {
// TODO: Implement bitwise operators for integral types
int shift_left = 32 - (start + width);
int shift_right = shift_left + start;
array p = multiply(w, array(1 << shift_left, uint32), s);
p = floor_divide(p, array(1 << shift_right, uint32), s);
p = expand_dims(p, -1, s);
parts.push_back(p);
}
array w_full = concatenate(parts, -1, s);
// Dequantize
w_full = reshape(w_full, {w.shape(0), -1, groups}, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s);
w_full = reshape(w_full, {w.shape(0), -1}, s);
return w_full;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1041,4 +1041,20 @@ array quantized_matmul(
int width = 4, int width = 4,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Quantize a matrix along its last axis */
std::tuple<array, array, array> quantize(
const array& w,
int groups = 128,
int width = 4,
StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */
array dequantize(
const array& w,
const array& scales,
const array& biases,
int groups = 128,
int width = 4,
StreamOrDevice s = {});
} // namespace mlx::core } // namespace mlx::core

View File

@ -38,6 +38,7 @@ from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import ( from mlx.nn.layers.transformer import (
MultiHeadAttention, MultiHeadAttention,
TransformerEncoder, TransformerEncoder,

View File

@ -258,6 +258,44 @@ class Module(dict):
filter_fn = filter_fn or Module.valid_parameter_filter filter_fn = filter_fn or Module.valid_parameter_filter
self.update(self.filter_and_map(filter_fn, map_fn)) self.update(self.filter_and_map(filter_fn, map_fn))
def update_modules(self, modules: dict):
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
It is the equivalent of :meth:`Module.update` but for modules instead
of parameters and allows us to flexibly edit complex architectures by
programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be
updated.
Args:
modules (dict): A complete or partial dictionary of the modules
submodules.
"""
def apply(dst, modules):
if isinstance(modules, dict):
for k in modules:
if k in dst:
current_value = dst[k]
new_value = modules[k]
if self.is_module(current_value) and self.is_module(new_value):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif isinstance(modules, list):
for i in range(len(dst)):
current_value = dst[i]
new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
apply(self, modules)
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]): def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
"""Apply a function to all the modules in this instance (including this """Apply a function to all the modules in this instance (including this
instance). instance).

View File

@ -0,0 +1,124 @@
# Copyright © 2023 Apple Inc.
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.linear import Linear
from mlx.utils import tree_flatten, tree_map
class QuantizedLinear(Module):
"""Applies an affine transformation to the input using a quantized weight matrix.
It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its
parameters are frozen and will not be included in any gradient computation
but this will probably change in the future.
QuantizedLinear also provides two useful classmethods to convert linear
layers to QuantizedLinear layers.
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
linear transformation up to the quantization error.
- :meth:`quantize_module` swaps all the linear layers of the passed module
with QuantizedLinear ones.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool): If set to ``False`` then the layer will not use a bias.
(default: True).
groups (int): The group size to use for the quantized weight. See
:func:`~mlx.core.quantize`. (default: 128)
width (int): The bit width to use for the quantized weight. See
:func:`~mlx.core.quantize`. (default: 4)
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
groups: int = 64,
width: int = 4,
):
super().__init__()
# Quantization config
self.groups = groups
self.width = width
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, groups, width)
# And bias if needed
if bias:
self.bias = mx.zeros((output_dims,))
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
def _extra_repr(self):
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.width
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
f"groups={self.groups}, width={self.width}"
)
def __call__(self, x):
x = mx.quantized_matmul(
x,
self.weight.T,
scales=self.scales,
biases=self.biases,
groups=self.groups,
width=self.width,
)
if "bias" in self:
x = x + self.bias
return x
@classmethod
def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4):
"""Create a QuantizedLinear layer from the parameters of a provided
linear layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, groups, width)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, groups, width
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias
return ql
@classmethod
def quantize_module(
cls,
model: Module,
groups: int = 64,
width: int = 4,
linear_class_predicate=lambda m: isinstance(m, Linear),
):
def _quantize_if_linear(m):
if linear_class_predicate(m):
return cls.from_linear(m, groups, width)
else:
return m
leaves = model.leaf_modules()
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
model.update_modules(leaves)

View File

@ -445,13 +445,13 @@ class Adamax(Adam):
class Lion(Optimizer): class Lion(Optimizer):
r"""Implementation of the Lion optimizer [1]. r"""Implementation of the Lion optimizer [1].
Since updates are computed through the sign operation, they tend to Since updates are computed through the sign operation, they tend to
have larger norm than for other optimizers such as SGD and Adam. have larger norm than for other optimizers such as SGD and Adam.
We recommend a learning rate that is 3-10x smaller than AdamW and a We recommend a learning rate that is 3-10x smaller than AdamW and a
weight decay 3-10x larger than AdamW to maintain the strength weight decay 3-10x larger than AdamW to maintain the strength
(lr * wd). Our Lion implementation follows the original paper. In (lr * wd). Our Lion implementation follows the original paper. In
detail, detail,
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
@ -486,7 +486,7 @@ class Lion(Optimizer):
def apply_single( def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState self, gradient: mx.array, parameter: mx.array, state: OptimizerState
): ):
"""Performs the Lion parameter update and stores :math:`m` """Performs the Lion parameter update and stores :math:`m`
in the optimizer state.""" in the optimizer state."""
lr = self.learning_rate lr = self.learning_rate
b1, b2 = self.betas b1, b2 = self.betas

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
def tree_map(fn, tree, *rest): def tree_map(fn, tree, *rest, is_leaf=None):
"""Applies ``fn`` to the leaves of the python tree ``tree`` and """Applies ``fn`` to the leaves of the python tree ``tree`` and
returns a new collection with the results. returns a new collection with the results.
@ -10,6 +10,9 @@ def tree_map(fn, tree, *rest):
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap` ``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
than to :func:`map`. than to :func:`map`.
The keyword argument ``is_leaf`` decides what constitutes a leaf from
``tree`` similar to :func:`tree_flatten`.
.. code-block:: python .. code-block:: python
import mlx.nn as nn import mlx.nn as nn
@ -26,21 +29,28 @@ def tree_map(fn, tree, *rest):
fn (Callable): The function that processes the leaves of the tree fn (Callable): The function that processes the leaves of the tree
tree (Any): The main python tree that will be iterated upon tree (Any): The main python tree that will be iterated upon
rest (Tuple[Any]): Extra trees to be iterated together with tree rest (Tuple[Any]): Extra trees to be iterated together with tree
is_leaf (Optional[Callable]): An optional callable that returns True if
the passed object is considered a leaf or False otherwise.
Returns: Returns:
A python tree with the new values returned by ``fn``. A python tree with the new values returned by ``fn``.
""" """
if isinstance(tree, list): if is_leaf is not None and is_leaf(tree):
return fn(tree, *rest)
elif isinstance(tree, list):
return [ return [
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree) tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
for i, child in enumerate(tree)
] ]
elif isinstance(tree, tuple): elif isinstance(tree, tuple):
return tuple( return tuple(
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree) tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
for i, child in enumerate(tree)
) )
elif isinstance(tree, dict): elif isinstance(tree, dict):
return { return {
k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items() k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
for k, child in tree.items()
} }
else: else:
return fn(tree, *rest) return fn(tree, *rest)

View File

@ -3035,4 +3035,101 @@ void init_ops(py::module_& m) {
Returns: Returns:
result (array): The result of the multiplication of ``x`` with ``w``. result (array): The result of the multiplication of ``x`` with ``w``.
)pbdoc"); )pbdoc");
m.def(
"quantize",
&quantize,
"w"_a,
py::pos_only(),
"groups"_a = 128,
"width"_a = 4,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]
Quantize the matrix ``w`` using ``width`` bits per element.
Note, every ``groups`` elements in a row of ``w`` are quantized
together. Hence, number of columns of ``w`` should be divisible by
``groups``. In particular, the rows of ``w`` are divided into groups of
size ``groups`` which are quantized together.
.. warning::
``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
:math:`w_g` in a row of ``w`` we compute the quantized representation
of each element :math:`\hat{w_i}` as follows
.. math::
\begin{aligned}
\alpha &= \max_i w_i \\
\beta &= \min_i w_i \\
s &= \frac{\alpha - \beta}{2^b - 1} \\
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
\end{aligned}
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
and is packed in an unsigned 32-bit integer from the lower to upper
bits. For instance, for 4-bit quantization we fit 8 elements in an
unsigned 32 bit integer where the 1st element occupies the 4 least
significant bits, the 2nd bits 4-7 etc.
In order to be able to dequantize the elements of ``w`` we also need to
save :math:`s` and :math:`\beta` which are the returned ``scales`` and
``biases`` respectively.
Args:
w (array): Matrix to be quantized
groups (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 128)
width (int, optional): The bitwidth of the elements in ``w``.
(default: 4)
Returns:
(tuple): A tuple containing
- w_q (array): The quantized version of ``w``
- scales (array): The scale to multiply each element with, namely :math:`s`
- biases (array): The biases to add to each element, namely :math:`\beta`
)pbdoc");
m.def(
"dequantize",
&dequantize,
"w"_a,
py::pos_only(),
"scales"_a,
"biases"_a,
"groups"_a = 128,
"width"_a = 4,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
Dequantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``groups`` and ``width`` configuration.
Formally, given the notation in :func:`quantize`, we compute
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
:math:`\beta` as follows
.. math::
w_i = s \hat{w_i} - \beta
Args:
w (array): Matrix to be quantized
scales (array): The scales to use per ``groups`` elements of ``w``
biases (array): The biases to use per ``groups`` elements of ``w``
groups (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 128)
width (int, optional): The bitwidth of the elements in ``w``.
(default: 4)
Returns:
result (array): The dequantized version of w
)pbdoc");
} }

View File

@ -6,48 +6,14 @@ import mlx.core as mx
import mlx_tests import mlx_tests
def select_bits(w, width, start):
shift_left = 32 - (start + width)
shift_right = shift_left + start
return (w * (2**shift_left)) // (2**shift_right)
def dequantize(w, scales, biases, width):
w_full = mx.concatenate(
[select_bits(w, width, i)[..., None] for i in range(0, 32, width)], axis=-1
)
w_full = w_full.reshape(len(w), scales.shape[-1], -1)
w_full = scales[..., None] * w_full + biases[..., None]
w_full = w_full.reshape(len(w), -1)
return w_full
def quantize(w, width, groups):
w = w.reshape(len(w), -1, groups)
w_max = w.max(-1, keepdims=True)
w_min = w.min(-1, keepdims=True)
delta = (w_max - w_min) / (2**width - 1)
w_int = mx.round((w - w_min) / delta).astype(mx.uint32)
scales = delta.squeeze(-1)
biases = w_min.squeeze(-1)
shifts = mx.array([2**i for i in range(0, 32, width)], dtype=mx.uint32)
w_int = w_int.reshape(len(w), -1, 32 // width)
w_int = w_int * shifts[None, None]
packed_w = w_int.sum(-1)
return packed_w, scales, biases
class TestQuantized(mlx_tests.MLXTestCase): class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self): def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 128)) w = mx.random.normal(shape=(128, 128))
w_q, scales, biases = quantize(w, 4, 64) for b in [2, 4, 8]:
w_hat = dequantize(w_q, scales, biases, 4) w_q, scales, biases = mx.quantize(w, 64, b)
w_hat2 = dequantize(*quantize(w_hat, 4, 64), 4) w_hat = mx.dequantize(w_q, scales, biases, 64, b)
self.assertLess((w_hat - w_hat2).abs().max(), 1e-6) errors = (w - w_hat).abs().reshape(*scales.shape, -1)
self.assertTrue((errors <= scales[..., None] / 2).all())
def test_qmm(self): def test_qmm(self):
key = mx.random.key(0) key = mx.random.key(0)
@ -62,14 +28,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
): ):
x = mx.random.normal(shape=(M, K), key=k1) x = mx.random.normal(shape=(M, K), key=k1)
w = mx.random.normal(shape=(N, K), key=k2) w = mx.random.normal(shape=(N, K), key=k2)
w_q, scales, biases = quantize(w, width, groups) w_q, scales, biases = mx.quantize(w, groups, width)
w_hat = dequantize(w_q, scales, biases, width) w_hat = mx.dequantize(
w_q, scales, biases, groups, width
)
y_q = mx.quantized_matmul( y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, width=width, groups=groups x, w_q.T, scales, biases, width=width, groups=groups
) )
y_hat = x @ w_hat.T y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 0.1) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmm_shapes(self): def test_qmm_shapes(self):
key = mx.random.key(0) key = mx.random.key(0)
@ -77,8 +45,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
groups = 64 groups = 64
width = 4 width = 4
w = mx.random.normal(shape=(32, 128), key=k2) w = mx.random.normal(shape=(32, 128), key=k2)
w_q, scales, biases = quantize(w, width, groups) w_q, scales, biases = mx.quantize(w, groups, width)
w_hat = dequantize(w_q, scales, biases, width) w_hat = mx.dequantize(w_q, scales, biases, groups, width)
for s in [(3, 128), (2, 1, 7, 128)]: for s in [(3, 128), (2, 1, 7, 128)]:
x = mx.random.normal(shape=(3, 128), key=k1) x = mx.random.normal(shape=(3, 128), key=k1)
y_q = mx.quantized_matmul( y_q = mx.quantized_matmul(
@ -86,7 +54,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
) )
y_hat = x @ w_hat.T y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 0.1) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmv(self): def test_qmv(self):
key = mx.random.key(0) key = mx.random.key(0)
@ -95,17 +63,17 @@ class TestQuantized(mlx_tests.MLXTestCase):
for width in [2, 4, 8]: for width in [2, 4, 8]:
for M in [512, 1024]: for M in [512, 1024]:
for N in [512, 1024]: for N in [512, 1024]:
# with self.subTest(shape=(M, N), groups=groups, width=width): with self.subTest(shape=(M, N), groups=groups, width=width):
x = mx.random.normal(shape=(1, N), key=k1) x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(M, N), key=k2) w = mx.random.normal(shape=(M, N), key=k2)
w_q, scales, biases = quantize(w, width, groups) w_q, scales, biases = mx.quantize(w, groups, width)
w_hat = dequantize(w_q, scales, biases, width) w_hat = mx.dequantize(w_q, scales, biases, groups, width)
y_q = mx.quantized_matmul( y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, width=width, groups=groups x, w_q.T, scales, biases, width=width, groups=groups
) )
y_hat = x @ w_hat.T y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 0.1) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2215,3 +2215,21 @@ TEST_CASE("test linspace") {
expected = array(std::initializer_list<float>{}, {0}); expected = array(std::initializer_list<float>{}, {0});
CHECK(array_equal(x, expected).item<bool>()); CHECK(array_equal(x, expected).item<bool>());
} }
TEST_CASE("test quantize dequantize") {
auto x1 = ones({128, 1});
auto x2 = expand_dims(arange(0, 128, float32), 0);
auto x = x1 * x2;
for (int i = 2; i <= 8; i *= 2) {
int el_per_int = 32 / i;
auto [x_q, scales, biases] = quantize(x, 128, i);
CHECK_EQ(x_q.shape(), std::vector<int>{128, 128 / el_per_int});
CHECK_EQ(scales.shape(), std::vector<int>{128, 1});
CHECK_EQ(biases.shape(), std::vector<int>{128, 1});
auto x_hat = dequantize(x_q, scales, biases, 128, i);
auto max_diff = max(abs(x - x_hat)).item<float>();
CHECK(max_diff <= 127.0 / (1 << i));
}
}