mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -4,6 +4,8 @@
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/compiled_preamble.h"
|
||||
@@ -27,7 +29,7 @@ std::string get_temp_file(const std::string& name) {
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::string& source_code = "") {
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
struct DLib {
|
||||
DLib(const std::string& libname) {
|
||||
lib = dlopen(libname.c_str(), RTLD_NOW);
|
||||
@@ -46,13 +48,20 @@ void* compile(
|
||||
// Statics to cache compiled libraries and functions
|
||||
static std::list<DLib> libs;
|
||||
static std::unordered_map<std::string, void*> kernels;
|
||||
static std::shared_mutex compile_mtx;
|
||||
|
||||
{
|
||||
std::shared_lock lock(compile_mtx);
|
||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(compile_mtx);
|
||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
if (source_code.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string source_code = source_builder();
|
||||
std::string kernel_file_name;
|
||||
|
||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||
@@ -90,7 +99,7 @@ void* compile(
|
||||
source_file.close();
|
||||
|
||||
std::ostringstream build_command;
|
||||
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
|
||||
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared "
|
||||
<< source_file_path << " -o " << shared_lib_path;
|
||||
std::string build_command_str = build_command.str();
|
||||
auto return_code = system(build_command_str.c_str());
|
||||
@@ -316,10 +325,7 @@ void Compiled::eval_cpu(
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name);
|
||||
|
||||
// If it doesn't exist, compile it
|
||||
if (fn_ptr == nullptr) {
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@@ -334,10 +340,8 @@ void Compiled::eval_cpu(
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
kernel << "}" << std::endl;
|
||||
|
||||
// Compile and get function pointer
|
||||
fn_ptr = compile(kernel_name, kernel.str());
|
||||
}
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
|
||||
@@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
EOM
|
||||
|
||||
CC_FLAGS=""
|
||||
else
|
||||
CC_FLAGS="-std=c++17"
|
||||
fi
|
||||
|
||||
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
||||
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
const char* get_kernel_preamble() {
|
||||
|
||||
Reference in New Issue
Block a user