Add mode parameter for quantization (#2499)

* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
This commit is contained in:
Awni Hannun
2025-08-28 06:45:26 -07:00
committed by GitHub
parent 7ef8a6f2d5
commit 70560b6bd5
28 changed files with 3635 additions and 757 deletions

View File

@@ -39,6 +39,6 @@ class Embedding(Module):
"""
return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits)
return QuantizedEmbedding.from_embedding(self, group_size, bits, mode)

View File

@@ -70,9 +70,9 @@ class Linear(Module):
x = x @ self["weight"].T
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits)
return QuantizedLinear.from_linear(self, group_size, bits, mode)
class Bilinear(Module):

View File

@@ -12,6 +12,8 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
*,
mode: str = "affine",
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
):
"""Quantize the sub-modules of a module according to a predicate.
@@ -26,6 +28,8 @@ def quantize(
:func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see
:func:`mlx.core.quantize`). Default: ``4``.
mode (str): The quantization method to use (see
:func:`mlx.core.quantize`). Default: ``"affine"``.
class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` or a
dict of params for `to_quantized` if it should be quantized and
@@ -39,7 +43,7 @@ def quantize(
if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"):
if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits)
return m.to_quantized(group_size=group_size, bits=bits, mode=mode)
elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params)
else:
@@ -72,6 +76,8 @@ class QuantizedEmbedding(Module):
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
mode (str): The quantization method to use (see
:func:`mlx.core.quantize`). Default: ``"affine"``.
"""
def __init__(
@@ -80,17 +86,23 @@ class QuantizedEmbedding(Module):
dims: int,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.mode = mode
# Initialize the quantized weight
scale = math.sqrt(1 / dims)
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
(self.scales,) = scales_biases
self.num_embeddings = num_embeddings
self.dims = dims
@@ -98,12 +110,14 @@ class QuantizedEmbedding(Module):
self.freeze()
def __call__(self, x):
biases = self.get("biases")
return mx.dequantize(
self["weight"][x],
scales=self["scales"][x],
biases=self["biases"][x],
biases=biases[x] if biases is not None else None,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
def as_linear(self, x):
@@ -117,28 +131,40 @@ class QuantizedEmbedding(Module):
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
biases=self.get("biases"),
transpose=True,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
def _extra_repr(self):
return (
f"{self.num_embeddings}, {self.dims}, "
f"group_size={self.group_size}, bits={self.bits}"
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
)
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
cls,
embedding_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
embedding_layer.weight, group_size, bits
ql = cls(embedding_dims, dims, group_size, bits, mode=mode)
ql.weight, *scales_biases = mx.quantize(
embedding_layer.weight,
group_size,
bits,
mode=mode,
)
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
(ql.scales,) = scales_biases
return ql
@@ -161,6 +187,8 @@ class QuantizedLinear(Module):
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
mode (str): The quantization method to use (see
:func:`mlx.core.quantize`). Default: ``"affine"``.
"""
def __init__(
@@ -170,12 +198,14 @@ class QuantizedLinear(Module):
bias: bool = True,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.mode = mode
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@@ -184,7 +214,11 @@ class QuantizedLinear(Module):
high=scale,
shape=(output_dims, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
(self.scales,) = scales_biases
# And bias if needed
if bias:
@@ -198,7 +232,7 @@ class QuantizedLinear(Module):
in_dims *= 32 // self.bits
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
)
def __call__(self, x):
@@ -206,23 +240,38 @@ class QuantizedLinear(Module):
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
biases=self.get("biases"),
transpose=True,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits
ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode)
ql.weight, *scales_biases = mx.quantize(
linear_layer.weight,
group_size,
bits,
mode=mode,
)
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
(ql.scales,) = scales_biases
if "bias" in linear_layer:
ql.bias = linear_layer.bias