mlx/python/tests/test_quantized.py

426 lines
16 KiB
Python
Raw Normal View History

# Copyright © 2023 Apple Inc.
import unittest
from itertools import product
import mlx.core as mx
import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
for gs in [32, 64, 128]:
for b in [2, 4, 8]:
with self.subTest(gs=gs, b=b):
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] + eps).abs()).all())
2024-04-24 15:40:42 +08:00
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
for gs in [32, 64, 128]:
for b in [2, 4, 8]:
w_q, scales, biases = mx.quantize(a, gs, b)
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
self.assertTrue(mx.all(a_hat == 0))
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[8, 32, 33, 64], # M
[512, 1024], # N
[512, 1024], # K
[True, False], # transposed
)
for group_size, bits, M, N, K, transposed in tests:
with self.subTest(
shape=(M, N, K),
group_size=group_size,
bits=bits,
transposed=transposed,
):
x = mx.random.normal(shape=(M, K), key=k1)
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmm_vjp(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
bits = 8
group_size = 64
M = 64
N = 1024
K = 512
x = mx.random.normal(shape=(2, M, K), key=k1)
c = mx.ones(shape=(2, M, N))
transposes = [True, False]
for transposed in transposes:
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
def fn(x):
return mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
expected_out = mx.quantized_matmul(
c, w_q, scales, biases, not transposed, group_size, bits
)
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
def test_qmm_shapes(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
group_size = 64
bits = 4
w = mx.random.normal(shape=(32, 256), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
for s in [(3, 256), (2, 1, 7, 256)]:
x = mx.random.normal(shape=s, key=k1)
y_q = mx.quantized_matmul(x, w_q, scales, biases, True, group_size, bits)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
w = mx.random.normal(shape=(256, 256), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
for s in [(3, 256), (2, 1, 7, 256)]:
x = mx.random.normal(shape=s, key=k1)
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N
)
for group_size, bits, M, N in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(M, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, True, group_size, bits
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N
)
for group_size, bits, M, N in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(N, M), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))
w_q, scales, biases = mx.quantize(w)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q.T, scales, biases)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q.T, scales.T, biases)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q, scales, biases, False)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q, scales.T, biases.T)
y = mx.quantized_matmul(x, w_q, scales, biases, True)
mx.eval(y)
def test_small_matrix(self):
w = mx.random.normal(shape=(8, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmv
x = mx.random.normal(shape=(1, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_non_multiples(self):
w = mx.random.normal(shape=(33, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qvm
x = mx.random.normal(shape=(1, 33))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 33))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Smaller than 8
w = mx.random.normal(shape=(3, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qvm
x = mx.random.normal(shape=(1, 3))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 3))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
2024-03-29 06:34:57 +08:00
# Test with larger than 128 unaligned sizes
w = mx.random.normal(shape=(99, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
x = mx.random.normal(shape=(129, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_gather_qmm(self):
2024-05-17 06:24:14 +08:00
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def test_shape(
M,
N,
K,
dtype=mx.float32,
batch_A=(),
batch_B=(),
lhs_indices=None,
rhs_indices=None,
transpose=True,
group_size=64,
bits=4,
):
with self.subTest(
M=M,
N=N,
K=K,
dtype=dtype,
batch_A=batch_A,
batch_B=batch_B,
lhs_indices=lhs_indices,
rhs_indices=rhs_indices,
transpose=transpose,
group_size=group_size,
bits=bits,
):
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
w = mx.random.normal(
shape=batch_B + ((N, K) if transpose else (K, N))
).astype(dtype)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
if lhs_indices is not None:
lhs_indices = mx.array(lhs_indices)
if rhs_indices is not None:
rhs_indices = mx.array(rhs_indices)
c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices)
c2 = mx.gather_qmm(
2024-05-17 06:24:14 +08:00
x,
qw,
s,
b,
lhs_indices,
rhs_indices,
transpose=transpose,
group_size=group_size,
bits=bits,
)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
inputs = (
{
"batch_A": (1,),
"lhs_indices": (0,),
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (1,),
"lhs_indices": None,
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (2,),
"lhs_indices": None,
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (3,),
"lhs_indices": (0, 2),
"batch_B": (1,),
"rhs_indices": (0,),
},
{
"batch_A": (5,),
"lhs_indices": (0, 2),
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (4, 2),
"lhs_indices": (
(7, 6),
(5, 4),
(1, 2),
),
"batch_B": (4, 1),
"rhs_indices": ((2,), (0,), (1,)),
},
)
for kwargs in inputs:
test_shape(32, 32, 256, **kwargs)
test_shape(1, 32, 256, **kwargs)
test_shape(32, 256, 32, transpose=False, **kwargs)
test_shape(1, 256, 32, transpose=False, **kwargs)
test_shape(32, 32, 512, **kwargs)
test_shape(1, 32, 512, **kwargs)
test_shape(32, 512, 32, transpose=False, **kwargs)
test_shape(1, 512, 32, transpose=False, **kwargs)
def test_gather_matmul_grad(self):
2024-05-17 06:24:14 +08:00
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
x = mx.random.normal((4, 2, 32, 256))
w = mx.random.normal((4, 1, 32, 256))
w_hat, qw, s, b = quantize(w)
def f_ref(x, w, i1, i2):
return mx.gather_mm(x, w, i1, i2).sum()
2024-05-17 06:24:14 +08:00
def f_test(x, qw, s, b, i1, i2):
return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
2024-05-17 06:24:14 +08:00
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(r1, r2, atol=1e-4))
g1 = mx.grad(f_ref)(x, w_hat, lhs_indices, rhs_indices)
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
if __name__ == "__main__":
unittest.main()