mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops * Add quantize functions to the docs and tests * Add a QuantizedLinear module
This commit is contained in:

committed by
GitHub

parent
4912ff3ec2
commit
57fe918cf8
@@ -6,48 +6,14 @@ import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
def select_bits(w, width, start):
|
||||
shift_left = 32 - (start + width)
|
||||
shift_right = shift_left + start
|
||||
return (w * (2**shift_left)) // (2**shift_right)
|
||||
|
||||
|
||||
def dequantize(w, scales, biases, width):
|
||||
w_full = mx.concatenate(
|
||||
[select_bits(w, width, i)[..., None] for i in range(0, 32, width)], axis=-1
|
||||
)
|
||||
w_full = w_full.reshape(len(w), scales.shape[-1], -1)
|
||||
w_full = scales[..., None] * w_full + biases[..., None]
|
||||
w_full = w_full.reshape(len(w), -1)
|
||||
|
||||
return w_full
|
||||
|
||||
|
||||
def quantize(w, width, groups):
|
||||
w = w.reshape(len(w), -1, groups)
|
||||
w_max = w.max(-1, keepdims=True)
|
||||
w_min = w.min(-1, keepdims=True)
|
||||
delta = (w_max - w_min) / (2**width - 1)
|
||||
|
||||
w_int = mx.round((w - w_min) / delta).astype(mx.uint32)
|
||||
scales = delta.squeeze(-1)
|
||||
biases = w_min.squeeze(-1)
|
||||
|
||||
shifts = mx.array([2**i for i in range(0, 32, width)], dtype=mx.uint32)
|
||||
w_int = w_int.reshape(len(w), -1, 32 // width)
|
||||
w_int = w_int * shifts[None, None]
|
||||
packed_w = w_int.sum(-1)
|
||||
|
||||
return packed_w, scales, biases
|
||||
|
||||
|
||||
class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_quantize_dequantize(self):
|
||||
w = mx.random.normal(shape=(128, 128))
|
||||
w_q, scales, biases = quantize(w, 4, 64)
|
||||
w_hat = dequantize(w_q, scales, biases, 4)
|
||||
w_hat2 = dequantize(*quantize(w_hat, 4, 64), 4)
|
||||
self.assertLess((w_hat - w_hat2).abs().max(), 1e-6)
|
||||
for b in [2, 4, 8]:
|
||||
w_q, scales, biases = mx.quantize(w, 64, b)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
|
||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||
self.assertTrue((errors <= scales[..., None] / 2).all())
|
||||
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
@@ -62,14 +28,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
):
|
||||
x = mx.random.normal(shape=(M, K), key=k1)
|
||||
w = mx.random.normal(shape=(N, K), key=k2)
|
||||
w_q, scales, biases = quantize(w, width, groups)
|
||||
w_hat = dequantize(w_q, scales, biases, width)
|
||||
w_q, scales, biases = mx.quantize(w, groups, width)
|
||||
w_hat = mx.dequantize(
|
||||
w_q, scales, biases, groups, width
|
||||
)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, width=width, groups=groups
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qmm_shapes(self):
|
||||
key = mx.random.key(0)
|
||||
@@ -77,8 +45,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
groups = 64
|
||||
width = 4
|
||||
w = mx.random.normal(shape=(32, 128), key=k2)
|
||||
w_q, scales, biases = quantize(w, width, groups)
|
||||
w_hat = dequantize(w_q, scales, biases, width)
|
||||
w_q, scales, biases = mx.quantize(w, groups, width)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, groups, width)
|
||||
for s in [(3, 128), (2, 1, 7, 128)]:
|
||||
x = mx.random.normal(shape=(3, 128), key=k1)
|
||||
y_q = mx.quantized_matmul(
|
||||
@@ -86,7 +54,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qmv(self):
|
||||
key = mx.random.key(0)
|
||||
@@ -95,17 +63,17 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
for width in [2, 4, 8]:
|
||||
for M in [512, 1024]:
|
||||
for N in [512, 1024]:
|
||||
# with self.subTest(shape=(M, N), groups=groups, width=width):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = quantize(w, width, groups)
|
||||
w_hat = dequantize(w_q, scales, biases, width)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, width=width, groups=groups
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||
with self.subTest(shape=(M, N), groups=groups, width=width):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, groups, width)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, groups, width)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, width=width, groups=groups
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user