mxfp4 works

This commit is contained in:
Awni Hannun
2025-08-19 07:49:56 -07:00
committed by Awni Hannun
parent 88c71d2b13
commit 9807ba0267
12 changed files with 2420 additions and 257 deletions

View File

@@ -218,6 +218,34 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_mxfp4_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[256, 512, 67], # M
[64, 128], # N
[0, 1, 3, 8], # B
)
for M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=32):
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
w_shape = (M, N) if B == 0 else (B, M, N)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=True,
group_size=32,
mode="mxfp4",
)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
@@ -283,6 +311,38 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_mxfp4_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
# Add a splitk
tests = list(tests)
tests.append((128, 16384, 0))
for M, N, B in tests:
with self.subTest(shape=(B, M, N)):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=False,
group_size=32,
mode="mxfp4",
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_mode_error_cases(self):
w = mx.random.normal(shape=(256, 256))
x = mx.random.normal(shape=(1, 256))
@@ -475,9 +535,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_gather_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@@ -494,6 +558,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=True,
group_size=64,
bits=4,
mode="affine",
):
with self.subTest(
M=M,
@@ -507,12 +572,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
):
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
w = mx.random.normal(
shape=batch_B + ((N, K) if transpose else (K, N))
).astype(dtype)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode)
if lhs_indices is not None:
lhs_indices = mx.array(lhs_indices)
@@ -530,8 +596,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
inputs = (
@@ -575,6 +641,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
"batch_B": (4, 1),
"rhs_indices": ((2,), (0,), (1,)),
},
{
"batch_A": (1,),
"lhs_indices": (0,),
"batch_B": (3,),
"rhs_indices": (2, 1),
"group_size": 32,
"mode": "mxfp4",
},
)
for kwargs in inputs:
@@ -618,9 +692,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@@ -640,19 +719,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
parameters = [
# L, K, D, E, I, transpose
(32, 512, 512, 4, 2, True),
(32, 512, 544, 4, 2, True),
(133, 512, 512, 4, 2, True),
(133, 512, 555, 4, 2, True),
(133, 512, 512, 4, 2, True),
(64, 512, 512, 4, 2, False),
(64, 512, 544, 4, 2, False),
(133, 512, 512, 4, 2, False),
(133, 512, 544, 4, 2, False),
(133, 512, 555, 4, 2, False),
(64, 512, 512, 4, 2, False),
(32, 512, 512, 4, 2, True, "affine"),
(32, 512, 544, 4, 2, True, "mxfp4"),
(133, 512, 512, 4, 2, True, "affine"),
(133, 512, 555, 4, 2, True, "affine"),
(133, 512, 512, 4, 2, True, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
(64, 512, 544, 4, 2, False, "mxfp4"),
(133, 512, 512, 4, 2, False, "affine"),
(133, 512, 544, 4, 2, False, "affine"),
(133, 512, 555, 4, 2, False, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
]
for L, K, D, E, I, transpose in parameters:
for L, K, D, E, I, transpose, mode in parameters:
if mode == "mxfp4":
group_size = 32
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
@@ -661,14 +742,28 @@ class TestQuantized(mlx_tests.MLXTestCase):
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
y2 = mx.gather_qmm(
x,
*wq,
group_size=group_size,
mode=mode,
transpose=transpose,
rhs_indices=indices
)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
xs,
*wq,
group_size=group_size,
mode=mode,
rhs_indices=idx,
transpose=transpose,
sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)