* fix gguf

* comment
This commit is contained in:
Awni Hannun 2024-07-18 07:35:35 -07:00 committed by GitHub
parent 2f83d6e4b7
commit df124e018a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 6 deletions

View File

@ -217,13 +217,11 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
};
while (gguf_get_tensor(ctx, &tensor)) {
std::string name(tensor.name, tensor.namelen);
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
tensor.type == GGUF_TYPE_Q8_0) {
gguf_load_quantized(array_map, tensor);
} else {
std::string name = std::string(tensor.name, tensor.namelen);
std::string name(tensor.name, tensor.namelen);
const auto& [data, dtype] = extract_tensor_data(&tensor);
array loaded_array = array(data, get_shape(tensor), dtype);
check_insert(array_map.insert({name, loaded_array}));
@ -233,6 +231,15 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
}
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
bool exists;
{
std::ifstream f(file.c_str());
exists = f.good();
}
if (!exists) {
throw std::invalid_argument("[load_gguf] Failed to open " + file);
}
gguf_ctx* ctx = gguf_open(file.data());
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");

View File

@ -9,7 +9,8 @@
namespace mlx::core {
void unpack_32_4(uint8_t* data, int8_t* dst) {
for (int64_t j = 0; j < 16; ++j) {
std::fill_n(dst, 16, 0);
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
if (j % 2 != 0) {
x <<= 4;
@ -17,7 +18,7 @@ void unpack_32_4(uint8_t* data, int8_t* dst) {
dst[j / 2] += x;
}
// Last 16 weights are in the higher bits
for (int64_t j = 0; j < 16; ++j) {
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] >> 4);
if (j % 2 != 0) {
x <<= 4;
@ -134,7 +135,6 @@ void gguf_load_quantized(
array scales(allocator::malloc(sb_nbytes), shape, float16);
array biases(allocator::malloc(sb_nbytes), std::move(shape), float16);
if (tensor.type == GGUF_TYPE_Q4_0) {
extract_q4_0_data(tensor, weights, scales, biases);
} else if (tensor.type == GGUF_TYPE_Q4_1) {