[CUDA] Bundle CCCL for JIT compilation (#2357)

* Ship CCCL for JIT compilation

* Remove cexpf
This commit is contained in:
Cheng
2025-07-12 10:45:37 +09:00
committed by GitHub
parent 42cc9cfbc7
commit 6325f60d52
9 changed files with 48 additions and 176 deletions

View File

@@ -13,6 +13,7 @@
#include <fmt/format.h>
#include <nvrtc.h>
#include <unistd.h>
namespace mlx::core::cu {
@@ -50,6 +51,16 @@ const std::string& cuda_home() {
return home;
}
// Return the location of CCCL headers shipped with the distribution.
bool get_cccl_include(std::string* out) {
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
if (!std::filesystem::exists(cccl_headers)) {
return false;
}
*out = fmt::format("--include-path={}", cccl_headers.string());
return true;
}
// Get the cache directory for storing compiled results.
const std::filesystem::path& ptx_cache_dir() {
static std::filesystem::path cache = []() -> std::filesystem::path {
@@ -161,7 +172,6 @@ constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "atomic_ops.cuh",
INCLUDE_PREFIX "binary_ops.cuh",
INCLUDE_PREFIX "cast_op.cuh",
INCLUDE_PREFIX "cexpf.cuh",
INCLUDE_PREFIX "config.h",
INCLUDE_PREFIX "cucomplex_math.cuh",
INCLUDE_PREFIX "fp16_math.cuh",
@@ -178,7 +188,6 @@ constexpr const char* g_headers[] = {
jit_source_atomic_ops,
jit_source_binary_ops,
jit_source_cast_op,
jit_source_cexpf,
jit_source_config,
jit_source_cucomplex_math,
jit_source_fp16_math,
@@ -217,16 +226,23 @@ JitModule::JitModule(
}
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
std::string include = fmt::format("--include-path={}/include", cuda_home());
const char* args[] = {compute.c_str(), include.c_str()};
args.push_back(compute.c_str());
std::string cccl_include;
if (get_cccl_include(&cccl_include)) {
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result =
nvrtcCompileProgram(prog, std::size(args), args);
nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));