mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
5bit quants
This commit is contained in:
@@ -24,7 +24,7 @@ inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
|
|||||||
|
|
||||||
template <typename T, int bits>
|
template <typename T, int bits>
|
||||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||||
assert(bits == 3 || bits == 6);
|
static_assert(bits == 3 || bits == 5 || bits == 6);
|
||||||
if (bits == 3) {
|
if (bits == 3) {
|
||||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||||
@@ -84,7 +84,7 @@ void _qmm(
|
|||||||
T scale = *scales_local++;
|
T scale = *scales_local++;
|
||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||||
if (bits == 3 || bits == 5 || bits == 6) {
|
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||||
T wl[pack_factor];
|
T wl[pack_factor];
|
||||||
extract_bits<T, bits>(w_local, wl);
|
extract_bits<T, bits>(w_local, wl);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@@ -141,7 +141,7 @@ void _qmm_t(
|
|||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
|
|
||||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||||
if (bits == 3 || bits == 5 || bits == 6) {
|
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||||
T wl[pack_factor];
|
T wl[pack_factor];
|
||||||
extract_bits<T, bits>(w_local, wl);
|
extract_bits<T, bits>(w_local, wl);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
|
|||||||
@@ -2555,6 +2555,7 @@ template <typename T, const int group_size, const int bits>
|
|||||||
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
||||||
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
||||||
} else if (bits == 5) {
|
} else if (bits == 5) {
|
||||||
|
w += offset * bytes_per_pack;
|
||||||
out[0] = (w[0] & 0x1f) * scale + bias;
|
out[0] = (w[0] & 0x1f) * scale + bias;
|
||||||
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
||||||
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
||||||
|
|||||||
Reference in New Issue
Block a user