A few updates for CPU (#1482)

* some updates

* format

* fix

* nit
This commit is contained in:
Awni Hannun
2024-10-14 12:45:49 -07:00
committed by GitHub
parent 881615b072
commit 020f048cd0
6 changed files with 50 additions and 25 deletions

View File

@@ -4,6 +4,8 @@
#include <filesystem>
#include <fstream>
#include <list>
#include <mutex>
#include <shared_mutex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
@@ -27,7 +29,7 @@ std::string get_temp_file(const std::string& name) {
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
const std::function<std::string(void)>& source_builder) {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
@@ -46,13 +48,20 @@ void* compile(
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
static std::shared_mutex compile_mtx;
{
std::shared_lock lock(compile_mtx);
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
}
std::unique_lock lock(compile_mtx);
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
}
std::string source_code = source_builder();
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
@@ -90,7 +99,7 @@ void* compile(
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
@@ -316,10 +325,7 @@ void Compiled::eval_cpu(
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
auto fn_ptr = compile(kernel_name, [&]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
@@ -334,10 +340,8 @@ void Compiled::eval_cpu(
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
return kernel.str();
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);

View File

@@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
#include <cstdint>
#include <vector>
EOM
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {