Add compiler flags to disable safetensors and gguf (#1098)

* with docs

* nit
This commit is contained in:
Awni Hannun 2024-05-09 17:39:44 -07:00 committed by GitHub
parent 06375e6605
commit 8b1906abd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 118 additions and 30 deletions

View File

@ -17,6 +17,8 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)

View File

@ -157,7 +157,10 @@ should point to the path to the built metal library.
- OFF - OFF
* - MLX_METAL_DEBUG * - MLX_METAL_DEBUG
- OFF - OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
.. note:: .. note::

View File

@ -6,8 +6,8 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/io/load.h" #include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
using GGUFMetaData = using GGUFMetaData =

View File

@ -1,33 +1,58 @@
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp
) )
MESSAGE(STATUS "Downloading json") if (MLX_BUILD_SAFETENSORS)
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) MESSAGE(STATUS "Downloading json")
FetchContent_MakeAvailable(json) FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
target_include_directories( FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE mlx PRIVATE
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann> $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
) )
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp
)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp
)
endif()
MESSAGE(STATUS "Downloading gguflib") if (MLX_BUILD_GGUF)
FetchContent_Declare(gguflib MESSAGE(STATUS "Downloading gguflib")
FetchContent_Declare(gguflib
GIT_REPOSITORY https://github.com/antirez/gguf-tools/ GIT_REPOSITORY https://github.com/antirez/gguf-tools/
GIT_TAG af7d88d808a7608a33723fba067036202910acb3 GIT_TAG af7d88d808a7608a33723fba067036202910acb3
) )
FetchContent_MakeAvailable(gguflib) FetchContent_MakeAvailable(gguflib)
target_include_directories( target_include_directories(
mlx PRIVATE mlx PRIVATE
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}> $<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
) )
add_library(
add_library(
gguflib STATIC gguflib STATIC
${gguflib_SOURCE_DIR}/fp16.c ${gguflib_SOURCE_DIR}/fp16.c
${gguflib_SOURCE_DIR}/gguflib.c) ${gguflib_SOURCE_DIR}/gguflib.c)
target_link_libraries(mlx $<BUILD_INTERFACE:gguflib>) target_link_libraries(mlx $<BUILD_INTERFACE:gguflib>)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp
)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp
)
endif()

View File

@ -4,7 +4,8 @@
#include <cstring> #include <cstring>
#include <numeric> #include <numeric>
#include <mlx/io/gguf.h> #include "mlx/io/gguf.h"
#include "mlx/ops.h"
namespace mlx::core { namespace mlx::core {

View File

@ -4,7 +4,7 @@
#include <cstring> #include <cstring>
#include <numeric> #include <numeric>
#include <mlx/io/gguf.h> #include "mlx/io/gguf.h"
namespace mlx::core { namespace mlx::core {

20
mlx/io/no_gguf.cpp Normal file
View File

@ -0,0 +1,20 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/io.h"
namespace mlx::core {
GGUFLoad load_gguf(const std::string&, StreamOrDevice s) {
throw std::runtime_error(
"[load_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.");
}
void save_gguf(
std::string,
std::unordered_map<std::string, array>,
std::unordered_map<std::string, GGUFMetaData>) {
throw std::runtime_error(
"[save_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.");
}
} // namespace mlx::core

37
mlx/io/no_safetensors.cpp Normal file
View File

@ -0,0 +1,37 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/io.h"
namespace mlx::core {
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader>, StreamOrDevice) {
throw std::runtime_error(
"[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
"to enable safetensors support.");
}
SafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) {
throw std::runtime_error(
"[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
"to enable safetensors support.");
}
void save_safetensors(
std::shared_ptr<io::Writer>,
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>) {
throw std::runtime_error(
"[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
"to enable safetensors support.");
}
void save_safetensors(
std::string file,
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>) {
throw std::runtime_error(
"[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
"to enable safetensors support.");
}
} // namespace mlx::core

View File

@ -5,6 +5,7 @@
#include "mlx/io.h" #include "mlx/io.h"
#include "mlx/io/load.h" #include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
using json = nlohmann::json; using json = nlohmann::json;
@ -149,7 +150,6 @@ SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s); return load_safetensors(std::make_shared<io::FileReader>(file), s);
} }
/** Save array to out stream in .npy format */
void save_safetensors( void save_safetensors(
std::shared_ptr<io::Writer> out_stream, std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a, std::unordered_map<std::string, array> a,