fast cuda kernel for mx/nv quantization

This commit is contained in:
Awni Hannun
2025-10-21 11:49:58 -07:00
parent c00ccf7404
commit c961a3a557
9 changed files with 492 additions and 161 deletions

View File

@@ -61,13 +61,18 @@ class TestQuantized(mlx_tests.MLXTestCase):
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
with self.assertRaises(ValueError):
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
# Invalid output type
with self.assertRaises(ValueError):
mx.dequantize(
w_q, scales, group_size=32, bits=4, mode="mxfp4", dtype=mx.int32
)
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))