Add mode parameter for quantization (#2499)

* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
This commit is contained in:
Awni Hannun
2025-08-28 06:45:26 -07:00
committed by GitHub
parent 7ef8a6f2d5
commit 70560b6bd5
28 changed files with 3635 additions and 757 deletions

View File

@@ -27,6 +27,56 @@ class TestQuantized(mlx_tests.MLXTestCase):
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
self.assertTrue(mx.all(a_hat == 0))
def test_mxfp4_quantize_dequantize(self):
lut = mx.array(
[
+0.0,
+0.5,
+1.0,
+1.5,
+2.0,
+3.0,
+4.0,
+6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
)
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
w = w.reshape(-1, 32)
w[:, 0] = 6
w = (w + 3e-6).astype(mx.bfloat16)
# Invalid bits / group size
with self.assertRaises(ValueError):
mx.quantize(w, bits=3, group_size=32, mode="mxfp4")
with self.assertRaises(ValueError):
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.all(w_hat == 0))
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
@@ -168,6 +218,34 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_mxfp4_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[256, 512, 67], # M
[64, 128], # N
[0, 1, 3, 8], # B
)
for M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=32):
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
w_shape = (M, N) if B == 0 else (B, M, N)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=True,
group_size=32,
mode="mxfp4",
)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
@@ -233,6 +311,103 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_mxfp4_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
# Add a splitk
tests = list(tests)
tests.append((128, 16384, 0))
for M, N, B in tests:
with self.subTest(shape=(B, M, N)):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=False,
group_size=32,
mode="mxfp4",
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_mode_error_cases(self):
w = mx.random.normal(shape=(256, 256))
x = mx.random.normal(shape=(1, 256))
# Invalid mode
with self.assertRaises(ValueError):
mx.quantize(w, mode="xyz")
wq, scales, biases = mx.quantize(w, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode="xyz")
with self.assertRaises(ValueError):
mx.quantized_matmul(
x, wq, scales, biases, bits=4, group_size=32, mode="xyz"
)
rhs_indices = mx.array(0)
with self.assertRaises(ValueError):
mx.gather_qmm(
x,
wq,
scales,
biases,
rhs_indices=rhs_indices,
bits=4,
group_size=32,
mode="xyz",
)
# Only quantize floating point types
with self.assertRaises(ValueError):
mx.quantize(mx.zeros((128, 128), mx.int32))
with self.assertRaises(ValueError):
mx.quantize(mx.zeros((128, 128), mx.int32), mode="mxfp4")
# Must have bias for affine
with self.assertRaises(ValueError):
mx.dequantize(wq, scales, None, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.gather_qmm(
x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32
)
# Must be floating point
x = mx.zeros(shape=(256,), dtype=mx.int32)
scales = mx.zeros(scales.shape, dtype=mx.int32)
biases = mx.zeros(scales.shape, dtype=mx.int32)
with self.assertRaises(ValueError):
mx.dequantize(wq, scales, biases, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.gather_qmm(
x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32
)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))
@@ -360,9 +535,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_gather_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@@ -379,6 +558,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=True,
group_size=64,
bits=4,
mode="affine",
):
with self.subTest(
M=M,
@@ -392,12 +572,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
):
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
w = mx.random.normal(
shape=batch_B + ((N, K) if transpose else (K, N))
).astype(dtype)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode)
if lhs_indices is not None:
lhs_indices = mx.array(lhs_indices)
@@ -415,8 +596,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
inputs = (
@@ -460,6 +641,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
"batch_B": (4, 1),
"rhs_indices": ((2,), (0,), (1,)),
},
{
"batch_A": (1,),
"lhs_indices": (0,),
"batch_B": (3,),
"rhs_indices": (2, 1),
"group_size": 32,
"mode": "mxfp4",
},
)
for kwargs in inputs:
@@ -503,9 +692,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@@ -525,19 +719,23 @@ class TestQuantized(mlx_tests.MLXTestCase):
parameters = [
# L, K, D, E, I, transpose
(32, 512, 512, 4, 2, True),
(32, 512, 544, 4, 2, True),
(133, 512, 512, 4, 2, True),
(133, 512, 555, 4, 2, True),
(133, 512, 512, 4, 2, True),
(64, 512, 512, 4, 2, False),
(64, 512, 544, 4, 2, False),
(133, 512, 512, 4, 2, False),
(133, 512, 544, 4, 2, False),
(133, 512, 555, 4, 2, False),
(64, 512, 512, 4, 2, False),
(32, 512, 512, 4, 2, True, "affine"),
(32, 512, 544, 4, 2, True, "mxfp4"),
(133, 512, 512, 4, 2, True, "affine"),
(133, 512, 555, 4, 2, True, "affine"),
(133, 512, 512, 4, 2, True, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
(64, 512, 544, 4, 2, False, "mxfp4"),
(133, 512, 512, 4, 2, False, "affine"),
(133, 512, 544, 4, 2, False, "affine"),
(133, 512, 555, 4, 2, False, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
]
for L, K, D, E, I, transpose in parameters:
for L, K, D, E, I, transpose, mode in parameters:
if mode == "mxfp4":
group_size = 32
else:
group_size = 64
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
@@ -546,14 +744,28 @@ class TestQuantized(mlx_tests.MLXTestCase):
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
y2 = mx.gather_qmm(
x,
*wq,
group_size=group_size,
mode=mode,
transpose=transpose,
rhs_indices=indices
)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
xs,
*wq,
group_size=group_size,
mode=mode,
rhs_indices=idx,
transpose=transpose,
sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)