Improve names of quantization arguments (#235)

* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
This commit is contained in:
Angelos Katharopoulos
2023-12-20 16:53:53 -08:00
committed by GitHub
parent 57fe918cf8
commit b3916cbf2b
11 changed files with 184 additions and 180 deletions

View File

@@ -26,12 +26,12 @@ class QuantizedLinear(Module):
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool): If set to ``False`` then the layer will not use a bias.
(default: True).
groups (int): The group size to use for the quantized weight. See
:func:`~mlx.core.quantize`. (default: 128)
width (int): The bit width to use for the quantized weight. See
:func:`~mlx.core.quantize`. (default: 4)
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. (default: True).
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. (default: 64)
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. (default: 4)
"""
def __init__(
@@ -39,14 +39,14 @@ class QuantizedLinear(Module):
input_dims: int,
output_dims: int,
bias: bool = True,
groups: int = 64,
width: int = 4,
group_size: int = 64,
bits: int = 4,
):
super().__init__()
# Quantization config
self.groups = groups
self.width = width
self.group_size = group_size
self.bits = bits
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@@ -55,7 +55,7 @@ class QuantizedLinear(Module):
high=scale,
shape=(output_dims, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, groups, width)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
# And bias if needed
if bias:
@@ -72,10 +72,10 @@ class QuantizedLinear(Module):
def _extra_repr(self):
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.width
in_dims *= 32 // self.bits
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
f"groups={self.groups}, width={self.width}"
f"group_size={self.group_size}, bits={self.bits}"
)
def __call__(self, x):
@@ -84,21 +84,21 @@ class QuantizedLinear(Module):
self.weight.T,
scales=self.scales,
biases=self.biases,
groups=self.groups,
width=self.width,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + self.bias
return x
@classmethod
def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4):
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
"""Create a QuantizedLinear layer from the parameters of a provided
linear layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, groups, width)
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, groups, width
linear_layer.weight, group_size, bits
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias
@@ -109,13 +109,13 @@ class QuantizedLinear(Module):
def quantize_module(
cls,
model: Module,
groups: int = 64,
width: int = 4,
group_size: int = 64,
bits: int = 4,
linear_class_predicate=lambda m: isinstance(m, Linear),
):
def _quantize_if_linear(m):
if linear_class_predicate(m):
return cls.from_linear(m, groups, width)
return cls.from_linear(m, group_size, bits)
else:
return m

View File

@@ -3011,26 +3011,27 @@ void init_ops(py::module_& m) {
py::pos_only(),
"scales"_a,
"biases"_a,
"groups"_a = 128,
"width"_a = 4,
"group_size"_a = 64,
"bits"_a = 4,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
quantized_matmul(x: array, w: array, scales: array, biases: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
quantized_matmul(x: array, w: array, scales: array, biases: array, /, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``groups`` of
elements. Each element in ``w`` takes ``width`` bits and is packed in an
quantization uses one floating point scale and bias per ``group_size`` of
elements. Each element in ``w`` takes ``bits`` bits and is packed in an
unsigned 32 bit integer.
Args:
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``groups`` elements of ``w``
biases (array): The biases to use per ``groups`` elements of ``w``
groups (int): The size of the group in ``w`` that shares a scale and
bias. (default: 128)
width (int): The bitwidth of the elements in ``w``. (default: 4)
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. (default: 64)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: 4)
Returns:
result (array): The result of the multiplication of ``x`` with ``w``.
@@ -3040,19 +3041,19 @@ void init_ops(py::module_& m) {
&quantize,
"w"_a,
py::pos_only(),
"groups"_a = 128,
"width"_a = 4,
"group_size"_a = 64,
"bits"_a = 4,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]
quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]
Quantize the matrix ``w`` using ``width`` bits per element.
Quantize the matrix ``w`` using ``bits`` bits per element.
Note, every ``groups`` elements in a row of ``w`` are quantized
Note, every ``group_size`` elements in a row of ``w`` are quantized
together. Hence, number of columns of ``w`` should be divisible by
``groups``. In particular, the rows of ``w`` are divided into groups of
size ``groups`` which are quantized together.
``group_size``. In particular, the rows of ``w`` are divided into groups of
size ``group_size`` which are quantized together.
.. warning::
@@ -3083,10 +3084,10 @@ void init_ops(py::module_& m) {
Args:
w (array): Matrix to be quantized
groups (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 128)
width (int, optional): The bitwidth of the elements in ``w``.
(default: 4)
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 64)
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: 4)
Returns:
(tuple): A tuple containing
@@ -3102,15 +3103,15 @@ void init_ops(py::module_& m) {
py::pos_only(),
"scales"_a,
"biases"_a,
"groups"_a = 128,
"width"_a = 4,
"group_size"_a = 64,
"bits"_a = 4,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
Dequantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``groups`` and ``width`` configuration.
``biases`` and the ``group_size`` and ``bits`` configuration.
Formally, given the notation in :func:`quantize`, we compute
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
@@ -3122,14 +3123,14 @@ void init_ops(py::module_& m) {
Args:
w (array): Matrix to be quantized
scales (array): The scales to use per ``groups`` elements of ``w``
biases (array): The biases to use per ``groups`` elements of ``w``
groups (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 128)
width (int, optional): The bitwidth of the elements in ``w``.
(default: 4)
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 64)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: 4)
Returns:
result (array): The dequantized version of w
result (array): The dequantized version of ``w``
)pbdoc");
}

View File

@@ -18,22 +18,22 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
for groups in [128, 64]:
for width in [2, 4, 8]:
for group_size in [128, 64]:
for bits in [2, 4, 8]:
for M in [8, 32, 33, 64]:
for N in [512, 1024]:
for K in [512, 1024]:
with self.subTest(
shape=(M, N, K), groups=groups, width=width
shape=(M, N, K), group_size=group_size, bits=bits
):
x = mx.random.normal(shape=(M, K), key=k1)
w = mx.random.normal(shape=(N, K), key=k2)
w_q, scales, biases = mx.quantize(w, groups, width)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(
w_q, scales, biases, groups, width
w_q, scales, biases, group_size, bits
)
y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, width=width, groups=groups
x, w_q.T, scales, biases, group_size, bits
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
@@ -42,16 +42,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_qmm_shapes(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
groups = 64
width = 4
group_size = 64
bits = 4
w = mx.random.normal(shape=(32, 128), key=k2)
w_q, scales, biases = mx.quantize(w, groups, width)
w_hat = mx.dequantize(w_q, scales, biases, groups, width)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
for s in [(3, 128), (2, 1, 7, 128)]:
x = mx.random.normal(shape=(3, 128), key=k1)
y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, width=width, groups=groups
)
y_q = mx.quantized_matmul(x, w_q.T, scales, biases, group_size, bits)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
@@ -59,17 +57,19 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
for groups in [128, 64]:
for width in [2, 4, 8]:
for group_size in [128, 64]:
for bits in [2, 4, 8]:
for M in [512, 1024]:
for N in [512, 1024]:
with self.subTest(shape=(M, N), groups=groups, width=width):
with self.subTest(
shape=(M, N), group_size=group_size, bits=bits
):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(M, N), key=k2)
w_q, scales, biases = mx.quantize(w, groups, width)
w_hat = mx.dequantize(w_q, scales, biases, groups, width)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, width=width, groups=groups
x, w_q.T, scales, biases, group_size, bits
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)