Allow arbitrary first dimension in quantization kernels. (#458)

* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
This commit is contained in:
Angelos Katharopoulos
2024-01-16 00:46:21 -08:00
committed by GitHub
parent f44c132f4a
commit c15fe3e61b
6 changed files with 206 additions and 66 deletions

View File

@@ -9,7 +9,7 @@ import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 128))
w = mx.random.normal(shape=(128, 512))
for b in [2, 4, 8]:
w_q, scales, biases = mx.quantize(w, 64, b)
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
@@ -131,6 +131,39 @@ class TestQuantized(mlx_tests.MLXTestCase):
y = mx.quantized_matmul(x, w_q, scales, biases, True)
mx.eval(y)
def test_small_matrix(self):
w = mx.random.normal(shape=(8, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmv
x = mx.random.normal(shape=(1, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
if __name__ == "__main__":
unittest.main()