This commit is contained in:
Alex Barron
2024-12-02 16:19:29 -08:00
parent 9d40e521d7
commit 890fdd1ef0
4 changed files with 59 additions and 26 deletions

View File

@@ -10,7 +10,7 @@ import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
for gs in [32, 64, 128]:
for gs in [16, 32, 64, 128]:
for b in [2, 3, 6, 4, 8]:
with self.subTest(gs=gs, b=b):
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
@@ -115,7 +115,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[128, 64, 32, 16], # group_size
[2, 3, 4, 6, 8], # bits
[512, 1024, 67], # M
[64, 128, 512, 1024], # N
@@ -205,39 +205,64 @@ class TestQuantized(mlx_tests.MLXTestCase):
mx.eval(y)
def test_small_matrix(self):
for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]:
# We are going to need some way of doing this when we're loading a block scale / bias
# For 6 bit scales/biases we'll have to load them from uint16s I guess?
bits = 8
group_size = 16
# for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]:
for w_shape in [(32, 4096)]:
with self.subTest(w_shape=w_shape):
w = mx.random.normal(shape=(w_shape))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
w_q, scales, biases = mx.quantize(w, bits=bits, group_size=group_size)
w_hat = mx.dequantize(
w_q, scales, biases, bits=bits, group_size=group_size
)
# Test qmv
x = mx.random.normal(shape=(3, 1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
x = mx.random.normal(shape=(3, 1, 4096))
y_q = mx.quantized_matmul(
x,
w_q,
scales,
biases,
transpose=True,
bits=bits,
group_size=group_size,
)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(3, 10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
x = mx.random.normal(shape=(3, 10, 4096))
y_q = mx.quantized_matmul(
x,
w_q,
scales,
biases,
transpose=True,
bits=bits,
group_size=group_size,
)
print("y_q", y_q)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
print("y_hat", y_hat)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qvm
x = mx.random.normal(shape=(3, 1, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# x = mx.random.normal(shape=(3, 1, 8))
# y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False, bits=bits, group_size=group_size)
# y_hat = x @ w_hat
# self.assertEqual(y_q.shape, y_hat.shape)
# self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(3, 10, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# x = mx.random.normal(shape=(3, 10, 8))
# y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False, bits=bits, group_size=group_size)
# y_hat = x @ w_hat
# self.assertEqual(y_q.shape, y_hat.shape)
# self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_non_multiples(self):
w = mx.random.normal(shape=(33, 256))