mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 11:28:12 +08:00
Add NF4 quant
This commit is contained in:
@@ -115,18 +115,23 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
# [2, 4, 8], # bits
|
||||
[4], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
[mx.QuantizationMode.DEFAULT, mx.QuantizationMode.DEFAULT],
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
for group_size, bits, M, N, mode in tests:
|
||||
with self.subTest(
|
||||
shape=(M, N), group_size=group_size, bits=bits, mode=mode
|
||||
):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
w_q = mx.quantize(w, group_size, bits)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, True, group_size, bits
|
||||
x, w_q, scales, biases, True, group_size, bits, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
@@ -137,18 +142,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[4], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
[mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT],
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
for group_size, bits, M, N, mode in tests:
|
||||
with self.subTest(
|
||||
shape=(M, N), group_size=group_size, bits=bits, mode=mode
|
||||
):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(N, M), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, False, group_size, bits
|
||||
x, w_q, scales, biases, False, group_size, bits, mode
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
@@ -171,37 +179,47 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
mx.eval(y)
|
||||
|
||||
def test_small_matrix(self):
|
||||
w = mx.random.normal(shape=(8, 256))
|
||||
w_q, scales, biases = mx.quantize(w)
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
for mode in [mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT]:
|
||||
with self.subTest(mode=mode):
|
||||
w = mx.random.normal(shape=(8, 256))
|
||||
w_q, scales, biases = mx.quantize(w, mode=mode)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, mode=mode)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=True, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(10, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(10, 256))
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=True, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 8))
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=False, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm
|
||||
x = mx.random.normal(shape=(10, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
# Test qmm
|
||||
x = mx.random.normal(shape=(10, 8))
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=False, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_non_multiples(self):
|
||||
w = mx.random.normal(shape=(33, 256))
|
||||
|
||||
Reference in New Issue
Block a user