mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
5bit quants
This commit is contained in:
@@ -24,7 +24,7 @@ inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
|
||||
|
||||
template <typename T, int bits>
|
||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
assert(bits == 3 || bits == 6);
|
||||
static_assert(bits == 3 || bits == 5 || bits == 6);
|
||||
if (bits == 3) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||
@@ -84,7 +84,7 @@ void _qmm(
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
if (bits == 3 || bits == 5 || bits == 6) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -141,7 +141,7 @@ void _qmm_t(
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
if (bits == 3 || bits == 5 || bits == 6) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
|
||||
@@ -2555,6 +2555,7 @@ template <typename T, const int group_size, const int bits>
|
||||
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
||||
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
||||
} else if (bits == 5) {
|
||||
w += offset * bytes_per_pack;
|
||||
out[0] = (w[0] & 0x1f) * scale + bias;
|
||||
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
||||
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
||||
|
||||
Reference in New Issue
Block a user