use arch specific targets when possible (#2771)

This commit is contained in:
Awni Hannun
2025-11-14 20:04:18 -08:00
committed by GitHub
parent 3c622ddd1d
commit 1bf605d56d
4 changed files with 34 additions and 5 deletions

View File

@@ -126,7 +126,11 @@ endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native")
execute_process(
COMMAND bash detect_cuda_arch.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
OUTPUT_STRIP_TRAILING_WHITESPACE)
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES

View File

@@ -0,0 +1,13 @@
#!/bin/bash
arch=`__nvcc_device_query`
case "$arch" in
"90")
echo "90a" ;;
"100")
echo "100a" ;;
"121")
echo "121a" ;;
*)
echo "native" ;;
esac

View File

@@ -279,11 +279,14 @@ void compile(
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
auto cc = device.compute_capability_major();
std::string arch_tag = (cc == 90 || cc == 100 || cc == 121) ? "a" : "";
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
"--gpu-architecture={}_{}{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
cc,
device.compute_capability_minor(),
arch_tag);
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {

View File

@@ -89,7 +89,16 @@ class CMakeBuild(build_ext):
]
if build_stage == 2 and build_cuda:
# Last arch is always real and virtual for forward-compatibility
cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
cuda_archs = ";".join(
(
"75-real",
"80-real",
"90a-real",
"100a-real",
"120a-real",
"120-virtual",
)
)
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
# Some generators require explcitly passing config when building.