Affine quant always in fp32 (#1925)

* do affine quant in fp32

* static cast
This commit is contained in:
Alex Barron
2025-03-05 01:50:19 +00:00
committed by GitHub
parent 3835a428c5
commit fd0d63ba5b
3 changed files with 39 additions and 33 deletions

View File

@@ -543,8 +543,8 @@ void quantize(
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
float n_bins = (1 << bits) - 1;
float eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
@@ -554,32 +554,30 @@ void quantize(
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
float w_min = std::numeric_limits<float>::infinity();
float w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
w_max = std::max(w_max, (float)w[w_idx + j]);
w_min = std::min(w_min, (float)w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
float scale = std::max((w_max - w_min) / n_bins, eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
float edge = mask ? w_min : w_max;
float q0 = std::rint(edge / scale);
float bias = 0;
if (q0 != 0) {
scale = edge / q0;
bias = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
float w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - bias) / scale);
w_el = std::min(std::max(w_el, 0.0f), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
@@ -590,6 +588,8 @@ void quantize(
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
scales[i] = static_cast<T>(scale);
biases[i] = static_cast<T>(bias);
}
}