mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	GGUF support (#350)
* Initial GGUF support for tensor fields. --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		@@ -1,6 +1,6 @@
 | 
				
			|||||||
cmake_minimum_required(VERSION 3.24)
 | 
					cmake_minimum_required(VERSION 3.24)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
project(mlx LANGUAGES CXX)
 | 
					project(mlx LANGUAGES C CXX)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# ----------------------------- Setup -----------------------------
 | 
					# ----------------------------- Setup -----------------------------
 | 
				
			||||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
 | 
					set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
 | 
				
			||||||
@@ -98,15 +98,6 @@ elseif (MLX_BUILD_METAL)
 | 
				
			|||||||
    ${QUARTZ_LIB})
 | 
					    ${QUARTZ_LIB})
 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MESSAGE(STATUS "Downloading json")
 | 
					 | 
				
			||||||
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
 | 
					 | 
				
			||||||
FetchContent_MakeAvailable(json)
 | 
					 | 
				
			||||||
target_include_directories(
 | 
					 | 
				
			||||||
    mlx PUBLIC
 | 
					 | 
				
			||||||
    $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
 | 
					 | 
				
			||||||
    $<INSTALL_INTERFACE:include/json>
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
find_library(ACCELERATE_LIBRARY Accelerate)
 | 
					find_library(ACCELERATE_LIBRARY Accelerate)
 | 
				
			||||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
 | 
					if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
 | 
				
			||||||
  message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
 | 
					  message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -89,6 +89,7 @@ Operations
 | 
				
			|||||||
   save
 | 
					   save
 | 
				
			||||||
   savez
 | 
					   savez
 | 
				
			||||||
   savez_compressed
 | 
					   savez_compressed
 | 
				
			||||||
 | 
					   save_gguf
 | 
				
			||||||
   save_safetensors
 | 
					   save_safetensors
 | 
				
			||||||
   sigmoid
 | 
					   sigmoid
 | 
				
			||||||
   sign
 | 
					   sign
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,4 +3,31 @@ target_sources(
 | 
				
			|||||||
  PRIVATE
 | 
					  PRIVATE
 | 
				
			||||||
  ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
 | 
					  ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
 | 
				
			||||||
  ${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
 | 
					  ${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
 | 
				
			||||||
 | 
					  ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MESSAGE(STATUS "Downloading json")
 | 
				
			||||||
 | 
					FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
 | 
				
			||||||
 | 
					FetchContent_MakeAvailable(json)
 | 
				
			||||||
 | 
					target_include_directories(
 | 
				
			||||||
 | 
					    mlx PUBLIC
 | 
				
			||||||
 | 
					    $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
 | 
				
			||||||
 | 
					    $<INSTALL_INTERFACE:include/json>
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MESSAGE(STATUS "Downloading gguflib")
 | 
				
			||||||
 | 
					FetchContent_Declare(gguflib
 | 
				
			||||||
 | 
					    GIT_REPOSITORY     https://github.com/antirez/gguf-tools/
 | 
				
			||||||
 | 
					    GIT_TAG            af7d88d808a7608a33723fba067036202910acb3 
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					FetchContent_MakeAvailable(gguflib)
 | 
				
			||||||
 | 
					target_include_directories(
 | 
				
			||||||
 | 
					    mlx PUBLIC
 | 
				
			||||||
 | 
					    $<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
 | 
				
			||||||
 | 
					    $<INSTALL_INTERFACE:include/gguflib>
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					add_library(
 | 
				
			||||||
 | 
					  gguflib SHARED
 | 
				
			||||||
 | 
					  ${gguflib_SOURCE_DIR}/fp16.c
 | 
				
			||||||
 | 
					  ${gguflib_SOURCE_DIR}/gguflib.c)
 | 
				
			||||||
 | 
					target_link_libraries(mlx $<BUILD_INTERFACE:gguflib>)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										163
									
								
								mlx/io/gguf.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										163
									
								
								mlx/io/gguf.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,163 @@
 | 
				
			|||||||
 | 
					// Copyright © 2023 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlx/ops.h"
 | 
				
			||||||
 | 
					#include "mlx/primitives.h"
 | 
				
			||||||
 | 
					#include "mlx/utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					extern "C" {
 | 
				
			||||||
 | 
					#include <gguflib.h>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mlx::core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::optional<uint32_t> dtype_to_gguf_tensor_type(const Dtype& dtype) {
 | 
				
			||||||
 | 
					  switch (dtype) {
 | 
				
			||||||
 | 
					    case float32:
 | 
				
			||||||
 | 
					      return GGUF_TYPE_F32;
 | 
				
			||||||
 | 
					    case float16:
 | 
				
			||||||
 | 
					      return GGUF_TYPE_F16;
 | 
				
			||||||
 | 
					    case int8:
 | 
				
			||||||
 | 
					      return GGUF_TYPE_I8;
 | 
				
			||||||
 | 
					    case int16:
 | 
				
			||||||
 | 
					      return GGUF_TYPE_I16;
 | 
				
			||||||
 | 
					    case int32:
 | 
				
			||||||
 | 
					      return GGUF_TYPE_I32;
 | 
				
			||||||
 | 
					    default:
 | 
				
			||||||
 | 
					      return {};
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
 | 
				
			||||||
 | 
					  switch (gguf_type) {
 | 
				
			||||||
 | 
					    case GGUF_TYPE_F32:
 | 
				
			||||||
 | 
					      return float32;
 | 
				
			||||||
 | 
					    case GGUF_TYPE_F16:
 | 
				
			||||||
 | 
					      return float16;
 | 
				
			||||||
 | 
					    case GGUF_TYPE_I8:
 | 
				
			||||||
 | 
					      return int8;
 | 
				
			||||||
 | 
					    case GGUF_TYPE_I16:
 | 
				
			||||||
 | 
					      return int16;
 | 
				
			||||||
 | 
					    case GGUF_TYPE_I32:
 | 
				
			||||||
 | 
					      return int32;
 | 
				
			||||||
 | 
					    default:
 | 
				
			||||||
 | 
					      return {};
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
 | 
				
			||||||
 | 
					  std::optional<Dtype> equivalent_dtype = gguf_type_to_dtype(tensor->type);
 | 
				
			||||||
 | 
					  // If there's an equivalent type, we can simply copy.
 | 
				
			||||||
 | 
					  if (equivalent_dtype.has_value()) {
 | 
				
			||||||
 | 
					    allocator::Buffer buffer = allocator::malloc(tensor->bsize);
 | 
				
			||||||
 | 
					    memcpy(
 | 
				
			||||||
 | 
					        buffer.raw_ptr(),
 | 
				
			||||||
 | 
					        tensor->weights_data,
 | 
				
			||||||
 | 
					        tensor->num_weights * equivalent_dtype.value().size);
 | 
				
			||||||
 | 
					    return {buffer, equivalent_dtype.value()};
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Otherwise, we convert to float16.
 | 
				
			||||||
 | 
					  // TODO: Add other dequantization options.
 | 
				
			||||||
 | 
					  int16_t* data = gguf_tensor_to_f16(tensor);
 | 
				
			||||||
 | 
					  if (data == NULL) {
 | 
				
			||||||
 | 
					    throw std::runtime_error("[load_gguf] gguf_tensor_to_f16 failed");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  const size_t new_size = tensor->num_weights * sizeof(int16_t);
 | 
				
			||||||
 | 
					  allocator::Buffer buffer = allocator::malloc(new_size);
 | 
				
			||||||
 | 
					  memcpy(buffer.raw_ptr(), data, new_size);
 | 
				
			||||||
 | 
					  free(data);
 | 
				
			||||||
 | 
					  return {buffer, float16};
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::unordered_map<std::string, array> load_gguf(
 | 
				
			||||||
 | 
					    const std::string& file,
 | 
				
			||||||
 | 
					    StreamOrDevice s) {
 | 
				
			||||||
 | 
					  std::unordered_map<std::string, array> result;
 | 
				
			||||||
 | 
					  gguf_ctx* ctx = gguf_open(file.c_str());
 | 
				
			||||||
 | 
					  if (!ctx) {
 | 
				
			||||||
 | 
					    throw std::runtime_error("[load_gguf] gguf_init failed");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  gguf_skip_key_values_section(ctx);
 | 
				
			||||||
 | 
					  gguf_tensor tensor;
 | 
				
			||||||
 | 
					  while (gguf_get_tensor(ctx, &tensor)) {
 | 
				
			||||||
 | 
					    std::vector<int> shape;
 | 
				
			||||||
 | 
					    // The dimension order in GGML is the reverse of the order used in MLX.
 | 
				
			||||||
 | 
					    for (int i = tensor.ndim - 1; i >= 0; i--) {
 | 
				
			||||||
 | 
					      shape.push_back(tensor.dim[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    const auto& [data, dtype] = extract_tensor_data(&tensor);
 | 
				
			||||||
 | 
					    array loaded_array = array(data, shape, dtype);
 | 
				
			||||||
 | 
					    std::string name = std::string(tensor.name, tensor.namelen);
 | 
				
			||||||
 | 
					    result.insert({name, loaded_array});
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  gguf_close(ctx);
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void save_gguf(std::string file, std::unordered_map<std::string, array> a) {
 | 
				
			||||||
 | 
					  // Add .gguf to file name if it is not there
 | 
				
			||||||
 | 
					  if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
 | 
				
			||||||
 | 
					    file += ".gguf";
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE);
 | 
				
			||||||
 | 
					  if (!ctx) {
 | 
				
			||||||
 | 
					    throw std::runtime_error("[save_gguf] gguf_create failed");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Tensor offsets are relative to data section, so we start at offset 0.
 | 
				
			||||||
 | 
					  uint64_t tensor_offset = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // First, append the tensor info
 | 
				
			||||||
 | 
					  for (auto& [key, arr] : a) {
 | 
				
			||||||
 | 
					    arr.eval();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Try to make it row contiguous
 | 
				
			||||||
 | 
					    if (!arr.flags().row_contiguous) {
 | 
				
			||||||
 | 
					      arr = reshape(flatten(arr), arr.shape());
 | 
				
			||||||
 | 
					      arr.eval();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Has to be row-major now but, check one more time in case
 | 
				
			||||||
 | 
					    // any of the above change in the future
 | 
				
			||||||
 | 
					    if (!arr.flags().row_contiguous) {
 | 
				
			||||||
 | 
					      throw std::invalid_argument(
 | 
				
			||||||
 | 
					          "[save_gguf] can only serialize row-major arrays");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tensor_offset += gguf_get_alignment_padding(ctx->alignment, tensor_offset);
 | 
				
			||||||
 | 
					    const std::optional<uint32_t> gguf_type =
 | 
				
			||||||
 | 
					        dtype_to_gguf_tensor_type(arr.dtype());
 | 
				
			||||||
 | 
					    if (!gguf_type.has_value()) {
 | 
				
			||||||
 | 
					      std::ostringstream msg;
 | 
				
			||||||
 | 
					      msg << "[save_gguf] dtype " << arr.dtype() << " is not supported";
 | 
				
			||||||
 | 
					      throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    const char* tensorname = key.c_str();
 | 
				
			||||||
 | 
					    const uint64_t namelen = key.length();
 | 
				
			||||||
 | 
					    const uint32_t num_dim = arr.ndim();
 | 
				
			||||||
 | 
					    uint64_t dim[num_dim];
 | 
				
			||||||
 | 
					    for (int i = 0; i < num_dim; i++) {
 | 
				
			||||||
 | 
					      dim[i] = arr.shape()[num_dim - 1 - i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (!gguf_append_tensor_info(
 | 
				
			||||||
 | 
					            ctx,
 | 
				
			||||||
 | 
					            tensorname,
 | 
				
			||||||
 | 
					            namelen,
 | 
				
			||||||
 | 
					            num_dim,
 | 
				
			||||||
 | 
					            dim,
 | 
				
			||||||
 | 
					            gguf_type.value(),
 | 
				
			||||||
 | 
					            tensor_offset)) {
 | 
				
			||||||
 | 
					      throw std::runtime_error("[save_gguf] gguf_append_tensor_info failed");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    tensor_offset += arr.nbytes();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Then, append the tensor weights
 | 
				
			||||||
 | 
					  for (const auto& [key, arr] : a) {
 | 
				
			||||||
 | 
					    if (!gguf_append_tensor_data(ctx, (void*)arr.data<void>(), arr.nbytes())) {
 | 
				
			||||||
 | 
					      throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  gguf_close(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} // namespace mlx::core
 | 
				
			||||||
@@ -1,7 +1,32 @@
 | 
				
			|||||||
#include "mlx/io/safetensor.h"
 | 
					// Copyright © 2023 Apple Inc.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					#include <json.hpp>
 | 
				
			||||||
#include <stack>
 | 
					#include <stack>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlx/io/load.h"
 | 
				
			||||||
 | 
					#include "mlx/ops.h"
 | 
				
			||||||
 | 
					#include "mlx/primitives.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using json = nlohmann::json;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define ST_F16 "F16"
 | 
				
			||||||
 | 
					#define ST_BF16 "BF16"
 | 
				
			||||||
 | 
					#define ST_F32 "F32"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define ST_BOOL "BOOL"
 | 
				
			||||||
 | 
					#define ST_I8 "I8"
 | 
				
			||||||
 | 
					#define ST_I16 "I16"
 | 
				
			||||||
 | 
					#define ST_I32 "I32"
 | 
				
			||||||
 | 
					#define ST_I64 "I64"
 | 
				
			||||||
 | 
					#define ST_U8 "U8"
 | 
				
			||||||
 | 
					#define ST_U16 "U16"
 | 
				
			||||||
 | 
					#define ST_U32 "U32"
 | 
				
			||||||
 | 
					#define ST_U64 "U64"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Note: Complex numbers aren't in the spec yet so this could change -
 | 
				
			||||||
 | 
					// https://github.com/huggingface/safetensors/issues/389
 | 
				
			||||||
 | 
					#define ST_C64 "C64"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mlx::core {
 | 
					namespace mlx::core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::string dtype_to_safetensor_str(Dtype t) {
 | 
					std::string dtype_to_safetensor_str(Dtype t) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,32 +0,0 @@
 | 
				
			|||||||
// Copyright © 2023 Apple Inc.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#pragma once
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#include <json.hpp>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#include "mlx/io/load.h"
 | 
					 | 
				
			||||||
#include "mlx/ops.h"
 | 
					 | 
				
			||||||
#include "mlx/primitives.h"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
using json = nlohmann::json;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
namespace mlx::core {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#define ST_F16 "F16"
 | 
					 | 
				
			||||||
#define ST_BF16 "BF16"
 | 
					 | 
				
			||||||
#define ST_F32 "F32"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#define ST_BOOL "BOOL"
 | 
					 | 
				
			||||||
#define ST_I8 "I8"
 | 
					 | 
				
			||||||
#define ST_I16 "I16"
 | 
					 | 
				
			||||||
#define ST_I32 "I32"
 | 
					 | 
				
			||||||
#define ST_I64 "I64"
 | 
					 | 
				
			||||||
#define ST_U8 "U8"
 | 
					 | 
				
			||||||
#define ST_U16 "U16"
 | 
					 | 
				
			||||||
#define ST_U32 "U32"
 | 
					 | 
				
			||||||
#define ST_U64 "U64"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Note: Complex numbers aren't in the spec yet so this could change -
 | 
					 | 
				
			||||||
// https://github.com/huggingface/safetensors/issues/389
 | 
					 | 
				
			||||||
#define ST_C64 "C64"
 | 
					 | 
				
			||||||
} // namespace mlx::core
 | 
					 | 
				
			||||||
@@ -1104,4 +1104,12 @@ void save_safetensors(
 | 
				
			|||||||
void save_safetensors(
 | 
					void save_safetensors(
 | 
				
			||||||
    const std::string& file,
 | 
					    const std::string& file,
 | 
				
			||||||
    std::unordered_map<std::string, array>);
 | 
					    std::unordered_map<std::string, array>);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Load array map from .gguf file format */
 | 
				
			||||||
 | 
					std::unordered_map<std::string, array> load_gguf(
 | 
				
			||||||
 | 
					    const std::string& file,
 | 
				
			||||||
 | 
					    StreamOrDevice s = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void save_gguf(std::string file, std::unordered_map<std::string, array> a);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core
 | 
					} // namespace mlx::core
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -181,6 +181,16 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
 | 
				
			|||||||
      "[load_safetensors] Input must be a file-like object, or string");
 | 
					      "[load_safetensors] Input must be a file-like object, or string");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::unordered_map<std::string, array> mlx_load_gguf_helper(
 | 
				
			||||||
 | 
					    py::object file,
 | 
				
			||||||
 | 
					    StreamOrDevice s) {
 | 
				
			||||||
 | 
					  if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
 | 
				
			||||||
 | 
					    return load_gguf(py::cast<std::string>(file), s);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  throw std::invalid_argument("[load_gguf] Input must be a string");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::unordered_map<std::string, array> mlx_load_npz_helper(
 | 
					std::unordered_map<std::string, array> mlx_load_npz_helper(
 | 
				
			||||||
    py::object file,
 | 
					    py::object file,
 | 
				
			||||||
    StreamOrDevice s) {
 | 
					    StreamOrDevice s) {
 | 
				
			||||||
@@ -264,6 +274,8 @@ DictOrArray mlx_load_helper(
 | 
				
			|||||||
    return mlx_load_npz_helper(file, s);
 | 
					    return mlx_load_npz_helper(file, s);
 | 
				
			||||||
  } else if (format.value() == "npy") {
 | 
					  } else if (format.value() == "npy") {
 | 
				
			||||||
    return mlx_load_npy_helper(file, s);
 | 
					    return mlx_load_npy_helper(file, s);
 | 
				
			||||||
 | 
					  } else if (format.value() == "gguf") {
 | 
				
			||||||
 | 
					    return mlx_load_gguf_helper(file, s);
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    throw std::invalid_argument("[load] Unknown file format " + format.value());
 | 
					    throw std::invalid_argument("[load] Unknown file format " + format.value());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -435,3 +447,13 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) {
 | 
				
			|||||||
  throw std::invalid_argument(
 | 
					  throw std::invalid_argument(
 | 
				
			||||||
      "[save_safetensors] Input must be a file-like object, or string");
 | 
					      "[save_safetensors] Input must be a file-like object, or string");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void mlx_save_gguf_helper(py::object file, py::dict d) {
 | 
				
			||||||
 | 
					  auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
 | 
				
			||||||
 | 
					  if (py::isinstance<py::str>(file)) {
 | 
				
			||||||
 | 
					    save_gguf(py::cast<std::string>(file), arrays_map);
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  throw std::invalid_argument("[save_safetensors] Input must be a string");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,6 +19,11 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
 | 
				
			|||||||
    StreamOrDevice s);
 | 
					    StreamOrDevice s);
 | 
				
			||||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
 | 
					void mlx_save_safetensor_helper(py::object file, py::dict d);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::unordered_map<std::string, array> mlx_load_gguf_helper(
 | 
				
			||||||
 | 
					    py::object file,
 | 
				
			||||||
 | 
					    StreamOrDevice s);
 | 
				
			||||||
 | 
					void mlx_save_gguf_helper(py::object file, py::dict d);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DictOrArray mlx_load_helper(
 | 
					DictOrArray mlx_load_helper(
 | 
				
			||||||
    py::object file,
 | 
					    py::object file,
 | 
				
			||||||
    std::optional<std::string> format,
 | 
					    std::optional<std::string> format,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3048,7 +3048,9 @@ void init_ops(py::module_& m) {
 | 
				
			|||||||
      R"pbdoc(
 | 
					      R"pbdoc(
 | 
				
			||||||
        load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
 | 
					        load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
 | 
					        Load array(s) from a binary file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and ``.gguf``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            file (file, str): File in which the array is saved.
 | 
					            file (file, str): File in which the array is saved.
 | 
				
			||||||
@@ -3059,6 +3061,12 @@ void init_ops(py::module_& m) {
 | 
				
			|||||||
            result (array, dict):
 | 
					            result (array, dict):
 | 
				
			||||||
                A single array if loading from a ``.npy`` file or a dict mapping
 | 
					                A single array if loading from a ``.npy`` file or a dict mapping
 | 
				
			||||||
                names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
 | 
					                names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Warning:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          When loading unsupported quantization formats from GGUF, tensors will
 | 
				
			||||||
 | 
					          automatically cast to ``mx.float16``
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
      "save_safetensors",
 | 
					      "save_safetensors",
 | 
				
			||||||
@@ -3070,10 +3078,28 @@ void init_ops(py::module_& m) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        Save array(s) to a binary file in ``.safetensors`` format.
 | 
					        Save array(s) to a binary file in ``.safetensors`` format.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        For more information on the format see https://huggingface.co/docs/safetensors/index.
 | 
					        See the `Safetensors documentation <https://huggingface.co/docs/safetensors/index>`_
 | 
				
			||||||
 | 
					        for more information on the format.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            file (file, str): File in which the array is saved>
 | 
					            file (file, str): File in which the array is saved.
 | 
				
			||||||
 | 
					            arrays (dict(str, array)): The dictionary of names to arrays to be saved.
 | 
				
			||||||
 | 
					      )pbdoc");
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					      "save_gguf",
 | 
				
			||||||
 | 
					      &mlx_save_gguf_helper,
 | 
				
			||||||
 | 
					      "file"_a,
 | 
				
			||||||
 | 
					      "arrays"_a,
 | 
				
			||||||
 | 
					      R"pbdoc(
 | 
				
			||||||
 | 
					        save_gguf(file: str, arrays: Dict[str, array])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Save array(s) to a binary file in ``.gguf`` format.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        See the `GGUF documentation <https://github.com/ggerganov/ggml/blob/master/docs/gguf.md>`_ for
 | 
				
			||||||
 | 
					        more information on the format.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            file (file, str): File in which the array is saved.
 | 
				
			||||||
            arrays (dict(str, array)): The dictionary of names to arrays to be saved.
 | 
					            arrays (dict(str, array)): The dictionary of names to arrays to be saved.
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
@@ -3306,7 +3332,7 @@ void init_ops(py::module_& m) {
 | 
				
			|||||||
            ``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
 | 
					            ``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
 | 
				
			||||||
            ``b``. If a list of lists is provided, then sum over the
 | 
					            ``b``. If a list of lists is provided, then sum over the
 | 
				
			||||||
            corresponding dimensions of ``a`` and ``b``. (default: 2)
 | 
					            corresponding dimensions of ``a`` and ``b``. (default: 2)
 | 
				
			||||||
        
 | 
					
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
          result (array): The tensor dot product.
 | 
					          result (array): The tensor dot product.
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -90,6 +90,33 @@ class TestLoad(mlx_tests.MLXTestCase):
 | 
				
			|||||||
                            mx.array_equal(load_dict["test"], save_dict["test"])
 | 
					                            mx.array_equal(load_dict["test"], save_dict["test"])
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_save_and_load_gguf(self):
 | 
				
			||||||
 | 
					        if not os.path.isdir(self.test_dir):
 | 
				
			||||||
 | 
					            os.mkdir(self.test_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # TODO: Add support for other dtypes (self.dtypes + ["bfloat16"])
 | 
				
			||||||
 | 
					        supported_dtypes = ["float16", "float32", "int8", "int16", "int32"]
 | 
				
			||||||
 | 
					        for dt in supported_dtypes:
 | 
				
			||||||
 | 
					            with self.subTest(dtype=dt):
 | 
				
			||||||
 | 
					                for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
 | 
				
			||||||
 | 
					                    with self.subTest(shape=shape):
 | 
				
			||||||
 | 
					                        save_file_mlx = os.path.join(
 | 
				
			||||||
 | 
					                            self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                        save_dict = {
 | 
				
			||||||
 | 
					                            "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
 | 
				
			||||||
 | 
					                            if dt in ["float32", "float16", "bfloat16"]
 | 
				
			||||||
 | 
					                            else mx.ones(shape, dtype=getattr(mx, dt))
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        mx.save_gguf(save_file_mlx, save_dict)
 | 
				
			||||||
 | 
					                        load_dict = mx.load(save_file_mlx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        self.assertTrue("test" in load_dict)
 | 
				
			||||||
 | 
					                        self.assertTrue(
 | 
				
			||||||
 | 
					                            mx.array_equal(load_dict["test"], save_dict["test"])
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_save_and_load_fs(self):
 | 
					    def test_save_and_load_fs(self):
 | 
				
			||||||
        if not os.path.isdir(self.test_dir):
 | 
					        if not os.path.isdir(self.test_dir):
 | 
				
			||||||
            os.mkdir(self.test_dir)
 | 
					            os.mkdir(self.test_dir)
 | 
				
			||||||
@@ -194,13 +221,24 @@ class TestLoad(mlx_tests.MLXTestCase):
 | 
				
			|||||||
        aload = mx.load(save_file)["a"]
 | 
					        aload = mx.load(save_file)["a"]
 | 
				
			||||||
        self.assertTrue(mx.array_equal(a, aload))
 | 
					        self.assertTrue(mx.array_equal(a, aload))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # safetensors only works with row contiguous
 | 
					        save_file = os.path.join(self.test_dir, "a.gguf")
 | 
				
			||||||
 | 
					        mx.save_gguf(save_file, {"a": a})
 | 
				
			||||||
 | 
					        aload = mx.load(save_file)["a"]
 | 
				
			||||||
 | 
					        self.assertTrue(mx.array_equal(a, aload))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # safetensors and gguf only work with row contiguous
 | 
				
			||||||
        # make sure col contiguous is handled properly
 | 
					        # make sure col contiguous is handled properly
 | 
				
			||||||
 | 
					        save_file = os.path.join(self.test_dir, "a.safetensors")
 | 
				
			||||||
        a = mx.arange(4).reshape(2, 2).T
 | 
					        a = mx.arange(4).reshape(2, 2).T
 | 
				
			||||||
        mx.save_safetensors(save_file, {"a": a})
 | 
					        mx.save_safetensors(save_file, {"a": a})
 | 
				
			||||||
        aload = mx.load(save_file)["a"]
 | 
					        aload = mx.load(save_file)["a"]
 | 
				
			||||||
        self.assertTrue(mx.array_equal(a, aload))
 | 
					        self.assertTrue(mx.array_equal(a, aload))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        save_file = os.path.join(self.test_dir, "a.gguf")
 | 
				
			||||||
 | 
					        mx.save_gguf(save_file, {"a": a})
 | 
				
			||||||
 | 
					        aload = mx.load(save_file)["a"]
 | 
				
			||||||
 | 
					        self.assertTrue(mx.array_equal(a, aload))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    unittest.main()
 | 
					    unittest.main()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,20 +20,53 @@ TEST_CASE("test save_safetensors") {
 | 
				
			|||||||
  map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
 | 
					  map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
 | 
				
			||||||
  map.insert({"test2", ones({2, 2})});
 | 
					  map.insert({"test2", ones({2, 2})});
 | 
				
			||||||
  save_safetensors(file_path, map);
 | 
					  save_safetensors(file_path, map);
 | 
				
			||||||
  auto safeDict = load_safetensors(file_path);
 | 
					  auto dict = load_safetensors(file_path);
 | 
				
			||||||
  CHECK_EQ(safeDict.size(), 2);
 | 
					  CHECK_EQ(dict.size(), 2);
 | 
				
			||||||
  CHECK_EQ(safeDict.count("test"), 1);
 | 
					  CHECK_EQ(dict.count("test"), 1);
 | 
				
			||||||
  CHECK_EQ(safeDict.count("test2"), 1);
 | 
					  CHECK_EQ(dict.count("test2"), 1);
 | 
				
			||||||
  array test = safeDict.at("test");
 | 
					  array test = dict.at("test");
 | 
				
			||||||
  CHECK_EQ(test.dtype(), float32);
 | 
					  CHECK_EQ(test.dtype(), float32);
 | 
				
			||||||
  CHECK_EQ(test.shape(), std::vector<int>({4}));
 | 
					  CHECK_EQ(test.shape(), std::vector<int>({4}));
 | 
				
			||||||
  CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>());
 | 
					  CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>());
 | 
				
			||||||
  array test2 = safeDict.at("test2");
 | 
					  array test2 = dict.at("test2");
 | 
				
			||||||
  CHECK_EQ(test2.dtype(), float32);
 | 
					  CHECK_EQ(test2.dtype(), float32);
 | 
				
			||||||
  CHECK_EQ(test2.shape(), std::vector<int>({2, 2}));
 | 
					  CHECK_EQ(test2.shape(), std::vector<int>({2, 2}));
 | 
				
			||||||
  CHECK(array_equal(test2, ones({2, 2})).item<bool>());
 | 
					  CHECK(array_equal(test2, ones({2, 2})).item<bool>());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_CASE("test gguf") {
 | 
				
			||||||
 | 
					  std::string file_path = get_temp_file("test_arr.gguf");
 | 
				
			||||||
 | 
					  using dict = std::unordered_map<std::string, array>;
 | 
				
			||||||
 | 
					  dict map = {
 | 
				
			||||||
 | 
					      {"test", array({1.0f, 2.0f, 3.0f, 4.0f})},
 | 
				
			||||||
 | 
					      {"test2", reshape(arange(6), {3, 2})}};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  save_gguf(file_path, map);
 | 
				
			||||||
 | 
					  auto loaded = load_gguf(file_path);
 | 
				
			||||||
 | 
					  CHECK_EQ(loaded.size(), 2);
 | 
				
			||||||
 | 
					  CHECK_EQ(loaded.count("test"), 1);
 | 
				
			||||||
 | 
					  CHECK_EQ(loaded.count("test2"), 1);
 | 
				
			||||||
 | 
					  for (auto [k, v] : loaded) {
 | 
				
			||||||
 | 
					    CHECK(array_equal(v, map.at(k)).item<bool>());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<Dtype> unsupported_types = {
 | 
				
			||||||
 | 
					      bool_, uint8, uint32, uint64, int64, bfloat16, complex64};
 | 
				
			||||||
 | 
					  for (auto t : unsupported_types) {
 | 
				
			||||||
 | 
					    dict to_save = {{"test", astype(arange(5), t)}};
 | 
				
			||||||
 | 
					    CHECK_THROWS(save_gguf(file_path, to_save));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<Dtype> supported_types = {int8, int32, float16};
 | 
				
			||||||
 | 
					  for (auto t : supported_types) {
 | 
				
			||||||
 | 
					    auto arr = astype(arange(5), t);
 | 
				
			||||||
 | 
					    dict to_save = {{"test", arr}};
 | 
				
			||||||
 | 
					    save_gguf(file_path, to_save);
 | 
				
			||||||
 | 
					    auto loaded = load_gguf(file_path);
 | 
				
			||||||
 | 
					    CHECK(array_equal(loaded.at("test"), arr).item<bool>());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_CASE("test single array serialization") {
 | 
					TEST_CASE("test single array serialization") {
 | 
				
			||||||
  // Basic test
 | 
					  // Basic test
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user