mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
GGUF support (#350)
* Initial GGUF support for tensor fields. --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -3,4 +3,31 @@ target_sources(
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
|
||||
)
|
||||
|
||||
MESSAGE(STATUS "Downloading json")
|
||||
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
target_include_directories(
|
||||
mlx PUBLIC
|
||||
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
|
||||
$<INSTALL_INTERFACE:include/json>
|
||||
)
|
||||
|
||||
MESSAGE(STATUS "Downloading gguflib")
|
||||
FetchContent_Declare(gguflib
|
||||
GIT_REPOSITORY https://github.com/antirez/gguf-tools/
|
||||
GIT_TAG af7d88d808a7608a33723fba067036202910acb3
|
||||
)
|
||||
FetchContent_MakeAvailable(gguflib)
|
||||
target_include_directories(
|
||||
mlx PUBLIC
|
||||
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/gguflib>
|
||||
)
|
||||
add_library(
|
||||
gguflib SHARED
|
||||
${gguflib_SOURCE_DIR}/fp16.c
|
||||
${gguflib_SOURCE_DIR}/gguflib.c)
|
||||
target_link_libraries(mlx $<BUILD_INTERFACE:gguflib>)
|
||||
|
163
mlx/io/gguf.cpp
Normal file
163
mlx/io/gguf.cpp
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
extern "C" {
|
||||
#include <gguflib.h>
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::optional<uint32_t> dtype_to_gguf_tensor_type(const Dtype& dtype) {
|
||||
switch (dtype) {
|
||||
case float32:
|
||||
return GGUF_TYPE_F32;
|
||||
case float16:
|
||||
return GGUF_TYPE_F16;
|
||||
case int8:
|
||||
return GGUF_TYPE_I8;
|
||||
case int16:
|
||||
return GGUF_TYPE_I16;
|
||||
case int32:
|
||||
return GGUF_TYPE_I32;
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
|
||||
switch (gguf_type) {
|
||||
case GGUF_TYPE_F32:
|
||||
return float32;
|
||||
case GGUF_TYPE_F16:
|
||||
return float16;
|
||||
case GGUF_TYPE_I8:
|
||||
return int8;
|
||||
case GGUF_TYPE_I16:
|
||||
return int16;
|
||||
case GGUF_TYPE_I32:
|
||||
return int32;
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
|
||||
std::optional<Dtype> equivalent_dtype = gguf_type_to_dtype(tensor->type);
|
||||
// If there's an equivalent type, we can simply copy.
|
||||
if (equivalent_dtype.has_value()) {
|
||||
allocator::Buffer buffer = allocator::malloc(tensor->bsize);
|
||||
memcpy(
|
||||
buffer.raw_ptr(),
|
||||
tensor->weights_data,
|
||||
tensor->num_weights * equivalent_dtype.value().size);
|
||||
return {buffer, equivalent_dtype.value()};
|
||||
}
|
||||
// Otherwise, we convert to float16.
|
||||
// TODO: Add other dequantization options.
|
||||
int16_t* data = gguf_tensor_to_f16(tensor);
|
||||
if (data == NULL) {
|
||||
throw std::runtime_error("[load_gguf] gguf_tensor_to_f16 failed");
|
||||
}
|
||||
const size_t new_size = tensor->num_weights * sizeof(int16_t);
|
||||
allocator::Buffer buffer = allocator::malloc(new_size);
|
||||
memcpy(buffer.raw_ptr(), data, new_size);
|
||||
free(data);
|
||||
return {buffer, float16};
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> load_gguf(
|
||||
const std::string& file,
|
||||
StreamOrDevice s) {
|
||||
std::unordered_map<std::string, array> result;
|
||||
gguf_ctx* ctx = gguf_open(file.c_str());
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
}
|
||||
gguf_skip_key_values_section(ctx);
|
||||
gguf_tensor tensor;
|
||||
while (gguf_get_tensor(ctx, &tensor)) {
|
||||
std::vector<int> shape;
|
||||
// The dimension order in GGML is the reverse of the order used in MLX.
|
||||
for (int i = tensor.ndim - 1; i >= 0; i--) {
|
||||
shape.push_back(tensor.dim[i]);
|
||||
}
|
||||
const auto& [data, dtype] = extract_tensor_data(&tensor);
|
||||
array loaded_array = array(data, shape, dtype);
|
||||
std::string name = std::string(tensor.name, tensor.namelen);
|
||||
result.insert({name, loaded_array});
|
||||
}
|
||||
gguf_close(ctx);
|
||||
return result;
|
||||
}
|
||||
|
||||
void save_gguf(std::string file, std::unordered_map<std::string, array> a) {
|
||||
// Add .gguf to file name if it is not there
|
||||
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
|
||||
file += ".gguf";
|
||||
}
|
||||
gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE);
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[save_gguf] gguf_create failed");
|
||||
}
|
||||
|
||||
// Tensor offsets are relative to data section, so we start at offset 0.
|
||||
uint64_t tensor_offset = 0;
|
||||
|
||||
// First, append the tensor info
|
||||
for (auto& [key, arr] : a) {
|
||||
arr.eval();
|
||||
|
||||
// Try to make it row contiguous
|
||||
if (!arr.flags().row_contiguous) {
|
||||
arr = reshape(flatten(arr), arr.shape());
|
||||
arr.eval();
|
||||
}
|
||||
|
||||
// Has to be row-major now but, check one more time in case
|
||||
// any of the above change in the future
|
||||
if (!arr.flags().row_contiguous) {
|
||||
throw std::invalid_argument(
|
||||
"[save_gguf] can only serialize row-major arrays");
|
||||
}
|
||||
|
||||
tensor_offset += gguf_get_alignment_padding(ctx->alignment, tensor_offset);
|
||||
const std::optional<uint32_t> gguf_type =
|
||||
dtype_to_gguf_tensor_type(arr.dtype());
|
||||
if (!gguf_type.has_value()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[save_gguf] dtype " << arr.dtype() << " is not supported";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
const char* tensorname = key.c_str();
|
||||
const uint64_t namelen = key.length();
|
||||
const uint32_t num_dim = arr.ndim();
|
||||
uint64_t dim[num_dim];
|
||||
for (int i = 0; i < num_dim; i++) {
|
||||
dim[i] = arr.shape()[num_dim - 1 - i];
|
||||
}
|
||||
if (!gguf_append_tensor_info(
|
||||
ctx,
|
||||
tensorname,
|
||||
namelen,
|
||||
num_dim,
|
||||
dim,
|
||||
gguf_type.value(),
|
||||
tensor_offset)) {
|
||||
throw std::runtime_error("[save_gguf] gguf_append_tensor_info failed");
|
||||
}
|
||||
tensor_offset += arr.nbytes();
|
||||
}
|
||||
|
||||
// Then, append the tensor weights
|
||||
for (const auto& [key, arr] : a) {
|
||||
if (!gguf_append_tensor_data(ctx, (void*)arr.data<void>(), arr.nbytes())) {
|
||||
throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed");
|
||||
}
|
||||
}
|
||||
gguf_close(ctx);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,7 +1,32 @@
|
||||
#include "mlx/io/safetensor.h"
|
||||
|
||||
// Copyright © 2023 Apple Inc.
|
||||
//
|
||||
#include <json.hpp>
|
||||
#include <stack>
|
||||
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
#define ST_F16 "F16"
|
||||
#define ST_BF16 "BF16"
|
||||
#define ST_F32 "F32"
|
||||
|
||||
#define ST_BOOL "BOOL"
|
||||
#define ST_I8 "I8"
|
||||
#define ST_I16 "I16"
|
||||
#define ST_I32 "I32"
|
||||
#define ST_I64 "I64"
|
||||
#define ST_U8 "U8"
|
||||
#define ST_U16 "U16"
|
||||
#define ST_U32 "U32"
|
||||
#define ST_U64 "U64"
|
||||
|
||||
// Note: Complex numbers aren't in the spec yet so this could change -
|
||||
// https://github.com/huggingface/safetensors/issues/389
|
||||
#define ST_C64 "C64"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string dtype_to_safetensor_str(Dtype t) {
|
||||
|
@@ -1,32 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define ST_F16 "F16"
|
||||
#define ST_BF16 "BF16"
|
||||
#define ST_F32 "F32"
|
||||
|
||||
#define ST_BOOL "BOOL"
|
||||
#define ST_I8 "I8"
|
||||
#define ST_I16 "I16"
|
||||
#define ST_I32 "I32"
|
||||
#define ST_I64 "I64"
|
||||
#define ST_U8 "U8"
|
||||
#define ST_U16 "U16"
|
||||
#define ST_U32 "U32"
|
||||
#define ST_U64 "U64"
|
||||
|
||||
// Note: Complex numbers aren't in the spec yet so this could change -
|
||||
// https://github.com/huggingface/safetensors/issues/389
|
||||
#define ST_C64 "C64"
|
||||
} // namespace mlx::core
|
@@ -1104,4 +1104,12 @@ void save_safetensors(
|
||||
void save_safetensors(
|
||||
const std::string& file,
|
||||
std::unordered_map<std::string, array>);
|
||||
|
||||
/** Load array map from .gguf file format */
|
||||
std::unordered_map<std::string, array> load_gguf(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
void save_gguf(std::string file, std::unordered_map<std::string, array> a);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user