Add NF4 quant

This commit is contained in:
Alex Barron
2024-06-21 10:55:42 -07:00
parent af9079cc1f
commit 152092957c
12 changed files with 530 additions and 212 deletions

View File

@@ -164,12 +164,14 @@ class QuantizedLinear(Module):
bias: bool = True,
group_size: int = 64,
bits: int = 4,
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.mode = mode
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@@ -210,18 +212,26 @@ class QuantizedLinear(Module):
transpose=True,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
if "bias" in self:
x = x + self["bias"]
return x
# if we pass mode to both then we can propagate it to the thing
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits
linear_layer.weight, group_size=group_size, bits=bits, mode=mode
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias

View File

@@ -3616,6 +3616,10 @@ void init_ops(nb::module_& m) {
array: An array of the same type as ``a`` rounded to the
given number of decimals.
)pbdoc");
nb::enum_<QuantizationMode>(m, "QuantizationMode")
.value("DEFAULT", QuantizationMode::DEFAULT)
.value("NF4", QuantizationMode::NF4)
.export_values();
m.def(
"quantized_matmul",
&quantized_matmul,
@@ -3626,10 +3630,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = QuantizationMode::DEFAULT,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def quantized_matmul(x: array, w: quantized_array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of
@@ -3648,6 +3653,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
mode (QuantizationMode, optional): The mode of quantization: see QuantizationMode (default: ``QuantizationMode.DEFAULT``)
Returns:
array: The result of the multiplication of ``x`` with ``w``.
@@ -3658,10 +3664,11 @@ void init_ops(nb::module_& m) {
nb::arg(),
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = QuantizationMode::DEFAULT,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> quantized_array"),
R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element.
@@ -3703,6 +3710,8 @@ void init_ops(nb::module_& m) {
scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``4``)
mode (QuantizationMode, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
Returns:
tuple: A tuple containing
@@ -3719,6 +3728,7 @@ void init_ops(nb::module_& m) {
"biases"_a,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = QuantizationMode::DEFAULT,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
@@ -3743,12 +3753,14 @@ void init_ops(nb::module_& m) {
scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
mode (QuantizationMode, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
Returns:
array: The dequantized version of ``w``
)pbdoc");
m.def(
"gater_qmm",
"gather_qmm",
&gather_qmm,
nb::arg(),
nb::arg(),
@@ -3759,6 +3771,7 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = QuantizationMode::DEFAULT,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
@@ -3788,6 +3801,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
mode (QuantizationMode, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
Returns:
array: The result of the multiplication of ``x`` with ``w``

View File

@@ -115,18 +115,23 @@ class TestQuantized(mlx_tests.MLXTestCase):
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
# [2, 4, 8], # bits
[4], # bits
[512, 1024], # M
[512, 1024], # N
[mx.QuantizationMode.DEFAULT, mx.QuantizationMode.DEFAULT],
)
for group_size, bits, M, N in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
for group_size, bits, M, N, mode in tests:
with self.subTest(
shape=(M, N), group_size=group_size, bits=bits, mode=mode
):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(M, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
w_q = mx.quantize(w, group_size, bits)
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, True, group_size, bits
x, w_q, scales, biases, True, group_size, bits, mode=mode
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
@@ -137,18 +142,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[4], # bits
[512, 1024], # M
[512, 1024], # N
[mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT],
)
for group_size, bits, M, N in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
for group_size, bits, M, N, mode in tests:
with self.subTest(
shape=(M, N), group_size=group_size, bits=bits, mode=mode
):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(N, M), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits
x, w_q, scales, biases, False, group_size, bits, mode
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
@@ -171,37 +179,47 @@ class TestQuantized(mlx_tests.MLXTestCase):
mx.eval(y)
def test_small_matrix(self):
w = mx.random.normal(shape=(8, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
for mode in [mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT]:
with self.subTest(mode=mode):
w = mx.random.normal(shape=(8, 256))
w_q, scales, biases = mx.quantize(w, mode=mode)
w_hat = mx.dequantize(w_q, scales, biases, mode=mode)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(
x, w_q, scales, biases, transpose=True, mode=mode
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(
x, w_q, scales, biases, transpose=True, mode=mode
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmv
x = mx.random.normal(shape=(1, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmv
x = mx.random.normal(shape=(1, 8))
y_q = mx.quantized_matmul(
x, w_q, scales, biases, transpose=False, mode=mode
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 8))
y_q = mx.quantized_matmul(
x, w_q, scales, biases, transpose=False, mode=mode
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_non_multiples(self):
w = mx.random.normal(shape=(33, 256))