Add more CUDA architectures for PyPi package (#2427)

* add cuda sm 90

* add more archs
This commit is contained in:
Awni Hannun 2025-07-28 12:35:15 -07:00 committed by GitHub
parent ab0e608862
commit 641be9463b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 13 additions and 30 deletions

View File

@ -105,11 +105,11 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>") mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif() endif()
# Compute capability 7 is required for synchronization between CPU/GPU with # Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain. # managed memory.
set(MLX_CUDA_ARCHITECTURES if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
"70;80" set(MLX_CUDA_ARCHITECTURES "native")
CACHE STRING "CUDA architectures") endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}") "${MLX_CUDA_ARCHITECTURES}")

View File

@ -97,23 +97,6 @@ CommandEncoder::CaptureContext::~CaptureContext() {
if (discard) { if (discard) {
return; return;
} }
// Extract and add as single kernel node when possible.
size_t num_nodes;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
if (num_nodes == 1) {
cudaGraphNode_t captured_node;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
if (type == cudaGraphNodeTypeKernel) {
CUDA_KERNEL_NODE_PARAMS params;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
enc.add_kernel_node(params);
return;
}
}
// Otherwise add the captured graph as subgraph.
enc.add_graph_node(graph); enc.add_graph_node(graph);
} }

View File

@ -49,11 +49,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
} }
inline __device__ void atomic_add(complex64_t* out, complex64_t val) { inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
#if __CUDA_ARCH__ < 900
atomic_add_general(out, val); atomic_add_general(out, val);
#else
atomicAdd(out, val);
#endif
} }
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {

View File

@ -44,6 +44,8 @@ def get_version():
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0)) build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
# A CMakeExtension needs a sourcedir instead of a file list. # A CMakeExtension needs a sourcedir instead of a file list.
@ -85,6 +87,11 @@ class CMakeBuild(build_ext):
"-DMLX_BUILD_EXAMPLES=OFF", "-DMLX_BUILD_EXAMPLES=OFF",
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}", f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
] ]
if build_stage == 2 and build_cuda:
# Last arch is always real and virtual for forward-compatibility
cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
# Some generators require explcitly passing config when building. # Some generators require explcitly passing config when building.
build_args = ["--config", cfg] build_args = ["--config", cfg]
# Adding CMake arguments set as environment variable # Adding CMake arguments set as environment variable
@ -95,7 +102,7 @@ class CMakeBuild(build_ext):
# Pass version to C++ # Pass version to C++
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined] cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
if platform.system() == "Darwin": if build_macos:
# Cross-compile support for macOS - respect ARCHFLAGS if set # Cross-compile support for macOS - respect ARCHFLAGS if set
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
if archs: if archs:
@ -202,9 +209,6 @@ if __name__ == "__main__":
], ],
) )
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
version = get_version() version = get_version()
_setup = partial( _setup = partial(