mxfp4 works

This commit is contained in:
Awni Hannun
2025-08-19 07:49:56 -07:00
committed by Awni Hannun
parent 88c71d2b13
commit 9807ba0267
12 changed files with 2420 additions and 257 deletions

View File

@@ -98,11 +98,11 @@ class QuantizedEmbedding(Module):
# Initialize the quantized weight
scale = math.sqrt(1 / dims)
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
self.scales = scales_biases
(self.scales,) = scales_biases
self.num_embeddings = num_embeddings
self.dims = dims
@@ -155,12 +155,16 @@ class QuantizedEmbedding(Module):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
ql.weight, *scales_biases = mx.quantize(
embedding_layer.weight,
group_size,
bits,
mode=mode,
)
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
(ql.scales,) = scales_biases
return ql
@@ -210,11 +214,11 @@ class QuantizedLinear(Module):
high=scale,
shape=(output_dims, input_dims),
)
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
self.scales = scales_biases
(self.scales,) = scales_biases
# And bias if needed
if bias:
@@ -257,7 +261,7 @@ class QuantizedLinear(Module):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, scales_biases = mx.quantize(
ql.weight, *scales_biases = mx.quantize(
linear_layer.weight,
group_size,
bits,
@@ -266,7 +270,7 @@ class QuantizedLinear(Module):
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
ql.scales = scales_biases
(ql.scales,) = scales_biases
if "bias" in linear_layer:
ql.bias = linear_layer.bias

View File

@@ -198,6 +198,12 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
nn.quantize(m, group_size=32, mode="mxfp4")
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
self.assertTrue(isinstance(m.layers[2].scales, mx.array))
def test_quantize_freeze(self):
lin = nn.Linear(512, 512)
qlin = lin.to_quantized()

View File

@@ -218,6 +218,34 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_mxfp4_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[256, 512, 67], # M
[64, 128], # N
[0, 1, 3, 8], # B
)
for M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=32):
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
w_shape = (M, N) if B == 0 else (B, M, N)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=True,
group_size=32,
mode="mxfp4",
)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
@@ -283,6 +311,38 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_mxfp4_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
# Add a splitk
tests = list(tests)
tests.append((128, 16384, 0))
for M, N, B in tests:
with self.subTest(shape=(B, M, N)):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=False,
group_size=32,
mode="mxfp4",
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_mode_error_cases(self):
w = mx.random.normal(shape=(256, 256))
x = mx.random.normal(shape=(1, 256))
@@ -475,9 +535,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_gather_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@@ -494,6 +558,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=True,
group_size=64,
bits=4,
mode="affine",
):
with self.subTest(
M=M,
@@ -507,12 +572,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
):
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
w = mx.random.normal(
shape=batch_B + ((N, K) if transpose else (K, N))
).astype(dtype)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode)
if lhs_indices is not None:
lhs_indices = mx.array(lhs_indices)
@@ -530,8 +596,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
inputs = (
@@ -575,6 +641,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
"batch_B": (4, 1),
"rhs_indices": ((2,), (0,), (1,)),
},
{
"batch_A": (1,),
"lhs_indices": (0,),
"batch_B": (3,),
"rhs_indices": (2, 1),
"group_size": 32,
"mode": "mxfp4",
},
)
for kwargs in inputs:
@@ -618,9 +692,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@@ -640,19 +719,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
parameters = [
# L, K, D, E, I, transpose
(32, 512, 512, 4, 2, True),
(32, 512, 544, 4, 2, True),
(133, 512, 512, 4, 2, True),
(133, 512, 555, 4, 2, True),
(133, 512, 512, 4, 2, True),
(64, 512, 512, 4, 2, False),
(64, 512, 544, 4, 2, False),
(133, 512, 512, 4, 2, False),
(133, 512, 544, 4, 2, False),
(133, 512, 555, 4, 2, False),
(64, 512, 512, 4, 2, False),
(32, 512, 512, 4, 2, True, "affine"),
(32, 512, 544, 4, 2, True, "mxfp4"),
(133, 512, 512, 4, 2, True, "affine"),
(133, 512, 555, 4, 2, True, "affine"),
(133, 512, 512, 4, 2, True, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
(64, 512, 544, 4, 2, False, "mxfp4"),
(133, 512, 512, 4, 2, False, "affine"),
(133, 512, 544, 4, 2, False, "affine"),
(133, 512, 555, 4, 2, False, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
]
for L, K, D, E, I, transpose in parameters:
for L, K, D, E, I, transpose, mode in parameters:
if mode == "mxfp4":
group_size = 32
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
@@ -661,14 +742,28 @@ class TestQuantized(mlx_tests.MLXTestCase):
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
y2 = mx.gather_qmm(
x,
*wq,
group_size=group_size,
mode=mode,
transpose=transpose,
rhs_indices=indices
)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
xs,
*wq,
group_size=group_size,
mode=mode,
rhs_indices=idx,
transpose=transpose,
sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)