mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
580 lines
22 KiB
Python
580 lines
22 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import unittest
|
|
from itertools import product
|
|
|
|
import mlx.core as mx
|
|
import mlx_tests
|
|
|
|
|
|
class TestQuantized(mlx_tests.MLXTestCase):
|
|
def test_quantize_dequantize(self):
|
|
w = mx.random.normal(shape=(128, 512))
|
|
for gs in [32, 64, 128]:
|
|
for b in [2, 3, 5, 6, 4, 8]:
|
|
with self.subTest(gs=gs, b=b):
|
|
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
|
|
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
|
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
|
eps = 1e-6
|
|
self.assertTrue((errors <= (scales[..., None] + eps).abs()).all())
|
|
|
|
# test quantize/dequantize 0s
|
|
a = mx.zeros((256, 512))
|
|
for gs in [32, 64, 128]:
|
|
for b in [2, 3, 4, 5, 6, 8]:
|
|
w_q, scales, biases = mx.quantize(a, gs, b)
|
|
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
|
self.assertTrue(mx.all(a_hat == 0))
|
|
|
|
def test_qmm(self):
|
|
key = mx.random.key(0)
|
|
k1, k2 = mx.random.split(key)
|
|
tests = product(
|
|
[128, 64, 32], # group_size
|
|
[2, 4, 8], # bits
|
|
[8, 32, 33, 64], # M
|
|
[128, 256], # N
|
|
[128, 256], # K
|
|
[True, False], # transposed
|
|
)
|
|
for group_size, bits, M, N, K, transposed in tests:
|
|
with self.subTest(
|
|
shape=(M, N, K),
|
|
group_size=group_size,
|
|
bits=bits,
|
|
transposed=transposed,
|
|
):
|
|
x = mx.random.normal(shape=(M, K), key=k1)
|
|
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
|
y_q = mx.quantized_matmul(
|
|
x, w_q, scales, biases, transposed, group_size, bits
|
|
)
|
|
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
def test_qmm_vjp(self):
|
|
key = mx.random.key(0)
|
|
k1, k2 = mx.random.split(key)
|
|
|
|
bits = 8
|
|
group_size = 64
|
|
M = 64
|
|
N = 1024
|
|
K = 512
|
|
|
|
x = mx.random.normal(shape=(2, M, K), key=k1)
|
|
c = mx.ones(shape=(2, M, N))
|
|
|
|
transposes = [True, False]
|
|
for transposed in transposes:
|
|
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
|
|
|
def fn(x):
|
|
return mx.quantized_matmul(
|
|
x, w_q, scales, biases, transposed, group_size, bits
|
|
)
|
|
|
|
_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
|
|
|
|
expected_out = mx.quantized_matmul(
|
|
c, w_q, scales, biases, not transposed, group_size, bits
|
|
)
|
|
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
|
|
|
|
def test_qmm_jvp(self):
|
|
key = mx.random.key(0)
|
|
k1, k2 = mx.random.split(key)
|
|
|
|
bits = 8
|
|
group_size = 64
|
|
M = 64
|
|
N = 128
|
|
K = 128
|
|
|
|
x = mx.random.normal(shape=(2, M, K), key=k1)
|
|
x_tan = mx.ones(shape=(2, M, N))
|
|
|
|
transposes = [True, False]
|
|
for transposed in transposes:
|
|
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
|
|
|
def fn(x):
|
|
return mx.quantized_matmul(
|
|
x, w_q, scales, biases, transposed, group_size, bits
|
|
)
|
|
|
|
_, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,))
|
|
|
|
expected_out = mx.quantized_matmul(
|
|
x_tan, w_q, scales, biases, transposed, group_size, bits
|
|
)
|
|
self.assertTrue(mx.allclose(jvp_out[0], expected_out))
|
|
|
|
def test_qmm_shapes(self):
|
|
key = mx.random.key(0)
|
|
k1, k2 = mx.random.split(key)
|
|
group_size = 64
|
|
bits = 4
|
|
w = mx.random.normal(shape=(32, 256), key=k2)
|
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
|
for s in [(3, 256), (2, 1, 7, 256)]:
|
|
x = mx.random.normal(shape=s, key=k1)
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, True, group_size, bits)
|
|
y_hat = x @ w_hat.T
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
w = mx.random.normal(shape=(256, 256), key=k2)
|
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
|
for s in [(3, 256), (2, 1, 7, 256)]:
|
|
x = mx.random.normal(shape=s, key=k1)
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
|
|
y_hat = x @ w_hat
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
def test_qmv(self):
|
|
key = mx.random.key(0)
|
|
k1, k2 = mx.random.split(key)
|
|
tests = product(
|
|
[128, 64, 32], # group_size
|
|
[2, 3, 4, 5, 6, 8], # bits
|
|
[256, 512, 67], # M
|
|
[64, 128], # N
|
|
[0, 1, 3, 8], # B
|
|
)
|
|
for group_size, bits, M, N, B in tests:
|
|
if group_size > N:
|
|
continue
|
|
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
|
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
|
|
w_shape = (M, N) if B == 0 else (B, M, N)
|
|
x = mx.random.normal(shape=x_shape, key=k1)
|
|
w = mx.random.normal(shape=w_shape, key=k2)
|
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
|
y_q = mx.quantized_matmul(
|
|
x, w_q, scales, biases, True, group_size, bits
|
|
)
|
|
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
def test_qvm(self):
|
|
key = mx.random.key(0)
|
|
k1, k2 = mx.random.split(key)
|
|
tests = product(
|
|
[128, 64, 32], # group_size
|
|
[2, 3, 4, 5, 6, 8], # bits
|
|
[32, 128, 256], # M
|
|
[128, 256, 67], # N
|
|
[0, 1, 3, 8], # B
|
|
)
|
|
for group_size, bits, M, N, B in tests:
|
|
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
|
if M < group_size:
|
|
continue
|
|
x_shape = (1, N) if B == 0 else (B, 1, N)
|
|
w_shape = (N, M) if B == 0 else (B, N, M)
|
|
x = mx.random.normal(shape=x_shape, key=k1)
|
|
w = mx.random.normal(shape=w_shape, key=k2)
|
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
|
y_q = mx.quantized_matmul(
|
|
x, w_q, scales, biases, False, group_size, bits
|
|
)
|
|
y_hat = x @ w_hat
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
def test_qvm_splitk(self):
|
|
key = mx.random.key(0)
|
|
k1, k2 = mx.random.split(key)
|
|
tests = product(
|
|
[128, 64, 32], # group_size
|
|
[2, 4, 8], # bits
|
|
[128], # M
|
|
[16384], # N
|
|
[1, 3], # B
|
|
)
|
|
for group_size, bits, M, N, B in tests:
|
|
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
|
x_shape = (1, N) if B == 0 else (B, 1, N)
|
|
w_shape = (N, M) if B == 0 else (B, N, M)
|
|
x = 1e-1 * mx.random.normal(shape=x_shape, key=k1)
|
|
w = 1e-1 * mx.random.normal(shape=w_shape, key=k2)
|
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
|
y_q = mx.quantized_matmul(
|
|
x, w_q, scales, biases, False, group_size, bits
|
|
)
|
|
y_hat = x @ w_hat
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
|
|
|
def test_throw(self):
|
|
x = mx.random.normal(shape=(10, 512))
|
|
w = mx.random.normal(shape=(32, 512))
|
|
w_q, scales, biases = mx.quantize(w)
|
|
|
|
with self.assertRaises(ValueError):
|
|
mx.quantized_matmul(x, w_q.T, scales, biases)
|
|
with self.assertRaises(ValueError):
|
|
mx.quantized_matmul(x, w_q.T, scales.T, biases)
|
|
with self.assertRaises(ValueError):
|
|
mx.quantized_matmul(x, w_q, scales, biases, False)
|
|
with self.assertRaises(ValueError):
|
|
mx.quantized_matmul(x, w_q, scales.T, biases.T)
|
|
y = mx.quantized_matmul(x, w_q, scales, biases, True)
|
|
mx.eval(y)
|
|
|
|
def test_small_matrix(self):
|
|
for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]:
|
|
with self.subTest(w_shape=w_shape):
|
|
w = mx.random.normal(shape=(w_shape))
|
|
w_q, scales, biases = mx.quantize(w)
|
|
w_hat = mx.dequantize(w_q, scales, biases)
|
|
|
|
# Test qmv
|
|
for shape in [(3, 1, 256), (3, 4, 256)]:
|
|
x = mx.random.normal(shape=shape)
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
|
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test qmm_t
|
|
x = mx.random.normal(shape=(3, 10, 256))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
|
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test qvm
|
|
x = mx.random.normal(shape=(3, 1, 8))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
|
y_hat = x @ w_hat
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test qmm
|
|
x = mx.random.normal(shape=(3, 10, 8))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
|
y_hat = x @ w_hat
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
def test_non_multiples(self):
|
|
w = mx.random.normal(shape=(33, 256))
|
|
w_q, scales, biases = mx.quantize(w)
|
|
w_hat = mx.dequantize(w_q, scales, biases)
|
|
|
|
# Test qmv
|
|
x = mx.random.normal(shape=(1, 256))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
|
y_hat = x @ w_hat.T
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test qmm_t
|
|
x = mx.random.normal(shape=(10, 256))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
|
y_hat = x @ w_hat.T
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test qvm
|
|
x = mx.random.normal(shape=(1, 33))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
|
y_hat = x @ w_hat
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test qmm
|
|
x = mx.random.normal(shape=(10, 33))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
|
y_hat = x @ w_hat
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Smaller than 8
|
|
w = mx.random.normal(shape=(3, 256))
|
|
w_q, scales, biases = mx.quantize(w)
|
|
w_hat = mx.dequantize(w_q, scales, biases)
|
|
|
|
# Test qmv
|
|
x = mx.random.normal(shape=(1, 256))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
|
y_hat = x @ w_hat.T
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test qmm_t
|
|
x = mx.random.normal(shape=(10, 256))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
|
y_hat = x @ w_hat.T
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test qvm
|
|
x = mx.random.normal(shape=(1, 3))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
|
y_hat = x @ w_hat
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test qmm
|
|
x = mx.random.normal(shape=(10, 3))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
|
y_hat = x @ w_hat
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
# Test with larger than 128 unaligned sizes
|
|
w = mx.random.normal(shape=(99, 256))
|
|
w_q, scales, biases = mx.quantize(w)
|
|
w_hat = mx.dequantize(w_q, scales, biases)
|
|
x = mx.random.normal(shape=(129, 256))
|
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
|
y_hat = x @ w_hat.T
|
|
self.assertEqual(y_q.shape, y_hat.shape)
|
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
|
|
|
def test_gather_qmm(self):
|
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
|
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
|
if transpose:
|
|
w_hat = w_hat.swapaxes(-1, -2)
|
|
return w_hat, qw, s, b
|
|
|
|
def test_shape(
|
|
M,
|
|
N,
|
|
K,
|
|
dtype=mx.float32,
|
|
batch_A=(),
|
|
batch_B=(),
|
|
lhs_indices=None,
|
|
rhs_indices=None,
|
|
transpose=True,
|
|
group_size=64,
|
|
bits=4,
|
|
):
|
|
with self.subTest(
|
|
M=M,
|
|
N=N,
|
|
K=K,
|
|
dtype=dtype,
|
|
batch_A=batch_A,
|
|
batch_B=batch_B,
|
|
lhs_indices=lhs_indices,
|
|
rhs_indices=rhs_indices,
|
|
transpose=transpose,
|
|
group_size=group_size,
|
|
bits=bits,
|
|
):
|
|
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
|
|
w = mx.random.normal(
|
|
shape=batch_B + ((N, K) if transpose else (K, N))
|
|
).astype(dtype)
|
|
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
|
|
|
|
if lhs_indices is not None:
|
|
lhs_indices = mx.array(lhs_indices)
|
|
if rhs_indices is not None:
|
|
rhs_indices = mx.array(rhs_indices)
|
|
|
|
c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices)
|
|
c2 = mx.gather_qmm(
|
|
x,
|
|
qw,
|
|
s,
|
|
b,
|
|
lhs_indices,
|
|
rhs_indices,
|
|
transpose=transpose,
|
|
group_size=group_size,
|
|
bits=bits,
|
|
)
|
|
|
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
|
|
|
inputs = (
|
|
{
|
|
"batch_A": (1,),
|
|
"lhs_indices": (0,),
|
|
"batch_B": (3,),
|
|
"rhs_indices": (2, 1),
|
|
},
|
|
{
|
|
"batch_A": (1,),
|
|
"lhs_indices": None,
|
|
"batch_B": (3,),
|
|
"rhs_indices": (2, 1),
|
|
},
|
|
{
|
|
"batch_A": (2,),
|
|
"lhs_indices": None,
|
|
"batch_B": (3,),
|
|
"rhs_indices": (2, 1),
|
|
},
|
|
{
|
|
"batch_A": (3,),
|
|
"lhs_indices": (0, 2),
|
|
"batch_B": (1,),
|
|
"rhs_indices": (0,),
|
|
},
|
|
{
|
|
"batch_A": (5,),
|
|
"lhs_indices": (0, 2),
|
|
"batch_B": (3,),
|
|
"rhs_indices": (2, 1),
|
|
},
|
|
{
|
|
"batch_A": (4, 2),
|
|
"lhs_indices": (
|
|
(7, 6),
|
|
(5, 4),
|
|
(1, 2),
|
|
),
|
|
"batch_B": (4, 1),
|
|
"rhs_indices": ((2,), (0,), (1,)),
|
|
},
|
|
)
|
|
|
|
for kwargs in inputs:
|
|
test_shape(1, 32, 128, **kwargs)
|
|
test_shape(32, 32, 256, **kwargs)
|
|
test_shape(1, 32, 256, **kwargs)
|
|
test_shape(32, 256, 32, transpose=False, **kwargs)
|
|
test_shape(1, 256, 32, transpose=False, **kwargs)
|
|
test_shape(32, 32, 512, **kwargs)
|
|
test_shape(1, 32, 512, **kwargs)
|
|
test_shape(32, 512, 32, transpose=False, **kwargs)
|
|
test_shape(1, 512, 32, transpose=False, **kwargs)
|
|
|
|
def test_gather_matmul_grad(self):
|
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
|
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
|
if transpose:
|
|
w_hat = w_hat.swapaxes(-1, -2)
|
|
return w_hat, qw, s, b
|
|
|
|
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
|
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
|
|
|
x = mx.random.normal((4, 2, 32, 256))
|
|
w = mx.random.normal((4, 1, 32, 256))
|
|
w_hat, qw, s, b = quantize(w)
|
|
|
|
def f_ref(x, w, i1, i2):
|
|
return mx.gather_mm(x, w, i1, i2).sum()
|
|
|
|
def f_test(x, qw, s, b, i1, i2):
|
|
return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
|
|
|
|
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
|
|
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)
|
|
self.assertTrue(mx.allclose(r1, r2, atol=1e-4))
|
|
|
|
g1 = mx.grad(f_ref)(x, w_hat, lhs_indices, rhs_indices)
|
|
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
|
|
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
|
|
|
def test_gather_qmm_sorted(self):
|
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
|
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
|
if transpose:
|
|
w_hat = w_hat.swapaxes(-1, -2)
|
|
return w_hat, qw, s, b
|
|
|
|
def gather_sort(x, indices):
|
|
N, M = indices.shape
|
|
indices = indices.flatten()
|
|
order = mx.argsort(indices)
|
|
inv_order = mx.argsort(order)
|
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
|
|
|
def scatter_unsort(x, inv_order, shape=None):
|
|
x = x[inv_order]
|
|
if shape is not None:
|
|
x = mx.unflatten(x, 0, shape)
|
|
return x
|
|
|
|
parameters = [
|
|
# L, K, D, E, I, transpose
|
|
(128, 1024, 1024, 32, 4, True),
|
|
(128, 1024, 544, 32, 4, True),
|
|
(433, 1024, 1024, 32, 4, True),
|
|
(433, 1024, 555, 32, 4, True),
|
|
(433, 2048, 1024, 32, 4, True),
|
|
(128, 1024, 1024, 32, 4, False),
|
|
(128, 1024, 544, 32, 4, False),
|
|
(433, 1024, 1024, 32, 4, False),
|
|
(433, 1024, 544, 32, 4, False),
|
|
(433, 1024, 555, 32, 4, False),
|
|
(433, 2048, 1024, 32, 4, False),
|
|
]
|
|
for L, K, D, E, I, transpose in parameters:
|
|
K, D = (K, D) if transpose else (D, K)
|
|
ishape = (L, I)
|
|
xshape = (L, 1, 1, K)
|
|
wshape = (E, D, K) if transpose else (E, K, D)
|
|
|
|
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
|
x = mx.random.normal(xshape) / K**0.5
|
|
w = mx.random.normal(wshape) / K**0.5
|
|
w, *wq = quantize(w, transpose=transpose)
|
|
|
|
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
|
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
|
|
xs, idx, inv_order = gather_sort(x, indices)
|
|
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
|
y4 = mx.gather_qmm(
|
|
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
|
|
)
|
|
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
|
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
|
|
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
|
|
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
|
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
|
|
|
def test_vjp_scales_biases(self):
|
|
mx.random.seed(0)
|
|
x = mx.random.normal(shape=(2, 2, 512))
|
|
w = mx.random.normal(shape=(512, 512))
|
|
wq, s, b = mx.quantize(w, bits=4, group_size=64)
|
|
|
|
def mm(sb, x, wq):
|
|
return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum()
|
|
|
|
params = (s, b)
|
|
dparams = mx.grad(mm)((s, b), x, wq)
|
|
|
|
eps = 8e-3
|
|
# numerical grad check with a few indices
|
|
indices = [(0, 0), (11, 4), (22, 7)]
|
|
for idx in indices:
|
|
for p in [0, 1]:
|
|
params[p][idx] += eps
|
|
out_up = mm(params, x, wq)
|
|
params[p][idx] -= 2 * eps
|
|
out_down = mm(params, x, wq)
|
|
params[p][idx] += eps
|
|
num_ds = (out_up - out_down) / (2 * eps)
|
|
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mlx_tests.MLXTestRunner()
|