mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 10:18:10 +08:00
Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops * Add quantize functions to the docs and tests * Add a QuantizedLinear module
This commit is contained in:

committed by
GitHub

parent
4912ff3ec2
commit
57fe918cf8
@@ -3035,4 +3035,101 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): The result of the multiplication of ``x`` with ``w``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"quantize",
|
||||
&quantize,
|
||||
"w"_a,
|
||||
py::pos_only(),
|
||||
"groups"_a = 128,
|
||||
"width"_a = 4,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]
|
||||
|
||||
Quantize the matrix ``w`` using ``width`` bits per element.
|
||||
|
||||
Note, every ``groups`` elements in a row of ``w`` are quantized
|
||||
together. Hence, number of columns of ``w`` should be divisible by
|
||||
``groups``. In particular, the rows of ``w`` are divided into groups of
|
||||
size ``groups`` which are quantized together.
|
||||
|
||||
.. warning::
|
||||
|
||||
``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32
|
||||
|
||||
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
|
||||
:math:`w_g` in a row of ``w`` we compute the quantized representation
|
||||
of each element :math:`\hat{w_i}` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
\alpha &= \max_i w_i \\
|
||||
\beta &= \min_i w_i \\
|
||||
s &= \frac{\alpha - \beta}{2^b - 1} \\
|
||||
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
|
||||
\end{aligned}
|
||||
|
||||
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
|
||||
and is packed in an unsigned 32-bit integer from the lower to upper
|
||||
bits. For instance, for 4-bit quantization we fit 8 elements in an
|
||||
unsigned 32 bit integer where the 1st element occupies the 4 least
|
||||
significant bits, the 2nd bits 4-7 etc.
|
||||
|
||||
In order to be able to dequantize the elements of ``w`` we also need to
|
||||
save :math:`s` and :math:`\beta` which are the returned ``scales`` and
|
||||
``biases`` respectively.
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
groups (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: 128)
|
||||
width (int, optional): The bitwidth of the elements in ``w``.
|
||||
(default: 4)
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing
|
||||
|
||||
- w_q (array): The quantized version of ``w``
|
||||
- scales (array): The scale to multiply each element with, namely :math:`s`
|
||||
- biases (array): The biases to add to each element, namely :math:`\beta`
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"dequantize",
|
||||
&dequantize,
|
||||
"w"_a,
|
||||
py::pos_only(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"groups"_a = 128,
|
||||
"width"_a = 4,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||
``biases`` and the ``groups`` and ``width`` configuration.
|
||||
|
||||
Formally, given the notation in :func:`quantize`, we compute
|
||||
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
||||
:math:`\beta` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
w_i = s \hat{w_i} - \beta
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
scales (array): The scales to use per ``groups`` elements of ``w``
|
||||
biases (array): The biases to use per ``groups`` elements of ``w``
|
||||
groups (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: 128)
|
||||
width (int, optional): The bitwidth of the elements in ``w``.
|
||||
(default: 4)
|
||||
|
||||
Returns:
|
||||
result (array): The dequantized version of w
|
||||
)pbdoc");
|
||||
}
|
||||
|
Reference in New Issue
Block a user