From 066336b60e9169357ce798d0abe4e9845a091bde Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 3 Apr 2025 10:56:12 -0700 Subject: [PATCH] load q4_k from gguf --- mlx/io/gguf.cpp | 2 +- mlx/io/gguf_quants.cpp | 72 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index ed6ea11dd..f6d085efa 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -219,7 +219,7 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { while (gguf_get_tensor(ctx, &tensor)) { if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 || - tensor.type == GGUF_TYPE_Q8_0) { + tensor.type == GGUF_TYPE_Q4_K || tensor.type == GGUF_TYPE_Q8_0) { gguf_load_quantized(array_map, tensor); } else { std::string name(tensor.name, tensor.namelen); diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index 148ed6c47..31be0c244 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -70,6 +70,65 @@ void extract_q4_1_data( } } +// Extracts (weight, scales, biases) from Q4_K tensors. +// Data layout is: +// * |FP16 s_of_scales | + +// * |FP16 s_of_mins | + +// * |16 6 bit integers d,m pairs, one per sub-block of 32 ele | + +// * |256 x 4bit weights| +void extract_q4_k_data( + const gguf_tensor& tensor, + array& weights_arr, + array& scales_arr, + array& biases_arr) { + auto data = static_cast(tensor.weights_data); + auto weights = weights_arr.data(); + auto scales = scales_arr.data(); + auto biases = biases_arr.data(); + for (int64_t g = 0; g < scales_arr.size() / 8; ++g) { + auto scales_scale = *((float16_t*)data); + auto mins_scale = *((float16_t*)data + 1); + data += 4; + + /* Scale scales/mins. */ + for (int j = 0; j < 8; j++) { + uint8_t d, m; + if (j < 4) { + d = data[j] & 63; + m = data[j + 4] & 63; + } else { + d = (data[j + 4] & 0xF) | ((data[j - 4] >> 6) << 4); + m = (data[j + 4] >> 4) | ((data[j - 0] >> 6) << 4); + } + scales[g * 8 + j] = d * scales_scale; + biases[g * 8 + j] = -(m * mins_scale); + } + data += 12; + for (int i = 0; i < 8; i += 2) { + std::fill_n(weights, 32, 0); + + // First 32 weights are in the lower bits + for (int j = 0; j < 32; ++j) { + uint8_t x = (data[j] & 0x0F); + if (j % 2 != 0) { + x <<= 4; + } + weights[j / 2] += x; + } + // Last 32 weights are in the higher bits + for (int j = 0; j < 32; ++j) { + uint8_t x = (data[j] >> 4); + if (j % 2 != 0) { + x <<= 4; + } + weights[16 + j / 2] += x; + } + weights += 32; + data += 32; + } + } +} + // Extracts (weight, scales, biases) from Q8_0 tensors. // Data layout is: |16 bit scale|32 x 8bit weights|. void extract_q8_0_data( @@ -100,11 +159,12 @@ void extract_q8_0_data( void gguf_load_quantized( std::unordered_map& a, const gguf_tensor& tensor) { - uint64_t weights_per_byte; - if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1) { - weights_per_byte = 2; + int bits; + if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 || + tensor.type == GGUF_TYPE_Q4_K) { + bits = 4; } else { // tensor.type == GGUF_TYPE_Q8_0 - weights_per_byte = 1; + bits = 8; } std::string name(tensor.name, tensor.namelen); @@ -119,7 +179,7 @@ void gguf_load_quantized( } auto weights_shape = shape; - weights_shape.back() /= (weights_per_byte * 4); + weights_shape.back() = weights_shape.back() * bits / 32; auto w_nbytes = uint32.size() * std::accumulate(weights_shape.begin(), weights_shape.end(), @@ -139,6 +199,8 @@ void gguf_load_quantized( extract_q4_0_data(tensor, weights, scales, biases); } else if (tensor.type == GGUF_TYPE_Q4_1) { extract_q4_1_data(tensor, weights, scales, biases); + } else if (tensor.type == GGUF_TYPE_Q4_K) { + extract_q4_k_data(tensor, weights, scales, biases); } else if (tensor.type == GGUF_TYPE_Q8_0) { extract_q8_0_data(tensor, weights, scales, biases); }