Bug fix in quantize (#1054)

This commit is contained in:
Angelos Katharopoulos 2024-04-29 20:55:04 -07:00 committed by GitHub
parent 09f1777896
commit 8db7161c94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 9 deletions

View File

@ -3299,24 +3299,22 @@ std::tuple<array, array, array> quantize(
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s); reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array delta = maximum( array scales = maximum(
divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s), divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s),
array(1e-7, w.dtype()), array(1e-7, w.dtype()),
s); s);
array scales = squeeze(delta, -1, s);
array biases = squeeze(w_min, -1, s);
// making sure that 0 is represented exactly in the resulting quantization // making sure that 0 is represented exactly in the resulting quantization
biases = multiply(round(divide(biases, scales, s), s), scales, s); array biases = multiply(round(divide(w_min, scales, s), s), scales, s);
// Quantize and pack w // Quantize and pack w
packed_w = packed_w = astype(
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32); round(divide(subtract(packed_w, biases, s), scales, s), s), uint32);
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s); packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
packed_w = sum( packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return std::make_tuple(packed_w, scales, biases); return std::make_tuple(
packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s));
} }
array dequantize( array dequantize(

View File

@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat = mx.dequantize(w_q, scales, biases, gs, b) w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1) errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6 eps = 1e-6
self.assertTrue((errors <= (scales[..., None] + eps)).all()) self.assertTrue((2 * errors <= (scales[..., None] + eps)).all())
# test quantize/dequantize 0s # test quantize/dequantize 0s
a = mx.zeros((256, 512)) a = mx.zeros((256, 512))