add trellis quant mode

This commit is contained in:
Alex Barron
2025-03-18 18:52:22 -07:00
parent e9e268336b
commit d7acf59fd0
16 changed files with 852 additions and 108 deletions

View File

@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Literal
import mlx.core as mx
from mlx.nn.layers.base import Module
@@ -39,6 +40,12 @@ class Embedding(Module):
"""
return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits)

View File

@@ -1,11 +1,12 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Any
from typing import Any, Literal
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.viterbi import quantize as trellis_quantize
class Identity(Module):
@@ -70,9 +71,15 @@ class Linear(Module):
x = x @ self["weight"].T
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits)
return QuantizedLinear.from_linear(self, group_size, bits, mode=mode, fake=fake)
class Bilinear(Module):

View File

@@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Callable, Optional, Union
from typing import Callable, Literal, Optional, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.viterbi import quantize as trellis_quantize
from mlx.utils import tree_map_with_path
@@ -12,7 +13,9 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
fake: bool = False,
):
"""Quantize the sub-modules of a module according to a predicate.
@@ -21,7 +24,7 @@ def quantize(
will be quantized. Note also, the module is updated in-place.
Args:
model (mlx.nn.Module): The model whose leaf modules may be quantized.
model (mlx.nn.Module):, mode: Literal["affine", "trellis"] = "affine" The model whose leaf modules may be quantized.
group_size (int): The quantization group size (see
:func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see
@@ -36,12 +39,15 @@ def quantize(
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
def _maybe_quantize(path, m):
print(path)
if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"):
if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits)
return m.to_quantized(
group_size=group_size, bits=bits, mode=mode, fake=fake
)
elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params)
return m.to_quantized(**bool_or_params, fake=fake)
else:
raise ValueError(
"``class_predicate`` must return a bool"
@@ -131,7 +137,11 @@ class QuantizedEmbedding(Module):
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
cls,
embedding_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
@@ -170,12 +180,14 @@ class QuantizedLinear(Module):
bias: bool = True,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.mode = mode
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@@ -216,19 +228,40 @@ class QuantizedLinear(Module):
transpose=True,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits
)
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
if mode == "trellis":
if fake:
ql.weight = mx.zeros(
(output_dims, input_dims // 32 * bits), dtype=mx.uint32
)
ql.scales = mx.array(0.0)
ql.biases = mx.array(0.0)
else:
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, bits=bits, mode="trellis"
)
else:
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits, mode="affine"
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias

View File

@@ -4116,10 +4116,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of
@@ -4138,6 +4139,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
mode (str, optional): The mode to use for quantization.
Default: ``affine``.
Returns:
array: The result of the multiplication of ``x`` with ``w``.
@@ -4149,9 +4152,10 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"mode"_a = "affine",
"stream"_a = nb::none(),
nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, mode: Literal['affine', 'trellis'], stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element.
@@ -4193,6 +4197,7 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
mode (str): The quantization mode to use. Default: ``affine``.
Returns:
tuple: A tuple containing
@@ -4249,10 +4254,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
@@ -4278,6 +4284,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
mode (str, optional): The mode to use for quantization.
Default: ``affine``.
Returns:
array: The result of the multiplication of ``x`` with ``w``

View File

@@ -10,6 +10,9 @@ import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
w_q, scales, biases = mx.quantize(w, bits=2, mode="trellis")
print(w_q, scales, biases)
for gs in [32, 64, 128]:
for b in [2, 3, 6, 4, 8]:
with self.subTest(gs=gs, b=b):