mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	[CUDA] Add more ways finding CCCL headers in JIT (#2382)
This commit is contained in:
		| @@ -52,13 +52,29 @@ const std::string& cuda_home() { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Return the location of CCCL headers shipped with the distribution. | // Return the location of CCCL headers shipped with the distribution. | ||||||
| bool get_cccl_include(std::string* out) { | const std::string& cccl_dir() { | ||||||
|   auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl"; |   static std::string dir = []() { | ||||||
|   if (!std::filesystem::exists(cccl_headers)) { |     std::filesystem::path path; | ||||||
|     return false; | #if defined(MLX_CCCL_DIR) | ||||||
|  |     // First search the install dir if defined. | ||||||
|  |     path = MLX_CCCL_DIR; | ||||||
|  |     if (std::filesystem::exists(path)) { | ||||||
|  |       return path.string(); | ||||||
|     } |     } | ||||||
|   *out = fmt::format("--include-path={}", cccl_headers.string()); | #endif | ||||||
|   return true; |     // Then search dynamically from the dir of libmlx.so file. | ||||||
|  |     path = current_binary_dir().parent_path() / "include" / "cccl"; | ||||||
|  |     if (std::filesystem::exists(path)) { | ||||||
|  |       return path.string(); | ||||||
|  |     } | ||||||
|  |     // Finally check the environment variable. | ||||||
|  |     path = std::getenv("MLX_CCCL_DIR"); | ||||||
|  |     if (!path.empty() && std::filesystem::exists(path)) { | ||||||
|  |       return path.string(); | ||||||
|  |     } | ||||||
|  |     return std::string(); | ||||||
|  |   }(); | ||||||
|  |   return dir; | ||||||
| } | } | ||||||
|  |  | ||||||
| // Get the cache directory for storing compiled results. | // Get the cache directory for storing compiled results. | ||||||
| @@ -238,8 +254,9 @@ JitModule::JitModule( | |||||||
|         device.compute_capability_major(), |         device.compute_capability_major(), | ||||||
|         device.compute_capability_minor()); |         device.compute_capability_minor()); | ||||||
|     args.push_back(compute.c_str()); |     args.push_back(compute.c_str()); | ||||||
|     std::string cccl_include; |     std::string cccl_include = cccl_dir(); | ||||||
|     if (get_cccl_include(&cccl_include)) { |     if (!cccl_include.empty()) { | ||||||
|  |       cccl_include = fmt::format("--include-path={}", cccl_include); | ||||||
|       args.push_back(cccl_include.c_str()); |       args.push_back(cccl_include.c_str()); | ||||||
|     } |     } | ||||||
|     std::string cuda_include = |     std::string cuda_include = | ||||||
|   | |||||||
| @@ -39,6 +39,14 @@ target_sources( | |||||||
|           linalg_tests.cpp |           linalg_tests.cpp | ||||||
|           ${METAL_TEST_SOURCES}) |           ${METAL_TEST_SOURCES}) | ||||||
|  |  | ||||||
|  | if(MLX_BUILD_CUDA) | ||||||
|  |   # Find the CCCL headers in install dir. | ||||||
|  |   target_compile_definitions( | ||||||
|  |     mlx | ||||||
|  |     PRIVATE | ||||||
|  |       MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl") | ||||||
|  | endif() | ||||||
|  |  | ||||||
| target_link_libraries(tests PRIVATE mlx doctest) | target_link_libraries(tests PRIVATE mlx doctest) | ||||||
| doctest_discover_tests(tests) | doctest_discover_tests(tests) | ||||||
| add_test(NAME tests COMMAND tests) | add_test(NAME tests COMMAND tests) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Cheng
					Cheng