mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
881615b072
commit
020f048cd0
@ -4,6 +4,8 @@
|
|||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <mutex>
|
||||||
|
#include <shared_mutex>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/compiled_preamble.h"
|
#include "mlx/backend/common/compiled_preamble.h"
|
||||||
@ -27,7 +29,7 @@ std::string get_temp_file(const std::string& name) {
|
|||||||
// Return a pointer to a compiled function
|
// Return a pointer to a compiled function
|
||||||
void* compile(
|
void* compile(
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::string& source_code = "") {
|
const std::function<std::string(void)>& source_builder) {
|
||||||
struct DLib {
|
struct DLib {
|
||||||
DLib(const std::string& libname) {
|
DLib(const std::string& libname) {
|
||||||
lib = dlopen(libname.c_str(), RTLD_NOW);
|
lib = dlopen(libname.c_str(), RTLD_NOW);
|
||||||
@ -46,13 +48,20 @@ void* compile(
|
|||||||
// Statics to cache compiled libraries and functions
|
// Statics to cache compiled libraries and functions
|
||||||
static std::list<DLib> libs;
|
static std::list<DLib> libs;
|
||||||
static std::unordered_map<std::string, void*> kernels;
|
static std::unordered_map<std::string, void*> kernels;
|
||||||
|
static std::shared_mutex compile_mtx;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::shared_lock lock(compile_mtx);
|
||||||
|
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock lock(compile_mtx);
|
||||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
if (source_code.empty()) {
|
std::string source_code = source_builder();
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string kernel_file_name;
|
std::string kernel_file_name;
|
||||||
|
|
||||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||||
@ -90,7 +99,7 @@ void* compile(
|
|||||||
source_file.close();
|
source_file.close();
|
||||||
|
|
||||||
std::ostringstream build_command;
|
std::ostringstream build_command;
|
||||||
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
|
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared "
|
||||||
<< source_file_path << " -o " << shared_lib_path;
|
<< source_file_path << " -o " << shared_lib_path;
|
||||||
std::string build_command_str = build_command.str();
|
std::string build_command_str = build_command.str();
|
||||||
auto return_code = system(build_command_str.c_str());
|
auto return_code = system(build_command_str.c_str());
|
||||||
@ -316,10 +325,7 @@ void Compiled::eval_cpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the function
|
// Get the function
|
||||||
auto fn_ptr = compile(kernel_name);
|
auto fn_ptr = compile(kernel_name, [&]() {
|
||||||
|
|
||||||
// If it doesn't exist, compile it
|
|
||||||
if (fn_ptr == nullptr) {
|
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << get_kernel_preamble() << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
@ -334,10 +340,8 @@ void Compiled::eval_cpu(
|
|||||||
ndim);
|
ndim);
|
||||||
// Close extern "C"
|
// Close extern "C"
|
||||||
kernel << "}" << std::endl;
|
kernel << "}" << std::endl;
|
||||||
|
return kernel.str();
|
||||||
// Compile and get function pointer
|
});
|
||||||
fn_ptr = compile(kernel_name, kernel.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(
|
||||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||||
|
@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
EOM
|
EOM
|
||||||
|
CC_FLAGS=""
|
||||||
|
else
|
||||||
|
CC_FLAGS="-std=c++17"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
||||||
|
|
||||||
cat << EOF > "$OUTPUT_FILE"
|
cat << EOF > "$OUTPUT_FILE"
|
||||||
const char* get_kernel_preamble() {
|
const char* get_kernel_preamble() {
|
||||||
|
10
mlx/fast.cpp
10
mlx/fast.cpp
@ -682,8 +682,10 @@ array pack_and_quantize(
|
|||||||
clip(
|
clip(
|
||||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||||
zero,
|
zero,
|
||||||
n_bins),
|
n_bins,
|
||||||
uint32);
|
s),
|
||||||
|
uint32,
|
||||||
|
s);
|
||||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||||
packed_w = sum(
|
packed_w = sum(
|
||||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||||
@ -751,11 +753,11 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||||
array scales =
|
array scales =
|
||||||
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||||
scales = where(mask, scales, negative(scales), s);
|
scales = where(mask, scales, negative(scales, s), s);
|
||||||
array edge = where(mask, w_min, w_max, s);
|
array edge = where(mask, w_min, w_max, s);
|
||||||
array q0 = round(divide(edge, scales, s), s);
|
array q0 = round(divide(edge, scales, s), s);
|
||||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||||
array biases = where(equal(q0, zero, s), zero, edge);
|
array biases = where(equal(q0, zero, s), zero, edge, s);
|
||||||
|
|
||||||
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
|
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
|
||||||
return {
|
return {
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include "mlx/io/load.h"
|
#include "mlx/io/load.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/transforms.h"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
@ -58,6 +59,8 @@ std::string dtype_to_safetensor_str(Dtype t) {
|
|||||||
return ST_BOOL;
|
return ST_BOOL;
|
||||||
case complex64:
|
case complex64:
|
||||||
return ST_C64;
|
return ST_C64;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("[save_safetensors] received invalid dtype.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,9 +172,18 @@ void save_safetensors(
|
|||||||
_metadata[key] = value;
|
_metadata[key] = value;
|
||||||
}
|
}
|
||||||
parent["__metadata__"] = _metadata;
|
parent["__metadata__"] = _metadata;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<array> to_eval;
|
||||||
|
to_eval.reserve(a.size());
|
||||||
|
for (auto& [_, arr] : a) {
|
||||||
|
to_eval.push_back(arr);
|
||||||
|
}
|
||||||
|
eval(std::move(to_eval));
|
||||||
|
}
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& [key, arr] : a) {
|
for (auto& [key, arr] : a) {
|
||||||
arr.eval();
|
|
||||||
if (arr.nbytes() == 0) {
|
if (arr.nbytes() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[save_safetensors] cannot serialize an empty array key: " + key);
|
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||||
|
@ -142,12 +142,13 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
|
|||||||
case float16:
|
case float16:
|
||||||
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
throw nb::type_error(
|
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
|
||||||
"bfloat16 arrays cannot be converted directly to NumPy.");
|
|
||||||
case float32:
|
case float32:
|
||||||
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||||
case complex64:
|
case complex64:
|
||||||
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
||||||
|
default:
|
||||||
|
throw nb::type_error("type cannot be converted to NumPy.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,6 +196,8 @@ nb::object to_scalar(array& a) {
|
|||||||
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
|
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
|
||||||
case complex64:
|
case complex64:
|
||||||
return nb::cast(a.item<std::complex<float>>());
|
return nb::cast(a.item<std::complex<float>>());
|
||||||
|
default:
|
||||||
|
throw nb::type_error("type cannot be converted to Python scalar.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,6 +251,8 @@ nb::object tolist(array& a) {
|
|||||||
return to_list<bfloat16_t, float>(a, 0, 0);
|
return to_list<bfloat16_t, float>(a, 0, 0);
|
||||||
case complex64:
|
case complex64:
|
||||||
return to_list<std::complex<float>>(a, 0, 0);
|
return to_list<std::complex<float>>(a, 0, 0);
|
||||||
|
default:
|
||||||
|
throw nb::type_error("data type cannot be converted to Python list.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1308,8 +1308,8 @@ void init_ops(nb::module_& m) {
|
|||||||
"start"_a,
|
"start"_a,
|
||||||
"stop"_a,
|
"stop"_a,
|
||||||
"step"_a = nb::none(),
|
"step"_a = nb::none(),
|
||||||
nb::kw_only(),
|
|
||||||
"dtype"_a = nb::none(),
|
"dtype"_a = nb::none(),
|
||||||
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
@ -1356,8 +1356,8 @@ void init_ops(nb::module_& m) {
|
|||||||
},
|
},
|
||||||
"stop"_a,
|
"stop"_a,
|
||||||
"step"_a = nb::none(),
|
"step"_a = nb::none(),
|
||||||
nb::kw_only(),
|
|
||||||
"dtype"_a = nb::none(),
|
"dtype"_a = nb::none(),
|
||||||
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
|
"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
|
||||||
|
Loading…
Reference in New Issue
Block a user