mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Add the quantization type option to quantizable layers
This commit is contained in:
parent
fb7be036af
commit
4847199ec6
@ -39,6 +39,10 @@ class Embedding(Module):
|
||||
"""
|
||||
return x @ self.weight.T
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(
|
||||
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
|
||||
):
|
||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||
return QuantizedEmbedding.from_embedding(
|
||||
self, group_size, bits, quantization_type
|
||||
)
|
||||
|
@ -70,9 +70,11 @@ class Linear(Module):
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(
|
||||
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
|
||||
):
|
||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
||||
return QuantizedLinear.from_linear(self, group_size, bits, quantization_type)
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
|
@ -12,6 +12,7 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
quantization_type: str = "affine",
|
||||
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
@ -39,7 +40,11 @@ def quantize(
|
||||
if bool_or_params := class_predicate(path, m):
|
||||
if hasattr(m, "to_quantized"):
|
||||
if isinstance(bool_or_params, bool):
|
||||
return m.to_quantized(group_size=group_size, bits=bits)
|
||||
return m.to_quantized(
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
quantization_type=quantization_type,
|
||||
)
|
||||
elif isinstance(bool_or_params, dict):
|
||||
return m.to_quantized(**bool_or_params)
|
||||
else:
|
||||
@ -131,9 +136,15 @@ class QuantizedEmbedding(Module):
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||
cls,
|
||||
embedding_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
quantization_type: str = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
if quantization_type != "affine":
|
||||
raise ValueError(f"Quantization type {quantization_type} not supported")
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
ql = cls(embedding_dims, dims, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
@ -222,12 +233,18 @@ class QuantizedLinear(Module):
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(
|
||||
cls,
|
||||
linear_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
quantization_type: str = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits
|
||||
linear_layer.weight, group_size, bits, quantization_type=quantization_type
|
||||
)
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
Loading…
Reference in New Issue
Block a user