mlx/python/tests/test_quantized.py

580 lines
22 KiB
Python

# Copyright © 2023 Apple Inc.
import unittest
from itertools import product
import mlx.core as mx
import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
for gs in [32, 64, 128]:
for b in [2, 3, 5, 6, 4, 8]:
with self.subTest(gs=gs, b=b):
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] + eps).abs()).all())
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
for gs in [32, 64, 128]:
for b in [2, 3, 4, 5, 6, 8]:
w_q, scales, biases = mx.quantize(a, gs, b)
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
self.assertTrue(mx.all(a_hat == 0))
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[8, 32, 33, 64], # M
[128, 256], # N
[128, 256], # K
[True, False], # transposed
)
for group_size, bits, M, N, K, transposed in tests:
with self.subTest(
shape=(M, N, K),
group_size=group_size,
bits=bits,
transposed=transposed,
):
x = mx.random.normal(shape=(M, K), key=k1)
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmm_vjp(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
bits = 8
group_size = 64
M = 64
N = 1024
K = 512
x = mx.random.normal(shape=(2, M, K), key=k1)
c = mx.ones(shape=(2, M, N))
transposes = [True, False]
for transposed in transposes:
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
def fn(x):
return mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
expected_out = mx.quantized_matmul(
c, w_q, scales, biases, not transposed, group_size, bits
)
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
def test_qmm_jvp(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
bits = 8
group_size = 64
M = 64
N = 128
K = 128
x = mx.random.normal(shape=(2, M, K), key=k1)
x_tan = mx.ones(shape=(2, M, N))
transposes = [True, False]
for transposed in transposes:
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
def fn(x):
return mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
_, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,))
expected_out = mx.quantized_matmul(
x_tan, w_q, scales, biases, transposed, group_size, bits
)
self.assertTrue(mx.allclose(jvp_out[0], expected_out))
def test_qmm_shapes(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
group_size = 64
bits = 4
w = mx.random.normal(shape=(32, 256), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
for s in [(3, 256), (2, 1, 7, 256)]:
x = mx.random.normal(shape=s, key=k1)
y_q = mx.quantized_matmul(x, w_q, scales, biases, True, group_size, bits)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
w = mx.random.normal(shape=(256, 256), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
for s in [(3, 256), (2, 1, 7, 256)]:
x = mx.random.normal(shape=s, key=k1)
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 5, 6, 8], # bits
[256, 512, 67], # M
[64, 128], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N, B in tests:
if group_size > N:
continue
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
w_shape = (M, N) if B == 0 else (B, M, N)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, True, group_size, bits
)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 5, 6, 8], # bits
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
if M < group_size:
continue
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm_splitk(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[128], # M
[16384], # N
[1, 3], # B
)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = 1e-1 * mx.random.normal(shape=x_shape, key=k1)
w = 1e-1 * mx.random.normal(shape=w_shape, key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))
w_q, scales, biases = mx.quantize(w)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q.T, scales, biases)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q.T, scales.T, biases)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q, scales, biases, False)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q, scales.T, biases.T)
y = mx.quantized_matmul(x, w_q, scales, biases, True)
mx.eval(y)
def test_small_matrix(self):
for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]:
with self.subTest(w_shape=w_shape):
w = mx.random.normal(shape=(w_shape))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
for shape in [(3, 1, 256), (3, 4, 256)]:
x = mx.random.normal(shape=shape)
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(3, 10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qvm
x = mx.random.normal(shape=(3, 1, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(3, 10, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_non_multiples(self):
w = mx.random.normal(shape=(33, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qvm
x = mx.random.normal(shape=(1, 33))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 33))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Smaller than 8
w = mx.random.normal(shape=(3, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qvm
x = mx.random.normal(shape=(1, 3))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 3))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test with larger than 128 unaligned sizes
w = mx.random.normal(shape=(99, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
x = mx.random.normal(shape=(129, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_gather_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def test_shape(
M,
N,
K,
dtype=mx.float32,
batch_A=(),
batch_B=(),
lhs_indices=None,
rhs_indices=None,
transpose=True,
group_size=64,
bits=4,
):
with self.subTest(
M=M,
N=N,
K=K,
dtype=dtype,
batch_A=batch_A,
batch_B=batch_B,
lhs_indices=lhs_indices,
rhs_indices=rhs_indices,
transpose=transpose,
group_size=group_size,
bits=bits,
):
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
w = mx.random.normal(
shape=batch_B + ((N, K) if transpose else (K, N))
).astype(dtype)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
if lhs_indices is not None:
lhs_indices = mx.array(lhs_indices)
if rhs_indices is not None:
rhs_indices = mx.array(rhs_indices)
c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices)
c2 = mx.gather_qmm(
x,
qw,
s,
b,
lhs_indices,
rhs_indices,
transpose=transpose,
group_size=group_size,
bits=bits,
)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
inputs = (
{
"batch_A": (1,),
"lhs_indices": (0,),
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (1,),
"lhs_indices": None,
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (2,),
"lhs_indices": None,
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (3,),
"lhs_indices": (0, 2),
"batch_B": (1,),
"rhs_indices": (0,),
},
{
"batch_A": (5,),
"lhs_indices": (0, 2),
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (4, 2),
"lhs_indices": (
(7, 6),
(5, 4),
(1, 2),
),
"batch_B": (4, 1),
"rhs_indices": ((2,), (0,), (1,)),
},
)
for kwargs in inputs:
test_shape(1, 32, 128, **kwargs)
test_shape(32, 32, 256, **kwargs)
test_shape(1, 32, 256, **kwargs)
test_shape(32, 256, 32, transpose=False, **kwargs)
test_shape(1, 256, 32, transpose=False, **kwargs)
test_shape(32, 32, 512, **kwargs)
test_shape(1, 32, 512, **kwargs)
test_shape(32, 512, 32, transpose=False, **kwargs)
test_shape(1, 512, 32, transpose=False, **kwargs)
def test_gather_matmul_grad(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
x = mx.random.normal((4, 2, 32, 256))
w = mx.random.normal((4, 1, 32, 256))
w_hat, qw, s, b = quantize(w)
def f_ref(x, w, i1, i2):
return mx.gather_mm(x, w, i1, i2).sum()
def f_test(x, qw, s, b, i1, i2):
return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(r1, r2, atol=1e-4))
g1 = mx.grad(f_ref)(x, w_hat, lhs_indices, rhs_indices)
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
parameters = [
# L, K, D, E, I, transpose
(128, 1024, 1024, 32, 4, True),
(128, 1024, 544, 32, 4, True),
(433, 1024, 1024, 32, 4, True),
(433, 1024, 555, 32, 4, True),
(433, 2048, 1024, 32, 4, True),
(128, 1024, 1024, 32, 4, False),
(128, 1024, 544, 32, 4, False),
(433, 1024, 1024, 32, 4, False),
(433, 1024, 544, 32, 4, False),
(433, 1024, 555, 32, 4, False),
(433, 2048, 1024, 32, 4, False),
]
for L, K, D, E, I, transpose in parameters:
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
def test_vjp_scales_biases(self):
mx.random.seed(0)
x = mx.random.normal(shape=(2, 2, 512))
w = mx.random.normal(shape=(512, 512))
wq, s, b = mx.quantize(w, bits=4, group_size=64)
def mm(sb, x, wq):
return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum()
params = (s, b)
dparams = mx.grad(mm)((s, b), x, wq)
eps = 8e-3
# numerical grad check with a few indices
indices = [(0, 0), (11, 4), (22, 7)]
for idx in indices:
for p in [0, 1]:
params[p][idx] += eps
out_up = mm(params, x, wq)
params[p][idx] -= 2 * eps
out_down = mm(params, x, wq)
params[p][idx] += eps
num_ds = (out_up - out_down) / (2 * eps)
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()