mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Make mx.compile work on Windows (#1697)
* Invoke MSVC on Windows in mx.compile * Export kernel symbol on MSVC * Remove unused template * Parse env pairs in a robust way * No need of cassert * Remove unnecessary helpers * Fix right trim * Move command building to a separate file * Missing header * Do not pollute cwd with cl.exe * Simplify str concat * Pass output dir * Fix styling
This commit is contained in:
parent
88f993da38
commit
935c8c4bb1
@ -66,5 +66,6 @@ target_sources(
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
|
||||
endif()
|
||||
|
@ -9,6 +9,7 @@
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/compiled_preamble.h"
|
||||
#include "mlx/backend/common/jit_compiler.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
|
||||
@ -44,11 +45,8 @@ namespace detail {
|
||||
bool compile_available_for_device(const Device& device) {
|
||||
return true;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name).string();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
@ -88,9 +86,10 @@ void* compile(
|
||||
kernel_file_name = kernel_name;
|
||||
}
|
||||
|
||||
std::ostringstream shared_lib_name;
|
||||
shared_lib_name << "lib" << kernel_file_name << ".so";
|
||||
auto shared_lib_path = get_temp_file(shared_lib_name.str());
|
||||
auto output_dir = std::filesystem::temp_directory_path();
|
||||
|
||||
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
|
||||
auto shared_lib_path = (output_dir / shared_lib_name).string();
|
||||
bool lib_exists = false;
|
||||
{
|
||||
std::ifstream f(shared_lib_path.c_str());
|
||||
@ -99,19 +98,16 @@ void* compile(
|
||||
|
||||
if (!lib_exists) {
|
||||
// Open source file and write source code to it
|
||||
std::ostringstream source_file_name;
|
||||
source_file_name << kernel_file_name << ".cpp";
|
||||
auto source_file_path = get_temp_file(source_file_name.str());
|
||||
std::string source_file_name = kernel_file_name + ".cpp";
|
||||
auto source_file_path = (output_dir / source_file_name).string();
|
||||
|
||||
std::ofstream source_file(source_file_path);
|
||||
source_file << source_code;
|
||||
source_file.close();
|
||||
|
||||
std::ostringstream build_command;
|
||||
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
|
||||
<< source_file_path << "' -o '" << shared_lib_path << "'";
|
||||
std::string build_command_str = build_command.str();
|
||||
auto return_code = system(build_command_str.c_str());
|
||||
std::string command = JitCompiler::build_command(
|
||||
output_dir, source_file_name, shared_lib_name);
|
||||
auto return_code = system(command.c_str());
|
||||
if (return_code) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
|
||||
@ -156,6 +152,11 @@ inline void build_kernel(
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// Export the symbol
|
||||
os << "__declspec(dllexport) ";
|
||||
#endif
|
||||
|
||||
// Start the kernel
|
||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||
|
||||
|
128
mlx/backend/common/jit_compiler.cpp
Normal file
128
mlx/backend/common/jit_compiler.cpp
Normal file
@ -0,0 +1,128 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/jit_compiler.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
||||
namespace {
|
||||
|
||||
// Split string into array.
|
||||
std::vector<std::string> str_split(const std::string& str, char delimiter) {
|
||||
std::vector<std::string> tokens;
|
||||
std::string token;
|
||||
std::istringstream tokenStream(str);
|
||||
while (std::getline(tokenStream, token, delimiter)) {
|
||||
tokens.push_back(token);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// Run a command and get its output.
|
||||
std::string exec(const std::string& cmd) {
|
||||
std::unique_ptr<FILE, decltype(&_pclose)> pipe(
|
||||
_popen(cmd.c_str(), "r"), _pclose);
|
||||
if (!pipe) {
|
||||
throw std::runtime_error("popen() failed.");
|
||||
}
|
||||
char buffer[128];
|
||||
std::string ret;
|
||||
while (fgets(buffer, sizeof(buffer), pipe.get())) {
|
||||
ret += buffer;
|
||||
}
|
||||
// Trim trailing spaces.
|
||||
ret.erase(
|
||||
std::find_if(
|
||||
ret.rbegin(),
|
||||
ret.rend(),
|
||||
[](unsigned char ch) { return !std::isspace(ch); })
|
||||
.base(),
|
||||
ret.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Get path information about MSVC.
|
||||
struct VisualStudioInfo {
|
||||
VisualStudioInfo() {
|
||||
#ifdef _M_ARM64
|
||||
arch = "arm64";
|
||||
#else
|
||||
arch = "x64";
|
||||
#endif
|
||||
// Get path of Visual Studio.
|
||||
std::string vs_path = exec(fmt::format(
|
||||
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
|
||||
" -property installationPath",
|
||||
std::getenv("ProgramFiles(x86)")));
|
||||
if (vs_path.empty()) {
|
||||
throw std::runtime_error("Can not find Visual Studio.");
|
||||
}
|
||||
// Read the envs from vcvarsall.
|
||||
std::string envs = exec(fmt::format(
|
||||
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
|
||||
vs_path,
|
||||
arch));
|
||||
for (const std::string& line : str_split(envs, '\n')) {
|
||||
// Each line is in the format "ENV_NAME=values".
|
||||
auto pos = line.find_first_of('=');
|
||||
if (pos == std::string::npos || pos == 0 || pos == line.size() - 1)
|
||||
continue;
|
||||
std::string name = line.substr(0, pos);
|
||||
std::string value = line.substr(pos + 1);
|
||||
if (name == "LIB") {
|
||||
libpaths = str_split(value, ';');
|
||||
} else if (name == "VCToolsInstallDir") {
|
||||
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::string arch;
|
||||
std::string cl_exe;
|
||||
std::vector<std::string> libpaths;
|
||||
};
|
||||
|
||||
const VisualStudioInfo& GetVisualStudioInfo() {
|
||||
static VisualStudioInfo info;
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // _MSC_VER
|
||||
|
||||
std::string JitCompiler::build_command(
|
||||
const std::filesystem::path& dir,
|
||||
const std::string& source_file_name,
|
||||
const std::string& shared_lib_name) {
|
||||
#ifdef _MSC_VER
|
||||
const VisualStudioInfo& info = GetVisualStudioInfo();
|
||||
std::string libpaths;
|
||||
for (const std::string& lib : info.libpaths) {
|
||||
libpaths += fmt::format(" /libpath:\"{0}\"", lib);
|
||||
}
|
||||
return fmt::format(
|
||||
"\""
|
||||
"cd /D \"{0}\" && "
|
||||
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
|
||||
"/link /out:\"{3}\" {4} >nul"
|
||||
"\"",
|
||||
dir.string(),
|
||||
info.cl_exe,
|
||||
source_file_name,
|
||||
shared_lib_name,
|
||||
libpaths);
|
||||
#else
|
||||
return fmt::format(
|
||||
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'",
|
||||
(dir / source_file_name).string(),
|
||||
(dir / shared_lib_name).string());
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
17
mlx/backend/common/jit_compiler.h
Normal file
17
mlx/backend/common/jit_compiler.h
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
class JitCompiler {
|
||||
public:
|
||||
// Build a shell command that compiles a source code file to a shared library.
|
||||
static std::string build_command(
|
||||
const std::filesystem::path& dir,
|
||||
const std::string& source_file_name,
|
||||
const std::string& shared_lib_name);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
@ -13,7 +13,7 @@ $CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/comp
|
||||
# Otherwise there will be too much empty lines making the result unreadable.
|
||||
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
|
||||
# Concatenate to string.
|
||||
$CONTENT = $CONTENT -join '`n'
|
||||
$CONTENT = $CONTENT -join "`n"
|
||||
|
||||
# Append extra content.
|
||||
$CONTENT = @"
|
||||
|
Loading…
Reference in New Issue
Block a user