Add the quantization type option to quantizable layers

This commit is contained in:
Angelos Katharopoulos 2024-12-16 22:11:23 -08:00
parent fb7be036af
commit 4847199ec6
3 changed files with 31 additions and 8 deletions

View File

@ -39,6 +39,10 @@ class Embedding(Module):
"""
return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits)
return QuantizedEmbedding.from_embedding(
self, group_size, bits, quantization_type
)

View File

@ -70,9 +70,11 @@ class Linear(Module):
x = x @ self["weight"].T
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits)
return QuantizedLinear.from_linear(self, group_size, bits, quantization_type)
class Bilinear(Module):

View File

@ -12,6 +12,7 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
quantization_type: str = "affine",
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
):
"""Quantize the sub-modules of a module according to a predicate.
@ -39,7 +40,11 @@ def quantize(
if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"):
if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits)
return m.to_quantized(
group_size=group_size,
bits=bits,
quantization_type=quantization_type,
)
elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params)
else:
@ -131,9 +136,15 @@ class QuantizedEmbedding(Module):
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
cls,
embedding_layer: Module,
group_size: int = 64,
bits: int = 4,
quantization_type: str = "affine",
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
if quantization_type != "affine":
raise ValueError(f"Quantization type {quantization_type} not supported")
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
@ -222,12 +233,18 @@ class QuantizedLinear(Module):
return x
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
quantization_type: str = "affine",
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits
linear_layer.weight, group_size, bits, quantization_type=quantization_type
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias