From 1bf605d56df3312af6c173547ed4c4c199f11cf6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 14 Nov 2025 20:04:18 -0800 Subject: [PATCH] use arch specific targets when possible (#2771) --- mlx/backend/cuda/CMakeLists.txt | 6 +++++- mlx/backend/cuda/detect_cuda_arch.sh | 13 +++++++++++++ mlx/backend/cuda/jit_module.cpp | 9 ++++++--- setup.py | 11 ++++++++++- 4 files changed, 34 insertions(+), 5 deletions(-) create mode 100644 mlx/backend/cuda/detect_cuda_arch.sh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 892d46c51..a4606e5e3 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -126,7 +126,11 @@ endif() # Compute capability >= 7.0 is required for synchronization between CPU/GPU with # managed memory. if(NOT DEFINED MLX_CUDA_ARCHITECTURES) - set(MLX_CUDA_ARCHITECTURES "native") + execute_process( + COMMAND bash detect_cuda_arch.sh + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES + OUTPUT_STRIP_TRAILING_WHITESPACE) endif() message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES diff --git a/mlx/backend/cuda/detect_cuda_arch.sh b/mlx/backend/cuda/detect_cuda_arch.sh new file mode 100644 index 000000000..9d7c01a3e --- /dev/null +++ b/mlx/backend/cuda/detect_cuda_arch.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +arch=`__nvcc_device_query` +case "$arch" in + "90") + echo "90a" ;; + "100") + echo "100a" ;; + "121") + echo "121a" ;; + *) + echo "native" ;; +esac diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 2801e4a67..69cf6ff66 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -279,11 +279,14 @@ void compile( // Compile program. std::vector args; bool use_sass = compiler_supports_device_sass(device); + auto cc = device.compute_capability_major(); + std::string arch_tag = (cc == 90 || cc == 100 || cc == 121) ? "a" : ""; std::string compute = fmt::format( - "--gpu-architecture={}_{}{}", + "--gpu-architecture={}_{}{}{}", use_sass ? "sm" : "compute", - device.compute_capability_major(), - device.compute_capability_minor()); + cc, + device.compute_capability_minor(), + arch_tag); args.push_back(compute.c_str()); std::string cccl_include = cccl_dir(); if (!cccl_include.empty()) { diff --git a/setup.py b/setup.py index 3b82a2cbd..aeeeb0912 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,16 @@ class CMakeBuild(build_ext): ] if build_stage == 2 and build_cuda: # Last arch is always real and virtual for forward-compatibility - cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120")) + cuda_archs = ";".join( + ( + "75-real", + "80-real", + "90a-real", + "100a-real", + "120a-real", + "120-virtual", + ) + ) cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"] # Some generators require explcitly passing config when building.