Adds C++ and nn quantization utilities (#230)

* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
This commit is contained in:
Angelos Katharopoulos
2023-12-20 14:17:38 -08:00
committed by GitHub
parent 4912ff3ec2
commit 57fe918cf8
12 changed files with 451 additions and 68 deletions

View File

@@ -3035,4 +3035,101 @@ void init_ops(py::module_& m) {
Returns:
result (array): The result of the multiplication of ``x`` with ``w``.
)pbdoc");
m.def(
"quantize",
&quantize,
"w"_a,
py::pos_only(),
"groups"_a = 128,
"width"_a = 4,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]
Quantize the matrix ``w`` using ``width`` bits per element.
Note, every ``groups`` elements in a row of ``w`` are quantized
together. Hence, number of columns of ``w`` should be divisible by
``groups``. In particular, the rows of ``w`` are divided into groups of
size ``groups`` which are quantized together.
.. warning::
``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
:math:`w_g` in a row of ``w`` we compute the quantized representation
of each element :math:`\hat{w_i}` as follows
.. math::
\begin{aligned}
\alpha &= \max_i w_i \\
\beta &= \min_i w_i \\
s &= \frac{\alpha - \beta}{2^b - 1} \\
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
\end{aligned}
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
and is packed in an unsigned 32-bit integer from the lower to upper
bits. For instance, for 4-bit quantization we fit 8 elements in an
unsigned 32 bit integer where the 1st element occupies the 4 least
significant bits, the 2nd bits 4-7 etc.
In order to be able to dequantize the elements of ``w`` we also need to
save :math:`s` and :math:`\beta` which are the returned ``scales`` and
``biases`` respectively.
Args:
w (array): Matrix to be quantized
groups (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 128)
width (int, optional): The bitwidth of the elements in ``w``.
(default: 4)
Returns:
(tuple): A tuple containing
- w_q (array): The quantized version of ``w``
- scales (array): The scale to multiply each element with, namely :math:`s`
- biases (array): The biases to add to each element, namely :math:`\beta`
)pbdoc");
m.def(
"dequantize",
&dequantize,
"w"_a,
py::pos_only(),
"scales"_a,
"biases"_a,
"groups"_a = 128,
"width"_a = 4,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
Dequantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``groups`` and ``width`` configuration.
Formally, given the notation in :func:`quantize`, we compute
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
:math:`\beta` as follows
.. math::
w_i = s \hat{w_i} - \beta
Args:
w (array): Matrix to be quantized
scales (array): The scales to use per ``groups`` elements of ``w``
biases (array): The biases to use per ``groups`` elements of ``w``
groups (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 128)
width (int, optional): The bitwidth of the elements in ``w``.
(default: 4)
Returns:
result (array): The dequantized version of w
)pbdoc");
}