mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 01:48:12 +08:00
Add NF4 quant
This commit is contained in:
@@ -164,12 +164,14 @@ class QuantizedLinear(Module):
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
@@ -210,18 +212,26 @@ class QuantizedLinear(Module):
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
# if we pass mode to both then we can propagate it to the thing
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(
|
||||
cls,
|
||||
linear_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits
|
||||
linear_layer.weight, group_size=group_size, bits=bits, mode=mode
|
||||
)
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
@@ -3616,6 +3616,10 @@ void init_ops(nb::module_& m) {
|
||||
array: An array of the same type as ``a`` rounded to the
|
||||
given number of decimals.
|
||||
)pbdoc");
|
||||
nb::enum_<QuantizationMode>(m, "QuantizationMode")
|
||||
.value("DEFAULT", QuantizationMode::DEFAULT)
|
||||
.value("NF4", QuantizationMode::NF4)
|
||||
.export_values();
|
||||
m.def(
|
||||
"quantized_matmul",
|
||||
&quantized_matmul,
|
||||
@@ -3626,10 +3630,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = QuantizationMode::DEFAULT,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: quantized_array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@@ -3648,6 +3653,7 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
mode (QuantizationMode, optional): The mode of quantization: see QuantizationMode (default: ``QuantizationMode.DEFAULT``)
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
@@ -3658,10 +3664,11 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = QuantizationMode::DEFAULT,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> quantized_array"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@@ -3703,6 +3710,8 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``4``)
|
||||
mode (QuantizationMode, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
@@ -3719,6 +3728,7 @@ void init_ops(nb::module_& m) {
|
||||
"biases"_a,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = QuantizationMode::DEFAULT,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@@ -3743,12 +3753,14 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
mode (QuantizationMode, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gater_qmm",
|
||||
"gather_qmm",
|
||||
&gather_qmm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
@@ -3759,6 +3771,7 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = QuantizationMode::DEFAULT,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@@ -3788,6 +3801,8 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
mode (QuantizationMode, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
|
||||
@@ -115,18 +115,23 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
# [2, 4, 8], # bits
|
||||
[4], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
[mx.QuantizationMode.DEFAULT, mx.QuantizationMode.DEFAULT],
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
for group_size, bits, M, N, mode in tests:
|
||||
with self.subTest(
|
||||
shape=(M, N), group_size=group_size, bits=bits, mode=mode
|
||||
):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
w_q = mx.quantize(w, group_size, bits)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, True, group_size, bits
|
||||
x, w_q, scales, biases, True, group_size, bits, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
@@ -137,18 +142,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[4], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
[mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT],
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
for group_size, bits, M, N, mode in tests:
|
||||
with self.subTest(
|
||||
shape=(M, N), group_size=group_size, bits=bits, mode=mode
|
||||
):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(N, M), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, False, group_size, bits
|
||||
x, w_q, scales, biases, False, group_size, bits, mode
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
@@ -171,37 +179,47 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
mx.eval(y)
|
||||
|
||||
def test_small_matrix(self):
|
||||
w = mx.random.normal(shape=(8, 256))
|
||||
w_q, scales, biases = mx.quantize(w)
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
for mode in [mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT]:
|
||||
with self.subTest(mode=mode):
|
||||
w = mx.random.normal(shape=(8, 256))
|
||||
w_q, scales, biases = mx.quantize(w, mode=mode)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, mode=mode)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=True, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(10, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(10, 256))
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=True, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 8))
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=False, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm
|
||||
x = mx.random.normal(shape=(10, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
# Test qmm
|
||||
x = mx.random.normal(shape=(10, 8))
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=False, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_non_multiples(self):
|
||||
w = mx.random.normal(shape=(33, 256))
|
||||
|
||||
Reference in New Issue
Block a user