mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Improvements in the quantizer and dequantization kernel (#1061)
This commit is contained in:

committed by
GitHub

parent
7f7b9662ea
commit
17f57df797
@@ -205,13 +205,10 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const thread uint16_t* ws = (const thread uint16_t*)w;
|
||||
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4 * i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
||||
result[4 * i + 1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
||||
result[4 * i + 2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
||||
result[4 * i + 3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
||||
U s[2] = {scale, scale / 16.0f};
|
||||
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
||||
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,17 +241,10 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
U s[4] = {
|
||||
scale,
|
||||
scale / static_cast<U>(16.0f),
|
||||
scale / static_cast<U>(256.0f),
|
||||
scale / static_cast<U>(4096.0f)};
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias;
|
||||
w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias;
|
||||
w_local[4 * i + 2] = s[2] * (ws[i] & 0x0f00) + bias;
|
||||
w_local[4 * i + 3] = s[3] * (ws[i] & 0xf000) + bias;
|
||||
U s[2] = {scale, scale / static_cast<U>(16.0f)};
|
||||
for (int i = 0; i < (N / 2); i++) {
|
||||
w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
|
||||
w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user