From 8db7161c94767d324fbf61e8bf23049efe04a6b7 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 29 Apr 2024 20:55:04 -0700 Subject: [PATCH] Bug fix in quantize (#1054) --- mlx/ops.cpp | 14 ++++++-------- python/tests/test_quantized.py | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 8ad7fc425..8c56601e2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3299,24 +3299,22 @@ std::tuple quantize( reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - array delta = maximum( + array scales = maximum( divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s), array(1e-7, w.dtype()), s); - array scales = squeeze(delta, -1, s); - array biases = squeeze(w_min, -1, s); - // making sure that 0 is represented exactly in the resulting quantization - biases = multiply(round(divide(biases, scales, s), s), scales, s); + array biases = multiply(round(divide(w_min, scales, s), s), scales, s); // Quantize and pack w - packed_w = - astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32); + packed_w = astype( + round(divide(subtract(packed_w, biases, s), scales, s), s), uint32); packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s); packed_w = sum( multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); - return std::make_tuple(packed_w, scales, biases); + return std::make_tuple( + packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s)); } array dequantize( diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 6e30bac5e..32026c321 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase): w_hat = mx.dequantize(w_q, scales, biases, gs, b) errors = (w - w_hat).abs().reshape(*scales.shape, -1) eps = 1e-6 - self.assertTrue((errors <= (scales[..., None] + eps)).all()) + self.assertTrue((2 * errors <= (scales[..., None] + eps)).all()) # test quantize/dequantize 0s a = mx.zeros((256, 512))