Improve names of quantization arguments (#235)

* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
This commit is contained in:
Angelos Katharopoulos
2023-12-20 16:53:53 -08:00
committed by GitHub
parent 57fe918cf8
commit b3916cbf2b
11 changed files with 184 additions and 180 deletions

View File

@@ -26,12 +26,12 @@ class QuantizedLinear(Module):
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool): If set to ``False`` then the layer will not use a bias.
(default: True).
groups (int): The group size to use for the quantized weight. See
:func:`~mlx.core.quantize`. (default: 128)
width (int): The bit width to use for the quantized weight. See
:func:`~mlx.core.quantize`. (default: 4)
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. (default: True).
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. (default: 64)
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. (default: 4)
"""
def __init__(
@@ -39,14 +39,14 @@ class QuantizedLinear(Module):
input_dims: int,
output_dims: int,
bias: bool = True,
groups: int = 64,
width: int = 4,
group_size: int = 64,
bits: int = 4,
):
super().__init__()
# Quantization config
self.groups = groups
self.width = width
self.group_size = group_size
self.bits = bits
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@@ -55,7 +55,7 @@ class QuantizedLinear(Module):
high=scale,
shape=(output_dims, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, groups, width)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
# And bias if needed
if bias:
@@ -72,10 +72,10 @@ class QuantizedLinear(Module):
def _extra_repr(self):
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.width
in_dims *= 32 // self.bits
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
f"groups={self.groups}, width={self.width}"
f"group_size={self.group_size}, bits={self.bits}"
)
def __call__(self, x):
@@ -84,21 +84,21 @@ class QuantizedLinear(Module):
self.weight.T,
scales=self.scales,
biases=self.biases,
groups=self.groups,
width=self.width,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + self.bias
return x
@classmethod
def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4):
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
"""Create a QuantizedLinear layer from the parameters of a provided
linear layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, groups, width)
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, groups, width
linear_layer.weight, group_size, bits
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias
@@ -109,13 +109,13 @@ class QuantizedLinear(Module):
def quantize_module(
cls,
model: Module,
groups: int = 64,
width: int = 4,
group_size: int = 64,
bits: int = 4,
linear_class_predicate=lambda m: isinstance(m, Linear),
):
def _quantize_if_linear(m):
if linear_class_predicate(m):
return cls.from_linear(m, groups, width)
return cls.from_linear(m, group_size, bits)
else:
return m