A few updates for CPU (#1482)

* some updates

* format

* fix

* nit
This commit is contained in:
Awni Hannun 2024-10-14 12:45:49 -07:00 committed by GitHub
parent 881615b072
commit 020f048cd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 50 additions and 25 deletions

View File

@ -4,6 +4,8 @@
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <list> #include <list>
#include <mutex>
#include <shared_mutex>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h" #include "mlx/backend/common/compiled_preamble.h"
@ -27,7 +29,7 @@ std::string get_temp_file(const std::string& name) {
// Return a pointer to a compiled function // Return a pointer to a compiled function
void* compile( void* compile(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& source_code = "") { const std::function<std::string(void)>& source_builder) {
struct DLib { struct DLib {
DLib(const std::string& libname) { DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW); lib = dlopen(libname.c_str(), RTLD_NOW);
@ -46,13 +48,20 @@ void* compile(
// Statics to cache compiled libraries and functions // Statics to cache compiled libraries and functions
static std::list<DLib> libs; static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels; static std::unordered_map<std::string, void*> kernels;
static std::shared_mutex compile_mtx;
{
std::shared_lock lock(compile_mtx);
if (auto it = kernels.find(kernel_name); it != kernels.end()) { if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second; return it->second;
} }
if (source_code.empty()) {
return nullptr;
} }
std::unique_lock lock(compile_mtx);
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
std::string kernel_file_name; std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255 // Deal with long kernel names. Maximum length for files on macOS is 255
@ -90,7 +99,7 @@ void* compile(
source_file.close(); source_file.close();
std::ostringstream build_command; std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared " build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path; << source_file_path << " -o " << shared_lib_path;
std::string build_command_str = build_command.str(); std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str()); auto return_code = system(build_command_str.c_str());
@ -316,10 +325,7 @@ void Compiled::eval_cpu(
} }
// Get the function // Get the function
auto fn_ptr = compile(kernel_name); auto fn_ptr = compile(kernel_name, [&]() {
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
std::ostringstream kernel; std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl; kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl; kernel << "extern \"C\" {" << std::endl;
@ -334,10 +340,8 @@ void Compiled::eval_cpu(
ndim); ndim);
// Close extern "C" // Close extern "C"
kernel << "}" << std::endl; kernel << "}" << std::endl;
return kernel.str();
// Compile and get function pointer });
fn_ptr = compile(kernel_name, kernel.str());
}
compiled_allocate_outputs( compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false); inputs, outputs, inputs_, constant_ids_, contiguous, false);

View File

@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
EOM EOM
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi fi
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null) CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE" cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() { const char* get_kernel_preamble() {

View File

@ -682,8 +682,10 @@ array pack_and_quantize(
clip( clip(
round(divide(subtract(packed_w, biases, s), scales, s), s), round(divide(subtract(packed_w, biases, s), scales, s), s),
zero, zero,
n_bins), n_bins,
uint32); s),
uint32,
s);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum( packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
@ -751,11 +753,11 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
array mask = greater(abs(w_min, s), abs(w_max, s), s); array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales = array scales =
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales), s); scales = where(mask, scales, negative(scales, s), s);
array edge = where(mask, w_min, w_max, s); array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s); array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge); array biases = where(equal(q0, zero, s), zero, edge, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s); packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return { return {

View File

@ -7,6 +7,7 @@
#include "mlx/io/load.h" #include "mlx/io/load.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/transforms.h"
using json = nlohmann::json; using json = nlohmann::json;
@ -58,6 +59,8 @@ std::string dtype_to_safetensor_str(Dtype t) {
return ST_BOOL; return ST_BOOL;
case complex64: case complex64:
return ST_C64; return ST_C64;
default:
throw std::runtime_error("[save_safetensors] received invalid dtype.");
} }
} }
@ -169,9 +172,18 @@ void save_safetensors(
_metadata[key] = value; _metadata[key] = value;
} }
parent["__metadata__"] = _metadata; parent["__metadata__"] = _metadata;
{
std::vector<array> to_eval;
to_eval.reserve(a.size());
for (auto& [_, arr] : a) {
to_eval.push_back(arr);
}
eval(std::move(to_eval));
}
size_t offset = 0; size_t offset = 0;
for (auto& [key, arr] : a) { for (auto& [key, arr] : a) {
arr.eval();
if (arr.nbytes() == 0) { if (arr.nbytes() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"[save_safetensors] cannot serialize an empty array key: " + key); "[save_safetensors] cannot serialize an empty array key: " + key);

View File

@ -142,12 +142,13 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
case float16: case float16:
return mlx_to_nd_array_impl<float16_t, NDParams...>(a); return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
case bfloat16: case bfloat16:
throw nb::type_error( throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
"bfloat16 arrays cannot be converted directly to NumPy.");
case float32: case float32:
return mlx_to_nd_array_impl<float, NDParams...>(a); return mlx_to_nd_array_impl<float, NDParams...>(a);
case complex64: case complex64:
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a); return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
default:
throw nb::type_error("type cannot be converted to NumPy.");
} }
} }
@ -195,6 +196,8 @@ nb::object to_scalar(array& a) {
return nb::cast(static_cast<float>(a.item<bfloat16_t>())); return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
case complex64: case complex64:
return nb::cast(a.item<std::complex<float>>()); return nb::cast(a.item<std::complex<float>>());
default:
throw nb::type_error("type cannot be converted to Python scalar.");
} }
} }
@ -248,6 +251,8 @@ nb::object tolist(array& a) {
return to_list<bfloat16_t, float>(a, 0, 0); return to_list<bfloat16_t, float>(a, 0, 0);
case complex64: case complex64:
return to_list<std::complex<float>>(a, 0, 0); return to_list<std::complex<float>>(a, 0, 0);
default:
throw nb::type_error("data type cannot be converted to Python list.");
} }
} }

View File

@ -1308,8 +1308,8 @@ void init_ops(nb::module_& m) {
"start"_a, "start"_a,
"stop"_a, "stop"_a,
"step"_a = nb::none(), "step"_a = nb::none(),
nb::kw_only(),
"dtype"_a = nb::none(), "dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), "def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
@ -1356,8 +1356,8 @@ void init_ops(nb::module_& m) {
}, },
"stop"_a, "stop"_a,
"step"_a = nb::none(), "step"_a = nb::none(),
nb::kw_only(),
"dtype"_a = nb::none(), "dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array")); "def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));