mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	Add more CUDA architectures for PyPi package (#2427)
* add cuda sm 90 * add more archs
This commit is contained in:
		@@ -105,11 +105,11 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
 | 
			
		||||
    mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
# Compute capability 7 is required for synchronization between CPU/GPU with
 | 
			
		||||
# managed memory. TODO: Add more architectures for potential performance gain.
 | 
			
		||||
set(MLX_CUDA_ARCHITECTURES
 | 
			
		||||
    "70;80"
 | 
			
		||||
    CACHE STRING "CUDA architectures")
 | 
			
		||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
 | 
			
		||||
# managed memory.
 | 
			
		||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
 | 
			
		||||
  set(MLX_CUDA_ARCHITECTURES "native")
 | 
			
		||||
endif()
 | 
			
		||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
 | 
			
		||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
 | 
			
		||||
                                     "${MLX_CUDA_ARCHITECTURES}")
 | 
			
		||||
 
 | 
			
		||||
@@ -97,23 +97,6 @@ CommandEncoder::CaptureContext::~CaptureContext() {
 | 
			
		||||
  if (discard) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Extract and add as single kernel node when possible.
 | 
			
		||||
  size_t num_nodes;
 | 
			
		||||
  CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
 | 
			
		||||
  if (num_nodes == 1) {
 | 
			
		||||
    cudaGraphNode_t captured_node;
 | 
			
		||||
    CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
 | 
			
		||||
    cudaGraphNodeType type;
 | 
			
		||||
    CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
 | 
			
		||||
    if (type == cudaGraphNodeTypeKernel) {
 | 
			
		||||
      CUDA_KERNEL_NODE_PARAMS params;
 | 
			
		||||
      CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
 | 
			
		||||
      enc.add_kernel_node(params);
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // Otherwise add the captured graph as subgraph.
 | 
			
		||||
  enc.add_graph_node(graph);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -49,11 +49,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
 | 
			
		||||
#if __CUDA_ARCH__ < 900
 | 
			
		||||
  atomic_add_general(out, val);
 | 
			
		||||
#else
 | 
			
		||||
  atomicAdd(out, val);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										12
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								setup.py
									
									
									
									
									
								
							@@ -44,6 +44,8 @@ def get_version():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
 | 
			
		||||
build_macos = platform.system() == "Darwin"
 | 
			
		||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# A CMakeExtension needs a sourcedir instead of a file list.
 | 
			
		||||
@@ -85,6 +87,11 @@ class CMakeBuild(build_ext):
 | 
			
		||||
            "-DMLX_BUILD_EXAMPLES=OFF",
 | 
			
		||||
            f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
 | 
			
		||||
        ]
 | 
			
		||||
        if build_stage == 2 and build_cuda:
 | 
			
		||||
            # Last arch is always real and virtual for forward-compatibility
 | 
			
		||||
            cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
 | 
			
		||||
            cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
 | 
			
		||||
 | 
			
		||||
        # Some generators require explcitly passing config when building.
 | 
			
		||||
        build_args = ["--config", cfg]
 | 
			
		||||
        # Adding CMake arguments set as environment variable
 | 
			
		||||
@@ -95,7 +102,7 @@ class CMakeBuild(build_ext):
 | 
			
		||||
        # Pass version to C++
 | 
			
		||||
        cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"]  # type: ignore[attr-defined]
 | 
			
		||||
 | 
			
		||||
        if platform.system() == "Darwin":
 | 
			
		||||
        if build_macos:
 | 
			
		||||
            # Cross-compile support for macOS - respect ARCHFLAGS if set
 | 
			
		||||
            archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
 | 
			
		||||
            if archs:
 | 
			
		||||
@@ -202,9 +209,6 @@ if __name__ == "__main__":
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    build_macos = platform.system() == "Darwin"
 | 
			
		||||
    build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
 | 
			
		||||
 | 
			
		||||
    version = get_version()
 | 
			
		||||
 | 
			
		||||
    _setup = partial(
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user