mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 07:31:26 +08:00
[CUDA] Add more ways finding CCCL headers in JIT (#2382)
This commit is contained in:
parent
fbb3f65a1a
commit
31fc530c76
@ -52,13 +52,29 @@ const std::string& cuda_home() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the location of CCCL headers shipped with the distribution.
|
// Return the location of CCCL headers shipped with the distribution.
|
||||||
bool get_cccl_include(std::string* out) {
|
const std::string& cccl_dir() {
|
||||||
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
|
static std::string dir = []() {
|
||||||
if (!std::filesystem::exists(cccl_headers)) {
|
std::filesystem::path path;
|
||||||
return false;
|
#if defined(MLX_CCCL_DIR)
|
||||||
}
|
// First search the install dir if defined.
|
||||||
*out = fmt::format("--include-path={}", cccl_headers.string());
|
path = MLX_CCCL_DIR;
|
||||||
return true;
|
if (std::filesystem::exists(path)) {
|
||||||
|
return path.string();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
// Then search dynamically from the dir of libmlx.so file.
|
||||||
|
path = current_binary_dir().parent_path() / "include" / "cccl";
|
||||||
|
if (std::filesystem::exists(path)) {
|
||||||
|
return path.string();
|
||||||
|
}
|
||||||
|
// Finally check the environment variable.
|
||||||
|
path = std::getenv("MLX_CCCL_DIR");
|
||||||
|
if (!path.empty() && std::filesystem::exists(path)) {
|
||||||
|
return path.string();
|
||||||
|
}
|
||||||
|
return std::string();
|
||||||
|
}();
|
||||||
|
return dir;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the cache directory for storing compiled results.
|
// Get the cache directory for storing compiled results.
|
||||||
@ -238,8 +254,9 @@ JitModule::JitModule(
|
|||||||
device.compute_capability_major(),
|
device.compute_capability_major(),
|
||||||
device.compute_capability_minor());
|
device.compute_capability_minor());
|
||||||
args.push_back(compute.c_str());
|
args.push_back(compute.c_str());
|
||||||
std::string cccl_include;
|
std::string cccl_include = cccl_dir();
|
||||||
if (get_cccl_include(&cccl_include)) {
|
if (!cccl_include.empty()) {
|
||||||
|
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||||
args.push_back(cccl_include.c_str());
|
args.push_back(cccl_include.c_str());
|
||||||
}
|
}
|
||||||
std::string cuda_include =
|
std::string cuda_include =
|
||||||
|
@ -39,6 +39,14 @@ target_sources(
|
|||||||
linalg_tests.cpp
|
linalg_tests.cpp
|
||||||
${METAL_TEST_SOURCES})
|
${METAL_TEST_SOURCES})
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
# Find the CCCL headers in install dir.
|
||||||
|
target_compile_definitions(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
target_link_libraries(tests PRIVATE mlx doctest)
|
target_link_libraries(tests PRIVATE mlx doctest)
|
||||||
doctest_discover_tests(tests)
|
doctest_discover_tests(tests)
|
||||||
add_test(NAME tests COMMAND tests)
|
add_test(NAME tests COMMAND tests)
|
||||||
|
Loading…
Reference in New Issue
Block a user