From b7f905787e7217015c7449b00bcaba171c0f1686 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Wed, 10 Jan 2024 16:22:48 -0500 Subject: [PATCH] GGUF support (#350) * Initial GGUF support for tensor fields. --------- Co-authored-by: Awni Hannun --- CMakeLists.txt | 11 +-- docs/src/python/ops.rst | 1 + mlx/io/CMakeLists.txt | 27 +++++++ mlx/io/gguf.cpp | 163 ++++++++++++++++++++++++++++++++++++++ mlx/io/safetensor.cpp | 29 ++++++- mlx/io/safetensor.h | 32 -------- mlx/ops.h | 8 ++ python/src/load.cpp | 22 +++++ python/src/load.h | 5 ++ python/src/ops.cpp | 34 +++++++- python/tests/test_load.py | 40 +++++++++- tests/load_tests.cpp | 45 +++++++++-- 12 files changed, 362 insertions(+), 55 deletions(-) create mode 100644 mlx/io/gguf.cpp delete mode 100644 mlx/io/safetensor.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e8afa3c5..3732ca58a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.24) -project(mlx LANGUAGES CXX) +project(mlx LANGUAGES C CXX) # ----------------------------- Setup ----------------------------- set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") @@ -98,15 +98,6 @@ elseif (MLX_BUILD_METAL) ${QUARTZ_LIB}) endif() -MESSAGE(STATUS "Downloading json") -FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) -FetchContent_MakeAvailable(json) -target_include_directories( - mlx PUBLIC - $ - $ -) - find_library(ACCELERATE_LIBRARY Accelerate) if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 952f12c1e..3dcd3660d 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -89,6 +89,7 @@ Operations save savez savez_compressed + save_gguf save_safetensors sigmoid sign diff --git a/mlx/io/CMakeLists.txt b/mlx/io/CMakeLists.txt index f3e27b96a..b3c25b61c 100644 --- a/mlx/io/CMakeLists.txt +++ b/mlx/io/CMakeLists.txt @@ -3,4 +3,31 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp ) + +MESSAGE(STATUS "Downloading json") +FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +FetchContent_MakeAvailable(json) +target_include_directories( + mlx PUBLIC + $ + $ +) + +MESSAGE(STATUS "Downloading gguflib") +FetchContent_Declare(gguflib + GIT_REPOSITORY https://github.com/antirez/gguf-tools/ + GIT_TAG af7d88d808a7608a33723fba067036202910acb3 +) +FetchContent_MakeAvailable(gguflib) +target_include_directories( + mlx PUBLIC + $ + $ +) +add_library( + gguflib SHARED + ${gguflib_SOURCE_DIR}/fp16.c + ${gguflib_SOURCE_DIR}/gguflib.c) +target_link_libraries(mlx $) diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp new file mode 100644 index 000000000..7de0ad611 --- /dev/null +++ b/mlx/io/gguf.cpp @@ -0,0 +1,163 @@ +// Copyright © 2023 Apple Inc. + +#include "mlx/ops.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +extern "C" { +#include +} + +namespace mlx::core { + +std::optional dtype_to_gguf_tensor_type(const Dtype& dtype) { + switch (dtype) { + case float32: + return GGUF_TYPE_F32; + case float16: + return GGUF_TYPE_F16; + case int8: + return GGUF_TYPE_I8; + case int16: + return GGUF_TYPE_I16; + case int32: + return GGUF_TYPE_I32; + default: + return {}; + } +} + +std::optional gguf_type_to_dtype(const uint32_t& gguf_type) { + switch (gguf_type) { + case GGUF_TYPE_F32: + return float32; + case GGUF_TYPE_F16: + return float16; + case GGUF_TYPE_I8: + return int8; + case GGUF_TYPE_I16: + return int16; + case GGUF_TYPE_I32: + return int32; + default: + return {}; + } +} + +std::tuple extract_tensor_data(gguf_tensor* tensor) { + std::optional equivalent_dtype = gguf_type_to_dtype(tensor->type); + // If there's an equivalent type, we can simply copy. + if (equivalent_dtype.has_value()) { + allocator::Buffer buffer = allocator::malloc(tensor->bsize); + memcpy( + buffer.raw_ptr(), + tensor->weights_data, + tensor->num_weights * equivalent_dtype.value().size); + return {buffer, equivalent_dtype.value()}; + } + // Otherwise, we convert to float16. + // TODO: Add other dequantization options. + int16_t* data = gguf_tensor_to_f16(tensor); + if (data == NULL) { + throw std::runtime_error("[load_gguf] gguf_tensor_to_f16 failed"); + } + const size_t new_size = tensor->num_weights * sizeof(int16_t); + allocator::Buffer buffer = allocator::malloc(new_size); + memcpy(buffer.raw_ptr(), data, new_size); + free(data); + return {buffer, float16}; +} + +std::unordered_map load_gguf( + const std::string& file, + StreamOrDevice s) { + std::unordered_map result; + gguf_ctx* ctx = gguf_open(file.c_str()); + if (!ctx) { + throw std::runtime_error("[load_gguf] gguf_init failed"); + } + gguf_skip_key_values_section(ctx); + gguf_tensor tensor; + while (gguf_get_tensor(ctx, &tensor)) { + std::vector shape; + // The dimension order in GGML is the reverse of the order used in MLX. + for (int i = tensor.ndim - 1; i >= 0; i--) { + shape.push_back(tensor.dim[i]); + } + const auto& [data, dtype] = extract_tensor_data(&tensor); + array loaded_array = array(data, shape, dtype); + std::string name = std::string(tensor.name, tensor.namelen); + result.insert({name, loaded_array}); + } + gguf_close(ctx); + return result; +} + +void save_gguf(std::string file, std::unordered_map a) { + // Add .gguf to file name if it is not there + if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") { + file += ".gguf"; + } + gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE); + if (!ctx) { + throw std::runtime_error("[save_gguf] gguf_create failed"); + } + + // Tensor offsets are relative to data section, so we start at offset 0. + uint64_t tensor_offset = 0; + + // First, append the tensor info + for (auto& [key, arr] : a) { + arr.eval(); + + // Try to make it row contiguous + if (!arr.flags().row_contiguous) { + arr = reshape(flatten(arr), arr.shape()); + arr.eval(); + } + + // Has to be row-major now but, check one more time in case + // any of the above change in the future + if (!arr.flags().row_contiguous) { + throw std::invalid_argument( + "[save_gguf] can only serialize row-major arrays"); + } + + tensor_offset += gguf_get_alignment_padding(ctx->alignment, tensor_offset); + const std::optional gguf_type = + dtype_to_gguf_tensor_type(arr.dtype()); + if (!gguf_type.has_value()) { + std::ostringstream msg; + msg << "[save_gguf] dtype " << arr.dtype() << " is not supported"; + throw std::runtime_error(msg.str()); + } + const char* tensorname = key.c_str(); + const uint64_t namelen = key.length(); + const uint32_t num_dim = arr.ndim(); + uint64_t dim[num_dim]; + for (int i = 0; i < num_dim; i++) { + dim[i] = arr.shape()[num_dim - 1 - i]; + } + if (!gguf_append_tensor_info( + ctx, + tensorname, + namelen, + num_dim, + dim, + gguf_type.value(), + tensor_offset)) { + throw std::runtime_error("[save_gguf] gguf_append_tensor_info failed"); + } + tensor_offset += arr.nbytes(); + } + + // Then, append the tensor weights + for (const auto& [key, arr] : a) { + if (!gguf_append_tensor_data(ctx, (void*)arr.data(), arr.nbytes())) { + throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed"); + } + } + gguf_close(ctx); +} + +} // namespace mlx::core diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 1ca79441d..406169312 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -1,7 +1,32 @@ -#include "mlx/io/safetensor.h" - +// Copyright © 2023 Apple Inc. +// +#include #include +#include "mlx/io/load.h" +#include "mlx/ops.h" +#include "mlx/primitives.h" + +using json = nlohmann::json; + +#define ST_F16 "F16" +#define ST_BF16 "BF16" +#define ST_F32 "F32" + +#define ST_BOOL "BOOL" +#define ST_I8 "I8" +#define ST_I16 "I16" +#define ST_I32 "I32" +#define ST_I64 "I64" +#define ST_U8 "U8" +#define ST_U16 "U16" +#define ST_U32 "U32" +#define ST_U64 "U64" + +// Note: Complex numbers aren't in the spec yet so this could change - +// https://github.com/huggingface/safetensors/issues/389 +#define ST_C64 "C64" + namespace mlx::core { std::string dtype_to_safetensor_str(Dtype t) { diff --git a/mlx/io/safetensor.h b/mlx/io/safetensor.h deleted file mode 100644 index 104a226ce..000000000 --- a/mlx/io/safetensor.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -#include - -#include "mlx/io/load.h" -#include "mlx/ops.h" -#include "mlx/primitives.h" - -using json = nlohmann::json; - -namespace mlx::core { - -#define ST_F16 "F16" -#define ST_BF16 "BF16" -#define ST_F32 "F32" - -#define ST_BOOL "BOOL" -#define ST_I8 "I8" -#define ST_I16 "I16" -#define ST_I32 "I32" -#define ST_I64 "I64" -#define ST_U8 "U8" -#define ST_U16 "U16" -#define ST_U32 "U32" -#define ST_U64 "U64" - -// Note: Complex numbers aren't in the spec yet so this could change - -// https://github.com/huggingface/safetensors/issues/389 -#define ST_C64 "C64" -} // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index d02d39717..72534bf44 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1104,4 +1104,12 @@ void save_safetensors( void save_safetensors( const std::string& file, std::unordered_map); + +/** Load array map from .gguf file format */ +std::unordered_map load_gguf( + const std::string& file, + StreamOrDevice s = {}); + +void save_gguf(std::string file, std::unordered_map a); + } // namespace mlx::core diff --git a/python/src/load.cpp b/python/src/load.cpp index fcc1cc722..03b108d8e 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -181,6 +181,16 @@ std::unordered_map mlx_load_safetensor_helper( "[load_safetensors] Input must be a file-like object, or string"); } +std::unordered_map mlx_load_gguf_helper( + py::object file, + StreamOrDevice s) { + if (py::isinstance(file)) { // Assume .gguf file path string + return load_gguf(py::cast(file), s); + } + + throw std::invalid_argument("[load_gguf] Input must be a string"); +} + std::unordered_map mlx_load_npz_helper( py::object file, StreamOrDevice s) { @@ -264,6 +274,8 @@ DictOrArray mlx_load_helper( return mlx_load_npz_helper(file, s); } else if (format.value() == "npy") { return mlx_load_npy_helper(file, s); + } else if (format.value() == "gguf") { + return mlx_load_gguf_helper(file, s); } else { throw std::invalid_argument("[load] Unknown file format " + format.value()); } @@ -435,3 +447,13 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) { throw std::invalid_argument( "[save_safetensors] Input must be a file-like object, or string"); } + +void mlx_save_gguf_helper(py::object file, py::dict d) { + auto arrays_map = d.cast>(); + if (py::isinstance(file)) { + save_gguf(py::cast(file), arrays_map); + return; + } + + throw std::invalid_argument("[save_safetensors] Input must be a string"); +} diff --git a/python/src/load.h b/python/src/load.h index d1d8bd59c..19feb5f5a 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -19,6 +19,11 @@ std::unordered_map mlx_load_safetensor_helper( StreamOrDevice s); void mlx_save_safetensor_helper(py::object file, py::dict d); +std::unordered_map mlx_load_gguf_helper( + py::object file, + StreamOrDevice s); +void mlx_save_gguf_helper(py::object file, py::dict d); + DictOrArray mlx_load_helper( py::object file, std::optional format, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 90be116c4..2d60db6aa 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3048,7 +3048,9 @@ void init_ops(py::module_& m) { R"pbdoc( load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] - Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format. + Load array(s) from a binary file. + + The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and ``.gguf``. Args: file (file, str): File in which the array is saved. @@ -3059,6 +3061,12 @@ void init_ops(py::module_& m) { result (array, dict): A single array if loading from a ``.npy`` file or a dict mapping names to arrays if loading from a ``.npz`` or ``.safetensors`` file. + + Warning: + + When loading unsupported quantization formats from GGUF, tensors will + automatically cast to ``mx.float16`` + )pbdoc"); m.def( "save_safetensors", @@ -3070,10 +3078,28 @@ void init_ops(py::module_& m) { Save array(s) to a binary file in ``.safetensors`` format. - For more information on the format see https://huggingface.co/docs/safetensors/index. + See the `Safetensors documentation `_ + for more information on the format. Args: - file (file, str): File in which the array is saved> + file (file, str): File in which the array is saved. + arrays (dict(str, array)): The dictionary of names to arrays to be saved. + )pbdoc"); + m.def( + "save_gguf", + &mlx_save_gguf_helper, + "file"_a, + "arrays"_a, + R"pbdoc( + save_gguf(file: str, arrays: Dict[str, array]) + + Save array(s) to a binary file in ``.gguf`` format. + + See the `GGUF documentation `_ for + more information on the format. + + Args: + file (file, str): File in which the array is saved. arrays (dict(str, array)): The dictionary of names to arrays to be saved. )pbdoc"); m.def( @@ -3306,7 +3332,7 @@ void init_ops(py::module_& m) { ``dims`` dimensions of ``a`` and the first ``dims`` dimensions of ``b``. If a list of lists is provided, then sum over the corresponding dimensions of ``a`` and ``b``. (default: 2) - + Returns: result (array): The tensor dot product. )pbdoc"); diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 66cf4aa4e..3b7baba54 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -90,6 +90,33 @@ class TestLoad(mlx_tests.MLXTestCase): mx.array_equal(load_dict["test"], save_dict["test"]) ) + def test_save_and_load_gguf(self): + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + # TODO: Add support for other dtypes (self.dtypes + ["bfloat16"]) + supported_dtypes = ["float16", "float32", "int8", "int16", "int32"] + for dt in supported_dtypes: + with self.subTest(dtype=dt): + for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): + with self.subTest(shape=shape): + save_file_mlx = os.path.join( + self.test_dir, f"mlx_{dt}_{i}_fs.gguf" + ) + save_dict = { + "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + } + + mx.save_gguf(save_file_mlx, save_dict) + load_dict = mx.load(save_file_mlx) + + self.assertTrue("test" in load_dict) + self.assertTrue( + mx.array_equal(load_dict["test"], save_dict["test"]) + ) + def test_save_and_load_fs(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) @@ -194,13 +221,24 @@ class TestLoad(mlx_tests.MLXTestCase): aload = mx.load(save_file)["a"] self.assertTrue(mx.array_equal(a, aload)) - # safetensors only works with row contiguous + save_file = os.path.join(self.test_dir, "a.gguf") + mx.save_gguf(save_file, {"a": a}) + aload = mx.load(save_file)["a"] + self.assertTrue(mx.array_equal(a, aload)) + + # safetensors and gguf only work with row contiguous # make sure col contiguous is handled properly + save_file = os.path.join(self.test_dir, "a.safetensors") a = mx.arange(4).reshape(2, 2).T mx.save_safetensors(save_file, {"a": a}) aload = mx.load(save_file)["a"] self.assertTrue(mx.array_equal(a, aload)) + save_file = os.path.join(self.test_dir, "a.gguf") + mx.save_gguf(save_file, {"a": a}) + aload = mx.load(save_file)["a"] + self.assertTrue(mx.array_equal(a, aload)) + if __name__ == "__main__": unittest.main() diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index edff1aff6..8b77a2eb3 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -20,20 +20,53 @@ TEST_CASE("test save_safetensors") { map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test2", ones({2, 2})}); save_safetensors(file_path, map); - auto safeDict = load_safetensors(file_path); - CHECK_EQ(safeDict.size(), 2); - CHECK_EQ(safeDict.count("test"), 1); - CHECK_EQ(safeDict.count("test2"), 1); - array test = safeDict.at("test"); + auto dict = load_safetensors(file_path); + CHECK_EQ(dict.size(), 2); + CHECK_EQ(dict.count("test"), 1); + CHECK_EQ(dict.count("test2"), 1); + array test = dict.at("test"); CHECK_EQ(test.dtype(), float32); CHECK_EQ(test.shape(), std::vector({4})); CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item()); - array test2 = safeDict.at("test2"); + array test2 = dict.at("test2"); CHECK_EQ(test2.dtype(), float32); CHECK_EQ(test2.shape(), std::vector({2, 2})); CHECK(array_equal(test2, ones({2, 2})).item()); } +TEST_CASE("test gguf") { + std::string file_path = get_temp_file("test_arr.gguf"); + using dict = std::unordered_map; + dict map = { + {"test", array({1.0f, 2.0f, 3.0f, 4.0f})}, + {"test2", reshape(arange(6), {3, 2})}}; + + save_gguf(file_path, map); + auto loaded = load_gguf(file_path); + CHECK_EQ(loaded.size(), 2); + CHECK_EQ(loaded.count("test"), 1); + CHECK_EQ(loaded.count("test2"), 1); + for (auto [k, v] : loaded) { + CHECK(array_equal(v, map.at(k)).item()); + } + + std::vector unsupported_types = { + bool_, uint8, uint32, uint64, int64, bfloat16, complex64}; + for (auto t : unsupported_types) { + dict to_save = {{"test", astype(arange(5), t)}}; + CHECK_THROWS(save_gguf(file_path, to_save)); + } + + std::vector supported_types = {int8, int32, float16}; + for (auto t : supported_types) { + auto arr = astype(arange(5), t); + dict to_save = {{"test", arr}}; + save_gguf(file_path, to_save); + auto loaded = load_gguf(file_path); + CHECK(array_equal(loaded.at("test"), arr).item()); + } +} + TEST_CASE("test single array serialization") { // Basic test {