Affine quant always in fp32 (#1925)

* do affine quant in fp32

* static cast
This commit is contained in:
Alex Barron
2025-03-05 01:50:19 +00:00
committed by GitHub
parent 3835a428c5
commit fd0d63ba5b
3 changed files with 39 additions and 33 deletions

View File

@@ -827,14 +827,17 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
auto wshape = w.shape();
wshape.back() = -1;
array zero(0, w.dtype());
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
array eps(1e-7, w.dtype());
array zero(0, float32);
array n_bins((1 << bits) - 1, float32); // 2**bits - 1
array eps(1e-7, float32);
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
w_max = astype(w_max, float32, s);
w_min = astype(w_min, float32, s);
array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales =
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
@@ -845,6 +848,9 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
array biases = where(equal(q0, zero, s), zero, edge, s);
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
scales = astype(scales, w.dtype(), s);
biases = astype(biases, w.dtype(), s);
return {
reshape(packed_w, wshape, s),
reshape(scales, wshape, s),