metal kernels

This commit is contained in:
Awni Hannun
2025-10-24 08:47:24 -07:00
parent 6959732915
commit 6286e528e4
13 changed files with 319 additions and 100 deletions

View File

@@ -92,7 +92,6 @@ class TestQuantized(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.quantize(w, group_size=32, bits=7, mode="mxfp8")
w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8")
with self.assertRaises(ValueError):
@@ -102,7 +101,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8")
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-2))
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-1))
# test quantize/dequantize 0s
a = mx.zeros((256, 512))