mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	GGUF support (#350)
* Initial GGUF support for tensor fields. --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -3,4 +3,31 @@ target_sources( | ||||
|   PRIVATE | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp | ||||
| ) | ||||
|  | ||||
| MESSAGE(STATUS "Downloading json") | ||||
| FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) | ||||
| FetchContent_MakeAvailable(json) | ||||
| target_include_directories( | ||||
|     mlx PUBLIC | ||||
|     $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann> | ||||
|     $<INSTALL_INTERFACE:include/json> | ||||
| ) | ||||
|  | ||||
| MESSAGE(STATUS "Downloading gguflib") | ||||
| FetchContent_Declare(gguflib | ||||
|     GIT_REPOSITORY     https://github.com/antirez/gguf-tools/ | ||||
|     GIT_TAG            af7d88d808a7608a33723fba067036202910acb3  | ||||
| ) | ||||
| FetchContent_MakeAvailable(gguflib) | ||||
| target_include_directories( | ||||
|     mlx PUBLIC | ||||
|     $<BUILD_INTERFACE:${gguflib_SOURCE_DIR}> | ||||
|     $<INSTALL_INTERFACE:include/gguflib> | ||||
| ) | ||||
| add_library( | ||||
|   gguflib SHARED | ||||
|   ${gguflib_SOURCE_DIR}/fp16.c | ||||
|   ${gguflib_SOURCE_DIR}/gguflib.c) | ||||
| target_link_libraries(mlx $<BUILD_INTERFACE:gguflib>) | ||||
|   | ||||
							
								
								
									
										163
									
								
								mlx/io/gguf.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										163
									
								
								mlx/io/gguf.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,163 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
|  | ||||
| #include "mlx/ops.h" | ||||
| #include "mlx/primitives.h" | ||||
| #include "mlx/utils.h" | ||||
|  | ||||
| extern "C" { | ||||
| #include <gguflib.h> | ||||
| } | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| std::optional<uint32_t> dtype_to_gguf_tensor_type(const Dtype& dtype) { | ||||
|   switch (dtype) { | ||||
|     case float32: | ||||
|       return GGUF_TYPE_F32; | ||||
|     case float16: | ||||
|       return GGUF_TYPE_F16; | ||||
|     case int8: | ||||
|       return GGUF_TYPE_I8; | ||||
|     case int16: | ||||
|       return GGUF_TYPE_I16; | ||||
|     case int32: | ||||
|       return GGUF_TYPE_I32; | ||||
|     default: | ||||
|       return {}; | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) { | ||||
|   switch (gguf_type) { | ||||
|     case GGUF_TYPE_F32: | ||||
|       return float32; | ||||
|     case GGUF_TYPE_F16: | ||||
|       return float16; | ||||
|     case GGUF_TYPE_I8: | ||||
|       return int8; | ||||
|     case GGUF_TYPE_I16: | ||||
|       return int16; | ||||
|     case GGUF_TYPE_I32: | ||||
|       return int32; | ||||
|     default: | ||||
|       return {}; | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) { | ||||
|   std::optional<Dtype> equivalent_dtype = gguf_type_to_dtype(tensor->type); | ||||
|   // If there's an equivalent type, we can simply copy. | ||||
|   if (equivalent_dtype.has_value()) { | ||||
|     allocator::Buffer buffer = allocator::malloc(tensor->bsize); | ||||
|     memcpy( | ||||
|         buffer.raw_ptr(), | ||||
|         tensor->weights_data, | ||||
|         tensor->num_weights * equivalent_dtype.value().size); | ||||
|     return {buffer, equivalent_dtype.value()}; | ||||
|   } | ||||
|   // Otherwise, we convert to float16. | ||||
|   // TODO: Add other dequantization options. | ||||
|   int16_t* data = gguf_tensor_to_f16(tensor); | ||||
|   if (data == NULL) { | ||||
|     throw std::runtime_error("[load_gguf] gguf_tensor_to_f16 failed"); | ||||
|   } | ||||
|   const size_t new_size = tensor->num_weights * sizeof(int16_t); | ||||
|   allocator::Buffer buffer = allocator::malloc(new_size); | ||||
|   memcpy(buffer.raw_ptr(), data, new_size); | ||||
|   free(data); | ||||
|   return {buffer, float16}; | ||||
| } | ||||
|  | ||||
| std::unordered_map<std::string, array> load_gguf( | ||||
|     const std::string& file, | ||||
|     StreamOrDevice s) { | ||||
|   std::unordered_map<std::string, array> result; | ||||
|   gguf_ctx* ctx = gguf_open(file.c_str()); | ||||
|   if (!ctx) { | ||||
|     throw std::runtime_error("[load_gguf] gguf_init failed"); | ||||
|   } | ||||
|   gguf_skip_key_values_section(ctx); | ||||
|   gguf_tensor tensor; | ||||
|   while (gguf_get_tensor(ctx, &tensor)) { | ||||
|     std::vector<int> shape; | ||||
|     // The dimension order in GGML is the reverse of the order used in MLX. | ||||
|     for (int i = tensor.ndim - 1; i >= 0; i--) { | ||||
|       shape.push_back(tensor.dim[i]); | ||||
|     } | ||||
|     const auto& [data, dtype] = extract_tensor_data(&tensor); | ||||
|     array loaded_array = array(data, shape, dtype); | ||||
|     std::string name = std::string(tensor.name, tensor.namelen); | ||||
|     result.insert({name, loaded_array}); | ||||
|   } | ||||
|   gguf_close(ctx); | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| void save_gguf(std::string file, std::unordered_map<std::string, array> a) { | ||||
|   // Add .gguf to file name if it is not there | ||||
|   if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") { | ||||
|     file += ".gguf"; | ||||
|   } | ||||
|   gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE); | ||||
|   if (!ctx) { | ||||
|     throw std::runtime_error("[save_gguf] gguf_create failed"); | ||||
|   } | ||||
|  | ||||
|   // Tensor offsets are relative to data section, so we start at offset 0. | ||||
|   uint64_t tensor_offset = 0; | ||||
|  | ||||
|   // First, append the tensor info | ||||
|   for (auto& [key, arr] : a) { | ||||
|     arr.eval(); | ||||
|  | ||||
|     // Try to make it row contiguous | ||||
|     if (!arr.flags().row_contiguous) { | ||||
|       arr = reshape(flatten(arr), arr.shape()); | ||||
|       arr.eval(); | ||||
|     } | ||||
|  | ||||
|     // Has to be row-major now but, check one more time in case | ||||
|     // any of the above change in the future | ||||
|     if (!arr.flags().row_contiguous) { | ||||
|       throw std::invalid_argument( | ||||
|           "[save_gguf] can only serialize row-major arrays"); | ||||
|     } | ||||
|  | ||||
|     tensor_offset += gguf_get_alignment_padding(ctx->alignment, tensor_offset); | ||||
|     const std::optional<uint32_t> gguf_type = | ||||
|         dtype_to_gguf_tensor_type(arr.dtype()); | ||||
|     if (!gguf_type.has_value()) { | ||||
|       std::ostringstream msg; | ||||
|       msg << "[save_gguf] dtype " << arr.dtype() << " is not supported"; | ||||
|       throw std::runtime_error(msg.str()); | ||||
|     } | ||||
|     const char* tensorname = key.c_str(); | ||||
|     const uint64_t namelen = key.length(); | ||||
|     const uint32_t num_dim = arr.ndim(); | ||||
|     uint64_t dim[num_dim]; | ||||
|     for (int i = 0; i < num_dim; i++) { | ||||
|       dim[i] = arr.shape()[num_dim - 1 - i]; | ||||
|     } | ||||
|     if (!gguf_append_tensor_info( | ||||
|             ctx, | ||||
|             tensorname, | ||||
|             namelen, | ||||
|             num_dim, | ||||
|             dim, | ||||
|             gguf_type.value(), | ||||
|             tensor_offset)) { | ||||
|       throw std::runtime_error("[save_gguf] gguf_append_tensor_info failed"); | ||||
|     } | ||||
|     tensor_offset += arr.nbytes(); | ||||
|   } | ||||
|  | ||||
|   // Then, append the tensor weights | ||||
|   for (const auto& [key, arr] : a) { | ||||
|     if (!gguf_append_tensor_data(ctx, (void*)arr.data<void>(), arr.nbytes())) { | ||||
|       throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed"); | ||||
|     } | ||||
|   } | ||||
|   gguf_close(ctx); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
| @@ -1,7 +1,32 @@ | ||||
| #include "mlx/io/safetensor.h" | ||||
|  | ||||
| // Copyright © 2023 Apple Inc. | ||||
| // | ||||
| #include <json.hpp> | ||||
| #include <stack> | ||||
|  | ||||
| #include "mlx/io/load.h" | ||||
| #include "mlx/ops.h" | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| using json = nlohmann::json; | ||||
|  | ||||
| #define ST_F16 "F16" | ||||
| #define ST_BF16 "BF16" | ||||
| #define ST_F32 "F32" | ||||
|  | ||||
| #define ST_BOOL "BOOL" | ||||
| #define ST_I8 "I8" | ||||
| #define ST_I16 "I16" | ||||
| #define ST_I32 "I32" | ||||
| #define ST_I64 "I64" | ||||
| #define ST_U8 "U8" | ||||
| #define ST_U16 "U16" | ||||
| #define ST_U32 "U32" | ||||
| #define ST_U64 "U64" | ||||
|  | ||||
| // Note: Complex numbers aren't in the spec yet so this could change - | ||||
| // https://github.com/huggingface/safetensors/issues/389 | ||||
| #define ST_C64 "C64" | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| std::string dtype_to_safetensor_str(Dtype t) { | ||||
|   | ||||
| @@ -1,32 +0,0 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include <json.hpp> | ||||
|  | ||||
| #include "mlx/io/load.h" | ||||
| #include "mlx/ops.h" | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| using json = nlohmann::json; | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| #define ST_F16 "F16" | ||||
| #define ST_BF16 "BF16" | ||||
| #define ST_F32 "F32" | ||||
|  | ||||
| #define ST_BOOL "BOOL" | ||||
| #define ST_I8 "I8" | ||||
| #define ST_I16 "I16" | ||||
| #define ST_I32 "I32" | ||||
| #define ST_I64 "I64" | ||||
| #define ST_U8 "U8" | ||||
| #define ST_U16 "U16" | ||||
| #define ST_U32 "U32" | ||||
| #define ST_U64 "U64" | ||||
|  | ||||
| // Note: Complex numbers aren't in the spec yet so this could change - | ||||
| // https://github.com/huggingface/safetensors/issues/389 | ||||
| #define ST_C64 "C64" | ||||
| } // namespace mlx::core | ||||
		Reference in New Issue
	
	Block a user
	 Juarez Bochi
					Juarez Bochi