Improvements in the quantizer and dequantization kernel (#1061)

This commit is contained in:
Angelos Katharopoulos
2024-05-01 18:19:11 -07:00
committed by GitHub
parent 7f7b9662ea
commit 17f57df797
3 changed files with 25 additions and 27 deletions

View File

@@ -205,13 +205,10 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
}
else if (bits == 4) {
const thread uint16_t* ws = (const thread uint16_t*)w;
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
for (int i = 0; i < (values_per_thread / 4); i++) {
result[4 * i] += x * (s[0] * (ws[i] & 0x000f) + bias);
result[4 * i + 1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
result[4 * i + 2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
result[4 * i + 3] += x * (s[3] * (ws[i] & 0xf000) + bias);
U s[2] = {scale, scale / 16.0f};
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
}
}
@@ -244,17 +241,10 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
U s[4] = {
scale,
scale / static_cast<U>(16.0f),
scale / static_cast<U>(256.0f),
scale / static_cast<U>(4096.0f)};
for (int i = 0; i < (N / 4); i++) {
w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias;
w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias;
w_local[4 * i + 2] = s[2] * (ws[i] & 0x0f00) + bias;
w_local[4 * i + 3] = s[3] * (ws[i] & 0xf000) + bias;
U s[2] = {scale, scale / static_cast<U>(16.0f)};
for (int i = 0; i < (N / 2); i++) {
w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
}
}