From 020f048cd0e7ff83c055281c59cdca405eaf04dd Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 12:45:49 -0700 Subject: [PATCH] A few updates for CPU (#1482) * some updates * format * fix * nit --- mlx/backend/common/compiled_cpu.cpp | 32 +++++++++++--------- mlx/backend/common/make_compiled_preamble.sh | 6 ++-- mlx/fast.cpp | 10 +++--- mlx/io/safetensors.cpp | 14 ++++++++- python/src/convert.cpp | 9 ++++-- python/src/ops.cpp | 4 +-- 6 files changed, 50 insertions(+), 25 deletions(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 13f2233ad..2c0df6073 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled_preamble.h" @@ -27,7 +29,7 @@ std::string get_temp_file(const std::string& name) { // Return a pointer to a compiled function void* compile( const std::string& kernel_name, - const std::string& source_code = "") { + const std::function& source_builder) { struct DLib { DLib(const std::string& libname) { lib = dlopen(libname.c_str(), RTLD_NOW); @@ -46,13 +48,20 @@ void* compile( // Statics to cache compiled libraries and functions static std::list libs; static std::unordered_map kernels; + static std::shared_mutex compile_mtx; + + { + std::shared_lock lock(compile_mtx); + if (auto it = kernels.find(kernel_name); it != kernels.end()) { + return it->second; + } + } + + std::unique_lock lock(compile_mtx); if (auto it = kernels.find(kernel_name); it != kernels.end()) { return it->second; } - if (source_code.empty()) { - return nullptr; - } - + std::string source_code = source_builder(); std::string kernel_file_name; // Deal with long kernel names. Maximum length for files on macOS is 255 @@ -90,7 +99,7 @@ void* compile( source_file.close(); std::ostringstream build_command; - build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared " + build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared " << source_file_path << " -o " << shared_lib_path; std::string build_command_str = build_command.str(); auto return_code = system(build_command_str.c_str()); @@ -316,10 +325,7 @@ void Compiled::eval_cpu( } // Get the function - auto fn_ptr = compile(kernel_name); - - // If it doesn't exist, compile it - if (fn_ptr == nullptr) { + auto fn_ptr = compile(kernel_name, [&]() { std::ostringstream kernel; kernel << get_kernel_preamble() << std::endl; kernel << "extern \"C\" {" << std::endl; @@ -334,10 +340,8 @@ void Compiled::eval_cpu( ndim); // Close extern "C" kernel << "}" << std::endl; - - // Compile and get function pointer - fn_ptr = compile(kernel_name, kernel.str()); - } + return kernel.str(); + }); compiled_allocate_outputs( inputs, outputs, inputs_, constant_ids_, contiguous, false); diff --git a/mlx/backend/common/make_compiled_preamble.sh b/mlx/backend/common/make_compiled_preamble.sh index bbb1187d1..149dc2886 100644 --- a/mlx/backend/common/make_compiled_preamble.sh +++ b/mlx/backend/common/make_compiled_preamble.sh @@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then #include #include EOM - +CC_FLAGS="" +else +CC_FLAGS="-std=c++17" fi -CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null) +CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null) cat << EOF > "$OUTPUT_FILE" const char* get_kernel_preamble() { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a3fe1ea1e..6fd4862d4 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -682,8 +682,10 @@ array pack_and_quantize( clip( round(divide(subtract(packed_w, biases, s), scales, s), s), zero, - n_bins), - uint32); + n_bins, + s), + uint32, + s); packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); packed_w = sum( multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); @@ -751,11 +753,11 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { array mask = greater(abs(w_min, s), abs(w_max, s), s); array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); - scales = where(mask, scales, negative(scales), s); + scales = where(mask, scales, negative(scales, s), s); array edge = where(mask, w_min, w_max, s); array q0 = round(divide(edge, scales, s), s); scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); - array biases = where(equal(q0, zero, s), zero, edge); + array biases = where(equal(q0, zero, s), zero, edge, s); packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s); return { diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index 0a41f0826..f022fb25f 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -7,6 +7,7 @@ #include "mlx/io/load.h" #include "mlx/ops.h" #include "mlx/primitives.h" +#include "mlx/transforms.h" using json = nlohmann::json; @@ -58,6 +59,8 @@ std::string dtype_to_safetensor_str(Dtype t) { return ST_BOOL; case complex64: return ST_C64; + default: + throw std::runtime_error("[save_safetensors] received invalid dtype."); } } @@ -169,9 +172,18 @@ void save_safetensors( _metadata[key] = value; } parent["__metadata__"] = _metadata; + + { + std::vector to_eval; + to_eval.reserve(a.size()); + for (auto& [_, arr] : a) { + to_eval.push_back(arr); + } + eval(std::move(to_eval)); + } + size_t offset = 0; for (auto& [key, arr] : a) { - arr.eval(); if (arr.nbytes() == 0) { throw std::invalid_argument( "[save_safetensors] cannot serialize an empty array key: " + key); diff --git a/python/src/convert.cpp b/python/src/convert.cpp index b3554eef8..6cd874b65 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -142,12 +142,13 @@ nb::ndarray mlx_to_nd_array(const array& a) { case float16: return mlx_to_nd_array_impl(a); case bfloat16: - throw nb::type_error( - "bfloat16 arrays cannot be converted directly to NumPy."); + throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); case float32: return mlx_to_nd_array_impl(a); case complex64: return mlx_to_nd_array_impl, NDParams...>(a); + default: + throw nb::type_error("type cannot be converted to NumPy."); } } @@ -195,6 +196,8 @@ nb::object to_scalar(array& a) { return nb::cast(static_cast(a.item())); case complex64: return nb::cast(a.item>()); + default: + throw nb::type_error("type cannot be converted to Python scalar."); } } @@ -248,6 +251,8 @@ nb::object tolist(array& a) { return to_list(a, 0, 0); case complex64: return to_list>(a, 0, 0); + default: + throw nb::type_error("data type cannot be converted to Python list."); } } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 20278b40d..b0b308a91 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1308,8 +1308,8 @@ void init_ops(nb::module_& m) { "start"_a, "stop"_a, "step"_a = nb::none(), - nb::kw_only(), "dtype"_a = nb::none(), + nb::kw_only(), "stream"_a = nb::none(), nb::sig( "def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), @@ -1356,8 +1356,8 @@ void init_ops(nb::module_& m) { }, "stop"_a, "step"_a = nb::none(), - nb::kw_only(), "dtype"_a = nb::none(), + nb::kw_only(), "stream"_a = nb::none(), nb::sig( "def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));