mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-21 12:38:10 +08:00
Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64 * Rename groups to group_size and width to bits
This commit is contained in:

committed by
GitHub

parent
57fe918cf8
commit
b3916cbf2b
@@ -26,12 +26,12 @@ class QuantizedLinear(Module):
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool): If set to ``False`` then the layer will not use a bias.
|
||||
(default: True).
|
||||
groups (int): The group size to use for the quantized weight. See
|
||||
:func:`~mlx.core.quantize`. (default: 128)
|
||||
width (int): The bit width to use for the quantized weight. See
|
||||
:func:`~mlx.core.quantize`. (default: 4)
|
||||
bias (bool, optional): If set to ``False`` then the layer will not use
|
||||
a bias. (default: True).
|
||||
group_size (int, optional): The group size to use for the quantized
|
||||
weight. See :func:`~mlx.core.quantize`. (default: 64)
|
||||
bits (int, optional): The bit width to use for the quantized weight.
|
||||
See :func:`~mlx.core.quantize`. (default: 4)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -39,14 +39,14 @@ class QuantizedLinear(Module):
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
bias: bool = True,
|
||||
groups: int = 64,
|
||||
width: int = 4,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.groups = groups
|
||||
self.width = width
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
@@ -55,7 +55,7 @@ class QuantizedLinear(Module):
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, groups, width)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
@@ -72,10 +72,10 @@ class QuantizedLinear(Module):
|
||||
|
||||
def _extra_repr(self):
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= 32 // self.width
|
||||
in_dims *= 32 // self.bits
|
||||
return (
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
|
||||
f"groups={self.groups}, width={self.width}"
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
@@ -84,21 +84,21 @@ class QuantizedLinear(Module):
|
||||
self.weight.T,
|
||||
scales=self.scales,
|
||||
biases=self.biases,
|
||||
groups=self.groups,
|
||||
width=self.width,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4):
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
"""Create a QuantizedLinear layer from the parameters of a provided
|
||||
linear layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, groups, width)
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, groups, width
|
||||
linear_layer.weight, group_size, bits
|
||||
)
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
@@ -109,13 +109,13 @@ class QuantizedLinear(Module):
|
||||
def quantize_module(
|
||||
cls,
|
||||
model: Module,
|
||||
groups: int = 64,
|
||||
width: int = 4,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
linear_class_predicate=lambda m: isinstance(m, Linear),
|
||||
):
|
||||
def _quantize_if_linear(m):
|
||||
if linear_class_predicate(m):
|
||||
return cls.from_linear(m, groups, width)
|
||||
return cls.from_linear(m, group_size, bits)
|
||||
else:
|
||||
return m
|
||||
|
||||
|
Reference in New Issue
Block a user