diff --git a/CMakeLists.txt b/CMakeLists.txt index 151017b9a..373ae1cbe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,8 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) +option(MLX_BUILD_GGUF "Include support for GGUF format" ON) +option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) diff --git a/docs/src/install.rst b/docs/src/install.rst index f34db7270..252b234e6 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -157,7 +157,10 @@ should point to the path to the built metal library. - OFF * - MLX_METAL_DEBUG - OFF - + * - MLX_BUILD_SAFETENSORS + - ON + * - MLX_BUILD_GGUF + - ON .. note:: diff --git a/mlx/io.h b/mlx/io.h index e30c0de34..4805d1c87 100644 --- a/mlx/io.h +++ b/mlx/io.h @@ -6,8 +6,8 @@ #include "mlx/array.h" #include "mlx/io/load.h" -#include "mlx/ops.h" #include "mlx/stream.h" +#include "mlx/utils.h" namespace mlx::core { using GGUFMetaData = diff --git a/mlx/io/CMakeLists.txt b/mlx/io/CMakeLists.txt index 193488ad9..38caeff9a 100644 --- a/mlx/io/CMakeLists.txt +++ b/mlx/io/CMakeLists.txt @@ -1,33 +1,58 @@ + target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp ) -MESSAGE(STATUS "Downloading json") -FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) -FetchContent_MakeAvailable(json) -target_include_directories( - mlx PRIVATE - $ -) +if (MLX_BUILD_SAFETENSORS) + MESSAGE(STATUS "Downloading json") + FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) + FetchContent_MakeAvailable(json) + target_include_directories( + mlx PRIVATE + $ + ) + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp + ) +else() + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp + ) +endif() -MESSAGE(STATUS "Downloading gguflib") -FetchContent_Declare(gguflib - GIT_REPOSITORY https://github.com/antirez/gguf-tools/ - GIT_TAG af7d88d808a7608a33723fba067036202910acb3 -) -FetchContent_MakeAvailable(gguflib) -target_include_directories( - mlx PRIVATE - $ -) +if (MLX_BUILD_GGUF) + MESSAGE(STATUS "Downloading gguflib") + FetchContent_Declare(gguflib + GIT_REPOSITORY https://github.com/antirez/gguf-tools/ + GIT_TAG af7d88d808a7608a33723fba067036202910acb3 + ) + FetchContent_MakeAvailable(gguflib) + target_include_directories( + mlx PRIVATE + $ + ) + add_library( + gguflib STATIC + ${gguflib_SOURCE_DIR}/fp16.c + ${gguflib_SOURCE_DIR}/gguflib.c) + target_link_libraries(mlx $) + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp + ) +else() + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp + ) +endif() -add_library( - gguflib STATIC - ${gguflib_SOURCE_DIR}/fp16.c - ${gguflib_SOURCE_DIR}/gguflib.c) -target_link_libraries(mlx $) diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 0193d2d09..c452886fd 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -4,7 +4,8 @@ #include #include -#include +#include "mlx/io/gguf.h" +#include "mlx/ops.h" namespace mlx::core { diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index 7c4e87253..a06ccfe65 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include "mlx/io/gguf.h" namespace mlx::core { diff --git a/mlx/io/no_gguf.cpp b/mlx/io/no_gguf.cpp new file mode 100644 index 000000000..822d8fcd6 --- /dev/null +++ b/mlx/io/no_gguf.cpp @@ -0,0 +1,20 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/io.h" + +namespace mlx::core { + +GGUFLoad load_gguf(const std::string&, StreamOrDevice s) { + throw std::runtime_error( + "[load_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support."); +} + +void save_gguf( + std::string, + std::unordered_map, + std::unordered_map) { + throw std::runtime_error( + "[save_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support."); +} + +} // namespace mlx::core diff --git a/mlx/io/no_safetensors.cpp b/mlx/io/no_safetensors.cpp new file mode 100644 index 000000000..949c5d2e9 --- /dev/null +++ b/mlx/io/no_safetensors.cpp @@ -0,0 +1,37 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/io.h" + +namespace mlx::core { + +SafetensorsLoad load_safetensors(std::shared_ptr, StreamOrDevice) { + throw std::runtime_error( + "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " + "to enable safetensors support."); +} + +SafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) { + throw std::runtime_error( + "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " + "to enable safetensors support."); +} + +void save_safetensors( + std::shared_ptr, + std::unordered_map, + std::unordered_map) { + throw std::runtime_error( + "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " + "to enable safetensors support."); +} + +void save_safetensors( + std::string file, + std::unordered_map, + std::unordered_map) { + throw std::runtime_error( + "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " + "to enable safetensors support."); +} + +} // namespace mlx::core diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensors.cpp similarity index 99% rename from mlx/io/safetensor.cpp rename to mlx/io/safetensors.cpp index 69ebd46c8..76d8151b9 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensors.cpp @@ -5,6 +5,7 @@ #include "mlx/io.h" #include "mlx/io/load.h" +#include "mlx/ops.h" #include "mlx/primitives.h" using json = nlohmann::json; @@ -149,7 +150,6 @@ SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { return load_safetensors(std::make_shared(file), s); } -/** Save array to out stream in .npy format */ void save_safetensors( std::shared_ptr out_stream, std::unordered_map a,