mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Add the quantization type option to quantizable layers
This commit is contained in:
parent
fb7be036af
commit
4847199ec6
@ -39,6 +39,10 @@ class Embedding(Module):
|
|||||||
"""
|
"""
|
||||||
return x @ self.weight.T
|
return x @ self.weight.T
|
||||||
|
|
||||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
def to_quantized(
|
||||||
|
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
|
||||||
|
):
|
||||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
return QuantizedEmbedding.from_embedding(
|
||||||
|
self, group_size, bits, quantization_type
|
||||||
|
)
|
||||||
|
@ -70,9 +70,11 @@ class Linear(Module):
|
|||||||
x = x @ self["weight"].T
|
x = x @ self["weight"].T
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
def to_quantized(
|
||||||
|
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
|
||||||
|
):
|
||||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
return QuantizedLinear.from_linear(self, group_size, bits, quantization_type)
|
||||||
|
|
||||||
|
|
||||||
class Bilinear(Module):
|
class Bilinear(Module):
|
||||||
|
@ -12,6 +12,7 @@ def quantize(
|
|||||||
model: Module,
|
model: Module,
|
||||||
group_size: int = 64,
|
group_size: int = 64,
|
||||||
bits: int = 4,
|
bits: int = 4,
|
||||||
|
quantization_type: str = "affine",
|
||||||
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||||
):
|
):
|
||||||
"""Quantize the sub-modules of a module according to a predicate.
|
"""Quantize the sub-modules of a module according to a predicate.
|
||||||
@ -39,7 +40,11 @@ def quantize(
|
|||||||
if bool_or_params := class_predicate(path, m):
|
if bool_or_params := class_predicate(path, m):
|
||||||
if hasattr(m, "to_quantized"):
|
if hasattr(m, "to_quantized"):
|
||||||
if isinstance(bool_or_params, bool):
|
if isinstance(bool_or_params, bool):
|
||||||
return m.to_quantized(group_size=group_size, bits=bits)
|
return m.to_quantized(
|
||||||
|
group_size=group_size,
|
||||||
|
bits=bits,
|
||||||
|
quantization_type=quantization_type,
|
||||||
|
)
|
||||||
elif isinstance(bool_or_params, dict):
|
elif isinstance(bool_or_params, dict):
|
||||||
return m.to_quantized(**bool_or_params)
|
return m.to_quantized(**bool_or_params)
|
||||||
else:
|
else:
|
||||||
@ -131,9 +136,15 @@ class QuantizedEmbedding(Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_embedding(
|
def from_embedding(
|
||||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
cls,
|
||||||
|
embedding_layer: Module,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
quantization_type: str = "affine",
|
||||||
):
|
):
|
||||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||||
|
if quantization_type != "affine":
|
||||||
|
raise ValueError(f"Quantization type {quantization_type} not supported")
|
||||||
embedding_dims, dims = embedding_layer.weight.shape
|
embedding_dims, dims = embedding_layer.weight.shape
|
||||||
ql = cls(embedding_dims, dims, group_size, bits)
|
ql = cls(embedding_dims, dims, group_size, bits)
|
||||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||||
@ -222,12 +233,18 @@ class QuantizedLinear(Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
def from_linear(
|
||||||
|
cls,
|
||||||
|
linear_layer: Module,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
quantization_type: str = "affine",
|
||||||
|
):
|
||||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||||
output_dims, input_dims = linear_layer.weight.shape
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||||
linear_layer.weight, group_size, bits
|
linear_layer.weight, group_size, bits, quantization_type=quantization_type
|
||||||
)
|
)
|
||||||
if "bias" in linear_layer:
|
if "bias" in linear_layer:
|
||||||
ql.bias = linear_layer.bias
|
ql.bias = linear_layer.bias
|
||||||
|
Loading…
Reference in New Issue
Block a user