From f82c7aa9b884715380363849d0b32a0d7e7465c7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 28 May 2025 11:45:06 -0700 Subject: [PATCH] 5bit quants --- mlx/backend/cpu/quantized.cpp | 6 +++--- mlx/backend/metal/kernels/quantized.h | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 85d6a7a47..ee8e56cc0 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -24,7 +24,7 @@ inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) { template void extract_bits(const uint8_t* w_in, T* w_out) { - assert(bits == 3 || bits == 6); + static_assert(bits == 3 || bits == 5 || bits == 6); if (bits == 3) { w_out[0] = static_cast(w_in[0] & 0x7); w_out[1] = static_cast((w_in[0] & 0x38) >> 3); @@ -84,7 +84,7 @@ void _qmm( T scale = *scales_local++; T bias = *biases_local++; for (int ng = 0; ng < packs_in_group; ng++) { - if (bits == 3 || bits == 5 || bits == 6) { + if constexpr (bits == 3 || bits == 5 || bits == 6) { T wl[pack_factor]; extract_bits(w_local, wl); #pragma clang loop unroll(full) @@ -141,7 +141,7 @@ void _qmm_t( T bias = *biases_local++; for (int kw = 0; kw < packs_in_group; kw++) { - if (bits == 3 || bits == 5 || bits == 6) { + if constexpr (bits == 3 || bits == 5 || bits == 6) { T wl[pack_factor]; extract_bits(w_local, wl); #pragma clang loop unroll(full) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 7d9feff7f..fea6f1460 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2555,6 +2555,7 @@ template out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; } else if (bits == 5) { + w += offset * bytes_per_pack; out[0] = (w[0] & 0x1f) * scale + bias; out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;