A few updates for CPU (#1482)

* some updates

* format

* fix

* nit
This commit is contained in:
Awni Hannun
2024-10-14 12:45:49 -07:00
committed by GitHub
parent 881615b072
commit 020f048cd0
6 changed files with 50 additions and 25 deletions

View File

@@ -4,6 +4,8 @@
#include <filesystem>
#include <fstream>
#include <list>
#include <mutex>
#include <shared_mutex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
@@ -27,7 +29,7 @@ std::string get_temp_file(const std::string& name) {
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
const std::function<std::string(void)>& source_builder) {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
@@ -46,13 +48,20 @@ void* compile(
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
static std::shared_mutex compile_mtx;
{
std::shared_lock lock(compile_mtx);
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
}
std::unique_lock lock(compile_mtx);
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
}
std::string source_code = source_builder();
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
@@ -90,7 +99,7 @@ void* compile(
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
@@ -316,10 +325,7 @@ void Compiled::eval_cpu(
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
auto fn_ptr = compile(kernel_name, [&]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
@@ -334,10 +340,8 @@ void Compiled::eval_cpu(
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
return kernel.str();
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);

View File

@@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
#include <cstdint>
#include <vector>
EOM
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {

View File

@@ -682,8 +682,10 @@ array pack_and_quantize(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
zero,
n_bins),
uint32);
n_bins,
s),
uint32,
s);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
@@ -751,11 +753,11 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales =
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales), s);
scales = where(mask, scales, negative(scales, s), s);
array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge);
array biases = where(equal(q0, zero, s), zero, edge, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {

View File

@@ -7,6 +7,7 @@
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
using json = nlohmann::json;
@@ -58,6 +59,8 @@ std::string dtype_to_safetensor_str(Dtype t) {
return ST_BOOL;
case complex64:
return ST_C64;
default:
throw std::runtime_error("[save_safetensors] received invalid dtype.");
}
}
@@ -169,9 +172,18 @@ void save_safetensors(
_metadata[key] = value;
}
parent["__metadata__"] = _metadata;
{
std::vector<array> to_eval;
to_eval.reserve(a.size());
for (auto& [_, arr] : a) {
to_eval.push_back(arr);
}
eval(std::move(to_eval));
}
size_t offset = 0;
for (auto& [key, arr] : a) {
arr.eval();
if (arr.nbytes() == 0) {
throw std::invalid_argument(
"[save_safetensors] cannot serialize an empty array key: " + key);