mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
use arch specific targets when possible (#2771)
This commit is contained in:
@@ -126,7 +126,11 @@ endif()
|
|||||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||||
# managed memory.
|
# managed memory.
|
||||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||||
set(MLX_CUDA_ARCHITECTURES "native")
|
execute_process(
|
||||||
|
COMMAND bash detect_cuda_arch.sh
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||||
|
|||||||
13
mlx/backend/cuda/detect_cuda_arch.sh
Normal file
13
mlx/backend/cuda/detect_cuda_arch.sh
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
arch=`__nvcc_device_query`
|
||||||
|
case "$arch" in
|
||||||
|
"90")
|
||||||
|
echo "90a" ;;
|
||||||
|
"100")
|
||||||
|
echo "100a" ;;
|
||||||
|
"121")
|
||||||
|
echo "121a" ;;
|
||||||
|
*)
|
||||||
|
echo "native" ;;
|
||||||
|
esac
|
||||||
@@ -279,11 +279,14 @@ void compile(
|
|||||||
// Compile program.
|
// Compile program.
|
||||||
std::vector<const char*> args;
|
std::vector<const char*> args;
|
||||||
bool use_sass = compiler_supports_device_sass(device);
|
bool use_sass = compiler_supports_device_sass(device);
|
||||||
|
auto cc = device.compute_capability_major();
|
||||||
|
std::string arch_tag = (cc == 90 || cc == 100 || cc == 121) ? "a" : "";
|
||||||
std::string compute = fmt::format(
|
std::string compute = fmt::format(
|
||||||
"--gpu-architecture={}_{}{}",
|
"--gpu-architecture={}_{}{}{}",
|
||||||
use_sass ? "sm" : "compute",
|
use_sass ? "sm" : "compute",
|
||||||
device.compute_capability_major(),
|
cc,
|
||||||
device.compute_capability_minor());
|
device.compute_capability_minor(),
|
||||||
|
arch_tag);
|
||||||
args.push_back(compute.c_str());
|
args.push_back(compute.c_str());
|
||||||
std::string cccl_include = cccl_dir();
|
std::string cccl_include = cccl_dir();
|
||||||
if (!cccl_include.empty()) {
|
if (!cccl_include.empty()) {
|
||||||
|
|||||||
11
setup.py
11
setup.py
@@ -89,7 +89,16 @@ class CMakeBuild(build_ext):
|
|||||||
]
|
]
|
||||||
if build_stage == 2 and build_cuda:
|
if build_stage == 2 and build_cuda:
|
||||||
# Last arch is always real and virtual for forward-compatibility
|
# Last arch is always real and virtual for forward-compatibility
|
||||||
cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
|
cuda_archs = ";".join(
|
||||||
|
(
|
||||||
|
"75-real",
|
||||||
|
"80-real",
|
||||||
|
"90a-real",
|
||||||
|
"100a-real",
|
||||||
|
"120a-real",
|
||||||
|
"120-virtual",
|
||||||
|
)
|
||||||
|
)
|
||||||
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
|
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
|
||||||
|
|
||||||
# Some generators require explcitly passing config when building.
|
# Some generators require explcitly passing config when building.
|
||||||
|
|||||||
Reference in New Issue
Block a user