This commit is contained in:
Awni Hannun
2025-07-10 14:33:44 -07:00
parent e4a3be4411
commit 73bb93318f

View File

@@ -172,30 +172,30 @@ __global__ void affine_dequantize(
if constexpr (bits == 3) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x7) * scale + bias;
out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
out[0] = static_cast<T>(w[0] & 0x7) * scale + bias;
out[1] = static_cast<T>((w[0] & 0x38) >> 3) * scale + bias;
out[2] = static_cast<T>(((w[0] & 0xc0) >> 6) + static_cast<T>((w[1] & 0x1) << 2)) * scale + bias;
out[3] = static_cast<T>((w[1] & 0xe) >> 1) * scale + bias;
out[4] = static_cast<T>((w[1] & 0x70) >> 4) * scale + bias;
out[5] = static_cast<T>(((w[1] & 0x80) >> 7) + static_cast<T>((w[2] & 0x3) << 1)) * scale + bias;
out[6] = static_cast<T>((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = static_cast<T>((w[2] & 0xe0) >> 5) * scale + bias;
} else if constexpr (bits == 5) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x1f) * scale + bias;
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
out[0] = static_cast<T>(w[0] & 0x1f) * scale + bias;
out[1] = static_cast<T>(((w[0] & 0xe0) >> 5) + static_cast<T>((w[1] & 0x3) << 3)) * scale + bias;
out[2] = static_cast<T>((w[1] & 0x7c) >> 2) * scale + bias;
out[3] = static_cast<T>(((w[1] & 0x80) >> 7) + static_cast<T>((w[2] & 0xf) << 1)) * scale + bias;
out[4] = static_cast<T>(((w[2] & 0xf0) >> 4) + static_cast<T>((w[3] & 0x1) << 4)) * scale + bias;
out[5] = static_cast<T>((w[3] & 0x3e) >> 1) * scale + bias;
out[6] = static_cast<T>(((w[3] & 0xc0) >> 6) + static_cast<T>((w[4] & 0x7) << 2)) * scale + bias;
out[7] = static_cast<T>((w[4] & 0xf8) >> 3) * scale + bias;
} else if constexpr (bits == 6) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x3f) * scale + bias;
out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
out[0] = static_cast<T>(w[0] & 0x3f) * scale + bias;
out[1] = static_cast<T>(((w[0] >> 6) & 0x03) + static_cast<T>((w[1] & 0x0f) << 2)) * scale + bias;
out[2] = static_cast<T>(((w[1] >> 4) & 0x0f) + static_cast<T>((w[2] & 0x03) << 4)) * scale + bias;
out[3] = static_cast<T>((w[2] >> 2) & 0x3f) * scale + bias;
} else {
uint val = w[offset];
#pragma clang loop unroll(full)
@@ -208,7 +208,7 @@ __global__ void affine_dequantize(
} else if (bits == 8) {
d = val;
}
out[i] = scale * d + bias;
out[i] = scale * static_cast<T>(d) + bias;
}
}
}