3 and 6 bit quantization (#1613)

* Support 3 and 6 bit quantization
This commit is contained in:
Alex Barron
2024-11-22 10:22:13 -08:00
committed by GitHub
parent 0c5eea226b
commit c79f6a4a8c
12 changed files with 633 additions and 419 deletions

View File

@@ -549,18 +549,6 @@ class TestFast(mlx_tests.MLXTestCase):
)(x)
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
def test_affine_quantize(self):
mx.random.seed(7)
x = mx.random.uniform(shape=(4, 1024))
for bits in (2, 4, 8):
for group_size in (32, 64, 128):
with self.subTest(bits=bits, group_size=group_size):
w, scales, biases = mx.quantize(x, bits=bits, group_size=group_size)
w_p = mx.fast.affine_quantize(
x, scales, biases, bits=bits, group_size=group_size
)
self.assertTrue(mx.allclose(w, w_p))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_basic(self):
mx.random.seed(7)