From 4847199ec6efdbf02aa5914223c55777cd42f936 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 16 Dec 2024 22:11:23 -0800 Subject: [PATCH] Add the quantization type option to quantizable layers --- python/mlx/nn/layers/embedding.py | 8 ++++++-- python/mlx/nn/layers/linear.py | 6 ++++-- python/mlx/nn/layers/quantized.py | 25 +++++++++++++++++++++---- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index 1e15a59cc..85bd27a5c 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -39,6 +39,10 @@ class Embedding(Module): """ return x @ self.weight.T - def to_quantized(self, group_size: int = 64, bits: int = 4): + def to_quantized( + self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine" + ): """Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.""" - return QuantizedEmbedding.from_embedding(self, group_size, bits) + return QuantizedEmbedding.from_embedding( + self, group_size, bits, quantization_type + ) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 63caa911c..b037c4797 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -70,9 +70,11 @@ class Linear(Module): x = x @ self["weight"].T return x - def to_quantized(self, group_size: int = 64, bits: int = 4): + def to_quantized( + self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine" + ): """Return a :obj:`QuantizedLinear` layer that approximates this layer.""" - return QuantizedLinear.from_linear(self, group_size, bits) + return QuantizedLinear.from_linear(self, group_size, bits, quantization_type) class Bilinear(Module): diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 823a0084f..76d30b1de 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -12,6 +12,7 @@ def quantize( model: Module, group_size: int = 64, bits: int = 4, + quantization_type: str = "affine", class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, ): """Quantize the sub-modules of a module according to a predicate. @@ -39,7 +40,11 @@ def quantize( if bool_or_params := class_predicate(path, m): if hasattr(m, "to_quantized"): if isinstance(bool_or_params, bool): - return m.to_quantized(group_size=group_size, bits=bits) + return m.to_quantized( + group_size=group_size, + bits=bits, + quantization_type=quantization_type, + ) elif isinstance(bool_or_params, dict): return m.to_quantized(**bool_or_params) else: @@ -131,9 +136,15 @@ class QuantizedEmbedding(Module): @classmethod def from_embedding( - cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 + cls, + embedding_layer: Module, + group_size: int = 64, + bits: int = 4, + quantization_type: str = "affine", ): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" + if quantization_type != "affine": + raise ValueError(f"Quantization type {quantization_type} not supported") embedding_dims, dims = embedding_layer.weight.shape ql = cls(embedding_dims, dims, group_size, bits) ql.weight, ql.scales, ql.biases = mx.quantize( @@ -222,12 +233,18 @@ class QuantizedLinear(Module): return x @classmethod - def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): + def from_linear( + cls, + linear_layer: Module, + group_size: int = 64, + bits: int = 4, + quantization_type: str = "affine", + ): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape ql = cls(input_dims, output_dims, False, group_size, bits) ql.weight, ql.scales, ql.biases = mx.quantize( - linear_layer.weight, group_size, bits + linear_layer.weight, group_size, bits, quantization_type=quantization_type ) if "bias" in linear_layer: ql.bias = linear_layer.bias