From 31fc530c76e2b439e2526e34a30abd544858644f Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 18 Jul 2025 07:25:34 +0900 Subject: [PATCH] [CUDA] Add more ways finding CCCL headers in JIT (#2382) --- mlx/backend/cuda/jit_module.cpp | 35 ++++++++++++++++++++++++--------- tests/CMakeLists.txt | 8 ++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index a9e5631de..6585c452a 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -52,13 +52,29 @@ const std::string& cuda_home() { } // Return the location of CCCL headers shipped with the distribution. -bool get_cccl_include(std::string* out) { - auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl"; - if (!std::filesystem::exists(cccl_headers)) { - return false; - } - *out = fmt::format("--include-path={}", cccl_headers.string()); - return true; +const std::string& cccl_dir() { + static std::string dir = []() { + std::filesystem::path path; +#if defined(MLX_CCCL_DIR) + // First search the install dir if defined. + path = MLX_CCCL_DIR; + if (std::filesystem::exists(path)) { + return path.string(); + } +#endif + // Then search dynamically from the dir of libmlx.so file. + path = current_binary_dir().parent_path() / "include" / "cccl"; + if (std::filesystem::exists(path)) { + return path.string(); + } + // Finally check the environment variable. + path = std::getenv("MLX_CCCL_DIR"); + if (!path.empty() && std::filesystem::exists(path)) { + return path.string(); + } + return std::string(); + }(); + return dir; } // Get the cache directory for storing compiled results. @@ -238,8 +254,9 @@ JitModule::JitModule( device.compute_capability_major(), device.compute_capability_minor()); args.push_back(compute.c_str()); - std::string cccl_include; - if (get_cccl_include(&cccl_include)) { + std::string cccl_include = cccl_dir(); + if (!cccl_include.empty()) { + cccl_include = fmt::format("--include-path={}", cccl_include); args.push_back(cccl_include.c_str()); } std::string cuda_include = diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cb174865d..2c44bf4f6 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -39,6 +39,14 @@ target_sources( linalg_tests.cpp ${METAL_TEST_SOURCES}) +if(MLX_BUILD_CUDA) + # Find the CCCL headers in install dir. + target_compile_definitions( + mlx + PRIVATE + MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl") +endif() + target_link_libraries(tests PRIVATE mlx doctest) doctest_discover_tests(tests) add_test(NAME tests COMMAND tests)