mxfp4 quantize/dequantize + start of optional biases

This commit is contained in:
Awni Hannun
2025-08-18 12:59:03 -07:00
committed by Awni Hannun
parent e04e17e3b6
commit 88c71d2b13
12 changed files with 638 additions and 274 deletions

View File

@@ -27,6 +27,56 @@ class TestQuantized(mlx_tests.MLXTestCase):
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
self.assertTrue(mx.all(a_hat == 0))
def test_mxfp4_quantize_dequantize(self):
lut = mx.array(
[
+0.0,
+0.5,
+1.0,
+1.5,
+2.0,
+3.0,
+4.0,
+6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
)
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
w = w.reshape(-1, 32)
w[:, 0] = 6
w = (w + 3e-6).astype(mx.bfloat16)
# Invalid bits / group size
with self.assertRaises(ValueError):
mx.quantize(w, bits=3, group_size=32, mode="mxfp4")
with self.assertRaises(ValueError):
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.all(w_hat == 0))
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
@@ -233,6 +283,71 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_mode_error_cases(self):
w = mx.random.normal(shape=(256, 256))
x = mx.random.normal(shape=(1, 256))
# Invalid mode
with self.assertRaises(ValueError):
mx.quantize(w, mode="xyz")
wq, scales, biases = mx.quantize(w, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode="xyz")
with self.assertRaises(ValueError):
mx.quantized_matmul(
x, wq, scales, biases, bits=4, group_size=32, mode="xyz"
)
rhs_indices = mx.array(0)
with self.assertRaises(ValueError):
mx.gather_qmm(
x,
wq,
scales,
biases,
rhs_indices=rhs_indices,
bits=4,
group_size=32,
mode="xyz",
)
# Only quantize floating point types
with self.assertRaises(ValueError):
mx.quantize(mx.zeros((128, 128), mx.int32))
with self.assertRaises(ValueError):
mx.quantize(mx.zeros((128, 128), mx.int32), mode="mxfp4")
# Must have bias for affine
with self.assertRaises(ValueError):
mx.dequantize(wq, scales, None, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.gather_qmm(
x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32
)
# Must be floating point
x = mx.zeros(shape=(256,), dtype=mx.int32)
scales = mx.zeros(scales.shape, dtype=mx.int32)
biases = mx.zeros(scales.shape, dtype=mx.int32)
with self.assertRaises(ValueError):
mx.dequantize(wq, scales, biases, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32)
with self.assertRaises(ValueError):
mx.gather_qmm(
x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32
)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))