mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 09:18:11 +08:00
add trellis quant mode
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
@@ -39,6 +40,12 @@ class Embedding(Module):
|
||||
"""
|
||||
return x @ self.weight.T
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(
|
||||
self,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
fake: bool = False,
|
||||
):
|
||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
||||
|
||||
|
||||
class Identity(Module):
|
||||
@@ -70,9 +71,15 @@ class Linear(Module):
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(
|
||||
self,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
fake: bool = False,
|
||||
):
|
||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
||||
return QuantizedLinear.from_linear(self, group_size, bits, mode=mode, fake=fake)
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
||||
from mlx.utils import tree_map_with_path
|
||||
|
||||
|
||||
@@ -12,7 +13,9 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||
fake: bool = False,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
|
||||
@@ -21,7 +24,7 @@ def quantize(
|
||||
will be quantized. Note also, the module is updated in-place.
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): The model whose leaf modules may be quantized.
|
||||
model (mlx.nn.Module):, mode: Literal["affine", "trellis"] = "affine" The model whose leaf modules may be quantized.
|
||||
group_size (int): The quantization group size (see
|
||||
:func:`mlx.core.quantize`). Default: ``64``.
|
||||
bits (int): The number of bits per parameter (see
|
||||
@@ -36,12 +39,15 @@ def quantize(
|
||||
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
||||
|
||||
def _maybe_quantize(path, m):
|
||||
print(path)
|
||||
if bool_or_params := class_predicate(path, m):
|
||||
if hasattr(m, "to_quantized"):
|
||||
if isinstance(bool_or_params, bool):
|
||||
return m.to_quantized(group_size=group_size, bits=bits)
|
||||
return m.to_quantized(
|
||||
group_size=group_size, bits=bits, mode=mode, fake=fake
|
||||
)
|
||||
elif isinstance(bool_or_params, dict):
|
||||
return m.to_quantized(**bool_or_params)
|
||||
return m.to_quantized(**bool_or_params, fake=fake)
|
||||
else:
|
||||
raise ValueError(
|
||||
"``class_predicate`` must return a bool"
|
||||
@@ -131,7 +137,11 @@ class QuantizedEmbedding(Module):
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||
cls,
|
||||
embedding_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
@@ -170,12 +180,14 @@ class QuantizedLinear(Module):
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
@@ -216,19 +228,40 @@ class QuantizedLinear(Module):
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(
|
||||
cls,
|
||||
linear_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
fake: bool = False,
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits
|
||||
)
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
|
||||
if mode == "trellis":
|
||||
if fake:
|
||||
ql.weight = mx.zeros(
|
||||
(output_dims, input_dims // 32 * bits), dtype=mx.uint32
|
||||
)
|
||||
ql.scales = mx.array(0.0)
|
||||
ql.biases = mx.array(0.0)
|
||||
else:
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, bits=bits, mode="trellis"
|
||||
)
|
||||
else:
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits, mode="affine"
|
||||
)
|
||||
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
|
||||
@@ -4116,10 +4116,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@@ -4138,6 +4139,8 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The mode to use for quantization.
|
||||
Default: ``affine``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
@@ -4149,9 +4152,10 @@ void init_ops(nb::module_& m) {
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
nb::kw_only(),
|
||||
"mode"_a = "affine",
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, mode: Literal['affine', 'trellis'], stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@@ -4193,6 +4197,7 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
mode (str): The quantization mode to use. Default: ``affine``.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
@@ -4249,10 +4254,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
@@ -4278,6 +4284,8 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The mode to use for quantization.
|
||||
Default: ``affine``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
|
||||
@@ -10,6 +10,9 @@ import mlx_tests
|
||||
class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_quantize_dequantize(self):
|
||||
w = mx.random.normal(shape=(128, 512))
|
||||
w_q, scales, biases = mx.quantize(w, bits=2, mode="trellis")
|
||||
print(w_q, scales, biases)
|
||||
|
||||
for gs in [32, 64, 128]:
|
||||
for b in [2, 3, 6, 4, 8]:
|
||||
with self.subTest(gs=gs, b=b):
|
||||
|
||||
Reference in New Issue
Block a user