Batched Quantized Matmul + Fast Small QMV (#1503)

* add fast qmv for small dims

* fix test

* batched cpu

* add batched template param

* refactor metal quantized.cpp
This commit is contained in:
Alex Barron
2024-10-21 16:23:17 -07:00
committed by GitHub
parent 58a855682c
commit d15fa13daf
9 changed files with 866 additions and 761 deletions

View File

@@ -117,19 +117,24 @@ class TestQuantized(mlx_tests.MLXTestCase):
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N
[512, 1024, 67], # M
[64, 128, 512, 1024], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(M, N), key=k2)
for group_size, bits, M, N, B in tests:
if group_size > N:
continue
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
w_shape = (M, N) if B == 0 else (B, M, N)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, True, group_size, bits
)
y_hat = x @ w_hat.T
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
@@ -140,12 +145,15 @@ class TestQuantized(mlx_tests.MLXTestCase):
[128, 64, 32], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N
[512, 1024, 67], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(N, M), key=k2)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
@@ -172,37 +180,39 @@ class TestQuantized(mlx_tests.MLXTestCase):
mx.eval(y)
def test_small_matrix(self):
w = mx.random.normal(shape=(8, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]:
with self.subTest(w_shape=w_shape):
w = mx.random.normal(shape=(w_shape))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmv
x = mx.random.normal(shape=(3, 1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(3, 10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmv
x = mx.random.normal(shape=(1, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qvm
x = mx.random.normal(shape=(3, 1, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(3, 10, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_non_multiples(self):
w = mx.random.normal(shape=(33, 256))