[CUDA] Add more ways finding CCCL headers in JIT (#2382)

This commit is contained in:
Cheng 2025-07-18 07:25:34 +09:00 committed by GitHub
parent fbb3f65a1a
commit 31fc530c76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 9 deletions

View File

@ -52,13 +52,29 @@ const std::string& cuda_home() {
} }
// Return the location of CCCL headers shipped with the distribution. // Return the location of CCCL headers shipped with the distribution.
bool get_cccl_include(std::string* out) { const std::string& cccl_dir() {
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl"; static std::string dir = []() {
if (!std::filesystem::exists(cccl_headers)) { std::filesystem::path path;
return false; #if defined(MLX_CCCL_DIR)
} // First search the install dir if defined.
*out = fmt::format("--include-path={}", cccl_headers.string()); path = MLX_CCCL_DIR;
return true; if (std::filesystem::exists(path)) {
return path.string();
}
#endif
// Then search dynamically from the dir of libmlx.so file.
path = current_binary_dir().parent_path() / "include" / "cccl";
if (std::filesystem::exists(path)) {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
return std::string();
}();
return dir;
} }
// Get the cache directory for storing compiled results. // Get the cache directory for storing compiled results.
@ -238,8 +254,9 @@ JitModule::JitModule(
device.compute_capability_major(), device.compute_capability_major(),
device.compute_capability_minor()); device.compute_capability_minor());
args.push_back(compute.c_str()); args.push_back(compute.c_str());
std::string cccl_include; std::string cccl_include = cccl_dir();
if (get_cccl_include(&cccl_include)) { if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str()); args.push_back(cccl_include.c_str());
} }
std::string cuda_include = std::string cuda_include =

View File

@ -39,6 +39,14 @@ target_sources(
linalg_tests.cpp linalg_tests.cpp
${METAL_TEST_SOURCES}) ${METAL_TEST_SOURCES})
if(MLX_BUILD_CUDA)
# Find the CCCL headers in install dir.
target_compile_definitions(
mlx
PRIVATE
MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl")
endif()
target_link_libraries(tests PRIVATE mlx doctest) target_link_libraries(tests PRIVATE mlx doctest)
doctest_discover_tests(tests) doctest_discover_tests(tests)
add_test(NAME tests COMMAND tests) add_test(NAME tests COMMAND tests)