mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-05 19:06:44 +08:00
Add more CUDA architectures for PyPi package (#2427)
* add cuda sm 90 * add more archs
This commit is contained in:
parent
ab0e608862
commit
641be9463b
@ -105,11 +105,11 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
|||||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
# managed memory.
|
||||||
set(MLX_CUDA_ARCHITECTURES
|
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||||
"70;80"
|
set(MLX_CUDA_ARCHITECTURES "native")
|
||||||
CACHE STRING "CUDA architectures")
|
endif()
|
||||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||||
"${MLX_CUDA_ARCHITECTURES}")
|
"${MLX_CUDA_ARCHITECTURES}")
|
||||||
|
@ -97,23 +97,6 @@ CommandEncoder::CaptureContext::~CaptureContext() {
|
|||||||
if (discard) {
|
if (discard) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and add as single kernel node when possible.
|
|
||||||
size_t num_nodes;
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
|
||||||
if (num_nodes == 1) {
|
|
||||||
cudaGraphNode_t captured_node;
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
|
||||||
cudaGraphNodeType type;
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
|
|
||||||
if (type == cudaGraphNodeTypeKernel) {
|
|
||||||
CUDA_KERNEL_NODE_PARAMS params;
|
|
||||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
|
||||||
enc.add_kernel_node(params);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Otherwise add the captured graph as subgraph.
|
|
||||||
enc.add_graph_node(graph);
|
enc.add_graph_node(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,11 +49,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
|
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
|
||||||
#if __CUDA_ARCH__ < 900
|
|
||||||
atomic_add_general(out, val);
|
atomic_add_general(out, val);
|
||||||
#else
|
|
||||||
atomicAdd(out, val);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||||
|
12
setup.py
12
setup.py
@ -44,6 +44,8 @@ def get_version():
|
|||||||
|
|
||||||
|
|
||||||
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
|
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
|
||||||
|
build_macos = platform.system() == "Darwin"
|
||||||
|
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||||
|
|
||||||
|
|
||||||
# A CMakeExtension needs a sourcedir instead of a file list.
|
# A CMakeExtension needs a sourcedir instead of a file list.
|
||||||
@ -85,6 +87,11 @@ class CMakeBuild(build_ext):
|
|||||||
"-DMLX_BUILD_EXAMPLES=OFF",
|
"-DMLX_BUILD_EXAMPLES=OFF",
|
||||||
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
||||||
]
|
]
|
||||||
|
if build_stage == 2 and build_cuda:
|
||||||
|
# Last arch is always real and virtual for forward-compatibility
|
||||||
|
cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
|
||||||
|
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
|
||||||
|
|
||||||
# Some generators require explcitly passing config when building.
|
# Some generators require explcitly passing config when building.
|
||||||
build_args = ["--config", cfg]
|
build_args = ["--config", cfg]
|
||||||
# Adding CMake arguments set as environment variable
|
# Adding CMake arguments set as environment variable
|
||||||
@ -95,7 +102,7 @@ class CMakeBuild(build_ext):
|
|||||||
# Pass version to C++
|
# Pass version to C++
|
||||||
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
|
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
|
||||||
|
|
||||||
if platform.system() == "Darwin":
|
if build_macos:
|
||||||
# Cross-compile support for macOS - respect ARCHFLAGS if set
|
# Cross-compile support for macOS - respect ARCHFLAGS if set
|
||||||
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
|
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
|
||||||
if archs:
|
if archs:
|
||||||
@ -202,9 +209,6 @@ if __name__ == "__main__":
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
build_macos = platform.system() == "Darwin"
|
|
||||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
|
||||||
|
|
||||||
version = get_version()
|
version = get_version()
|
||||||
|
|
||||||
_setup = partial(
|
_setup = partial(
|
||||||
|
Loading…
Reference in New Issue
Block a user