mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Bug fix in quantize (#1054)
This commit is contained in:
parent
09f1777896
commit
8db7161c94
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -3299,24 +3299,22 @@ std::tuple<array, array, array> quantize(
|
||||
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array delta = maximum(
|
||||
array scales = maximum(
|
||||
divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s),
|
||||
array(1e-7, w.dtype()),
|
||||
s);
|
||||
array scales = squeeze(delta, -1, s);
|
||||
array biases = squeeze(w_min, -1, s);
|
||||
|
||||
// making sure that 0 is represented exactly in the resulting quantization
|
||||
biases = multiply(round(divide(biases, scales, s), s), scales, s);
|
||||
array biases = multiply(round(divide(w_min, scales, s), s), scales, s);
|
||||
|
||||
// Quantize and pack w
|
||||
packed_w =
|
||||
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);
|
||||
packed_w = astype(
|
||||
round(divide(subtract(packed_w, biases, s), scales, s), s), uint32);
|
||||
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
|
||||
return std::make_tuple(packed_w, scales, biases);
|
||||
return std::make_tuple(
|
||||
packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s));
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
|
@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||
eps = 1e-6
|
||||
self.assertTrue((errors <= (scales[..., None] + eps)).all())
|
||||
self.assertTrue((2 * errors <= (scales[..., None] + eps)).all())
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
|
Loading…
Reference in New Issue
Block a user