Guard PTX with architecture defines

This commit is contained in:
Angelos Katharopoulos
2025-08-01 11:59:05 -07:00
parent c456d59e9f
commit 29d78af759
3 changed files with 19 additions and 3 deletions

View File

@@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
#pragma once
#define MLX_UNROLL _Pragma("unroll")
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#define MLX_CUDA_SM_80_ENABLED
#endif

View File

@@ -2,6 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/cuda/steel/defines.cuh"
#include "mlx/backend/cuda/steel/tiles.cuh" #include "mlx/backend/cuda/steel/tiles.cuh"
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -26,6 +27,7 @@ __device__ __forceinline__ void mma_t(
Tile16x16<float>& C, Tile16x16<float>& C,
Tile16x16<__nv_bfloat16>& A, Tile16x16<__nv_bfloat16>& A,
Tile16x16<__nv_bfloat16>& B) { Tile16x16<__nv_bfloat16>& B) {
#if defined(MLX_CUDA_SM_80_ENABLED)
asm volatile( asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, " "{%0, %1, %2, %3}, "
@@ -82,6 +84,7 @@ __device__ __forceinline__ void mma_t(
"f"(C.values[2].y), "f"(C.values[2].y),
"f"(C.values[3].x), "f"(C.values[3].x),
"f"(C.values[3].y)); "f"(C.values[3].y));
#endif
} }
/** /**

View File

@@ -3,8 +3,7 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/steel/defines.cuh"
#define MLX_UNROLL _Pragma("unroll")
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -19,7 +18,7 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
static_assert( static_assert(
N == 16 || N == 8 || N == 4, N == 16 || N == 8 || N == 4,
"cp.async is only supported for N in {4, 8, 16}."); "cp.async is only supported for N in {4, 8, 16}.");
#if defined(MLX_CUDA_SM_80_ENABLED)
if constexpr (N == 16) { if constexpr (N == 16) {
asm volatile( asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), "cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
@@ -33,13 +32,16 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), "cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int*>(x))); "l"(reinterpret_cast<const int*>(x)));
} }
#endif
} }
/** /**
* Submit all the previous async copies to be executed. * Submit all the previous async copies to be executed.
*/ */
__device__ inline void cp_async_commit() { __device__ inline void cp_async_commit() {
#if defined(MLX_CUDA_SM_80_ENABLED)
asm volatile("cp.async.commit_group;\n" ::); asm volatile("cp.async.commit_group;\n" ::);
#endif
} }
/** /**
@@ -47,11 +49,13 @@ __device__ inline void cp_async_commit() {
*/ */
template <int N> template <int N>
__device__ inline void cp_async_wait() { __device__ inline void cp_async_wait() {
#if defined(MLX_CUDA_SM_80_ENABLED)
if constexpr (N == 0) { if constexpr (N == 0) {
asm volatile("cp.async.wait_all;\n" ::); asm volatile("cp.async.wait_all;\n" ::);
} else { } else {
asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
} }
#endif
} }
/** /**