diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 650f038c8..e5c0156c8 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -7,6 +7,8 @@ #include #include +#include + #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled_preamble.h" #include "mlx/backend/common/jit_compiler.h" @@ -105,14 +107,14 @@ void* compile( source_file << source_code; source_file.close(); - std::string command = JitCompiler::build_command( - output_dir, source_file_name, shared_lib_name); - auto return_code = system(command.c_str()); - if (return_code) { - std::ostringstream msg; - msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name - << " with error code " << return_code << "." << std::endl; - throw std::runtime_error(msg.str()); + try { + JitCompiler::exec(JitCompiler::build_command( + output_dir, source_file_name, shared_lib_name)); + } catch (const std::exception& error) { + throw std::runtime_error(fmt::format( + "[Compile::eval_cpu] Failed to compile function {0}: {1}", + kernel_name, + error.what())); } } diff --git a/mlx/backend/common/jit_compiler.cpp b/mlx/backend/common/jit_compiler.cpp index 27fb9e723..34d57138c 100644 --- a/mlx/backend/common/jit_compiler.cpp +++ b/mlx/backend/common/jit_compiler.cpp @@ -24,29 +24,6 @@ std::vector str_split(const std::string& str, char delimiter) { return tokens; } -// Run a command and get its output. -std::string exec(const std::string& cmd) { - std::unique_ptr pipe( - _popen(cmd.c_str(), "r"), _pclose); - if (!pipe) { - throw std::runtime_error("popen() failed."); - } - char buffer[128]; - std::string ret; - while (fgets(buffer, sizeof(buffer), pipe.get())) { - ret += buffer; - } - // Trim trailing spaces. - ret.erase( - std::find_if( - ret.rbegin(), - ret.rend(), - [](unsigned char ch) { return !std::isspace(ch); }) - .base(), - ret.end()); - return ret; -} - // Get path information about MSVC. struct VisualStudioInfo { VisualStudioInfo() { @@ -56,7 +33,7 @@ struct VisualStudioInfo { arch = "x64"; #endif // Get path of Visual Studio. - std::string vs_path = exec(fmt::format( + std::string vs_path = JitCompiler::exec(fmt::format( "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" " -property installationPath", std::getenv("ProgramFiles(x86)"))); @@ -64,7 +41,7 @@ struct VisualStudioInfo { throw std::runtime_error("Can not find Visual Studio."); } // Read the envs from vcvarsall. - std::string envs = exec(fmt::format( + std::string envs = JitCompiler::exec(fmt::format( "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", vs_path, arch)); @@ -110,7 +87,7 @@ std::string JitCompiler::build_command( "\"" "cd /D \"{0}\" && " "\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" " - "/link /out:\"{3}\" {4} >nul" + "/link /out:\"{3}\" {4} 2>&1" "\"", dir.string(), info.cl_exe, @@ -119,10 +96,57 @@ std::string JitCompiler::build_command( libpaths); #else return fmt::format( - "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", + "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1", (dir / source_file_name).string(), (dir / shared_lib_name).string()); #endif } +std::string JitCompiler::exec(const std::string& cmd) { +#ifdef _MSC_VER + FILE* pipe = _popen(cmd.c_str(), "r"); +#else + FILE* pipe = popen(cmd.c_str(), "r"); +#endif + if (!pipe) { + throw std::runtime_error("popen() failed."); + } + char buffer[128]; + std::string ret; + while (fgets(buffer, sizeof(buffer), pipe)) { + ret += buffer; + } + // Trim trailing spaces. + ret.erase( + std::find_if( + ret.rbegin(), + ret.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + ret.end()); + +#ifdef _MSC_VER + int status = _pclose(pipe); +#else + int status = pclose(pipe); +#endif + if (status == -1) { + throw std::runtime_error("pclose() failed."); + } +#ifdef _MSC_VER + int code = status; +#else + int code = WEXITSTATUS(status); +#endif + if (code != 0) { + throw std::runtime_error(fmt::format( + "Failed to execute command with return code {0}: \"{1}\", " + "the output is: {2}", + code, + cmd, + ret)); + } + return ret; +} + } // namespace mlx::core diff --git a/mlx/backend/common/jit_compiler.h b/mlx/backend/common/jit_compiler.h index b0bf8c0de..3a9e988da 100644 --- a/mlx/backend/common/jit_compiler.h +++ b/mlx/backend/common/jit_compiler.h @@ -12,6 +12,9 @@ class JitCompiler { const std::filesystem::path& dir, const std::string& source_file_name, const std::string& shared_lib_name); + + // Run a command and get its output. + static std::string exec(const std::string& cmd); }; } // namespace mlx::core