diff --git a/mlx/backend/cuda/steel/defines.cuh b/mlx/backend/cuda/steel/defines.cuh new file mode 100644 index 000000000..bf920428f --- /dev/null +++ b/mlx/backend/cuda/steel/defines.cuh @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#define MLX_UNROLL _Pragma("unroll") + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +#define MLX_CUDA_SM_80_ENABLED +#endif diff --git a/mlx/backend/cuda/steel/mma.cuh b/mlx/backend/cuda/steel/mma.cuh index 42d3c9040..94e314909 100644 --- a/mlx/backend/cuda/steel/mma.cuh +++ b/mlx/backend/cuda/steel/mma.cuh @@ -2,6 +2,7 @@ #pragma once +#include "mlx/backend/cuda/steel/defines.cuh" #include "mlx/backend/cuda/steel/tiles.cuh" namespace mlx::core::cu { @@ -26,6 +27,7 @@ __device__ __forceinline__ void mma_t( Tile16x16& C, Tile16x16<__nv_bfloat16>& A, Tile16x16<__nv_bfloat16>& B) { +#if defined(MLX_CUDA_SM_80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}, " @@ -82,6 +84,7 @@ __device__ __forceinline__ void mma_t( "f"(C.values[2].y), "f"(C.values[3].x), "f"(C.values[3].y)); +#endif } /** diff --git a/mlx/backend/cuda/steel/utils.cuh b/mlx/backend/cuda/steel/utils.cuh index cfa8c0ad5..0957c09d0 100644 --- a/mlx/backend/cuda/steel/utils.cuh +++ b/mlx/backend/cuda/steel/utils.cuh @@ -3,8 +3,7 @@ #pragma once #include "mlx/backend/cuda/device/utils.cuh" - -#define MLX_UNROLL _Pragma("unroll") +#include "mlx/backend/cuda/steel/defines.cuh" namespace mlx::core::cu { @@ -19,7 +18,7 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) { static_assert( N == 16 || N == 8 || N == 4, "cp.async is only supported for N in {4, 8, 16}."); - +#if defined(MLX_CUDA_SM_80_ENABLED) if constexpr (N == 16) { asm volatile( "cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), @@ -33,13 +32,16 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) { "cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), "l"(reinterpret_cast(x))); } +#endif } /** * Submit all the previous async copies to be executed. */ __device__ inline void cp_async_commit() { +#if defined(MLX_CUDA_SM_80_ENABLED) asm volatile("cp.async.commit_group;\n" ::); +#endif } /** @@ -47,11 +49,13 @@ __device__ inline void cp_async_commit() { */ template __device__ inline void cp_async_wait() { +#if defined(MLX_CUDA_SM_80_ENABLED) if constexpr (N == 0) { asm volatile("cp.async.wait_all;\n" ::); } else { asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); } +#endif } /**