5bit quants

This commit is contained in:
Awni Hannun
2025-05-28 11:45:06 -07:00
parent 9754ea5f63
commit f82c7aa9b8
2 changed files with 4 additions and 3 deletions

View File

@@ -24,7 +24,7 @@ inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
static_assert(bits == 3 || bits == 5 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
@@ -84,7 +84,7 @@ void _qmm(
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
if (bits == 3 || bits == 5 || bits == 6) {
if constexpr (bits == 3 || bits == 5 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
@@ -141,7 +141,7 @@ void _qmm_t(
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
if (bits == 3 || bits == 5 || bits == 6) {
if constexpr (bits == 3 || bits == 5 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)

View File

@@ -2555,6 +2555,7 @@ template <typename T, const int group_size, const int bits>
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
} else if (bits == 5) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x1f) * scale + bias;
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;