Add 3bit packed quants

This commit is contained in:
Angelos Katharopoulos
2024-12-17 10:08:47 -08:00
parent 14420949d2
commit d75a509234
4 changed files with 112 additions and 29 deletions

View File

@@ -181,12 +181,14 @@ class QuantizedLinear(Module):
bias: bool = True,
group_size: int = 64,
bits: int = 4,
quantization_type: str = "affine",
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.quantization_type = quantization_type
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@@ -195,7 +197,9 @@ class QuantizedLinear(Module):
high=scale,
shape=(output_dims, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
self.weight, self.scales, self.biases = mx.quantize(
weight, group_size, bits, quantization_type=quantization_type
)
# And bias if needed
if bias:
@@ -223,10 +227,11 @@ class QuantizedLinear(Module):
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
biases=self.get("biases", None),
transpose=True,
group_size=self.group_size,
bits=self.bits,
quantization_type=self.quantization_type,
)
if "bias" in self:
x = x + self["bias"]
@@ -242,7 +247,7 @@ class QuantizedLinear(Module):
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql = cls(input_dims, output_dims, False, group_size, bits, quantization_type)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits, quantization_type=quantization_type
)