Bug fix in quantize (#1054)

This commit is contained in:
Angelos Katharopoulos
2024-04-29 20:55:04 -07:00
committed by GitHub
parent 09f1777896
commit 8db7161c94
2 changed files with 7 additions and 9 deletions

View File

@@ -3299,24 +3299,22 @@ std::tuple<array, array, array> quantize(
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array delta = maximum(
array scales = maximum(
divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s),
array(1e-7, w.dtype()),
s);
array scales = squeeze(delta, -1, s);
array biases = squeeze(w_min, -1, s);
// making sure that 0 is represented exactly in the resulting quantization
biases = multiply(round(divide(biases, scales, s), s), scales, s);
array biases = multiply(round(divide(w_min, scales, s), s), scales, s);
// Quantize and pack w
packed_w =
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);
packed_w = astype(
round(divide(subtract(packed_w, biases, s), scales, s), s), uint32);
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return std::make_tuple(packed_w, scales, biases);
return std::make_tuple(
packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s));
}
array dequantize(