mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Bug fix in quantize (#1054)
This commit is contained in:
parent
09f1777896
commit
8db7161c94
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -3299,24 +3299,22 @@ std::tuple<array, array, array> quantize(
|
|||||||
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
||||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
array delta = maximum(
|
array scales = maximum(
|
||||||
divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s),
|
divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s),
|
||||||
array(1e-7, w.dtype()),
|
array(1e-7, w.dtype()),
|
||||||
s);
|
s);
|
||||||
array scales = squeeze(delta, -1, s);
|
|
||||||
array biases = squeeze(w_min, -1, s);
|
|
||||||
|
|
||||||
// making sure that 0 is represented exactly in the resulting quantization
|
// making sure that 0 is represented exactly in the resulting quantization
|
||||||
biases = multiply(round(divide(biases, scales, s), s), scales, s);
|
array biases = multiply(round(divide(w_min, scales, s), s), scales, s);
|
||||||
|
|
||||||
// Quantize and pack w
|
// Quantize and pack w
|
||||||
packed_w =
|
packed_w = astype(
|
||||||
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);
|
round(divide(subtract(packed_w, biases, s), scales, s), s), uint32);
|
||||||
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
|
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
|
||||||
packed_w = sum(
|
packed_w = sum(
|
||||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||||
|
|
||||||
return std::make_tuple(packed_w, scales, biases);
|
return std::make_tuple(
|
||||||
|
packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s));
|
||||||
}
|
}
|
||||||
|
|
||||||
array dequantize(
|
array dequantize(
|
||||||
|
@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
self.assertTrue((errors <= (scales[..., None] + eps)).all())
|
self.assertTrue((2 * errors <= (scales[..., None] + eps)).all())
|
||||||
|
|
||||||
# test quantize/dequantize 0s
|
# test quantize/dequantize 0s
|
||||||
a = mx.zeros((256, 512))
|
a = mx.zeros((256, 512))
|
||||||
|
Loading…
Reference in New Issue
Block a user