Guard PTX with architecture defines

This commit is contained in:
Angelos Katharopoulos
2025-08-01 11:59:05 -07:00
parent c456d59e9f
commit 29d78af759
3 changed files with 19 additions and 3 deletions

View File

@@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
#pragma once
#define MLX_UNROLL _Pragma("unroll")
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#define MLX_CUDA_SM_80_ENABLED
#endif

View File

@@ -2,6 +2,7 @@
#pragma once
#include "mlx/backend/cuda/steel/defines.cuh"
#include "mlx/backend/cuda/steel/tiles.cuh"
namespace mlx::core::cu {
@@ -26,6 +27,7 @@ __device__ __forceinline__ void mma_t(
Tile16x16<float>& C,
Tile16x16<__nv_bfloat16>& A,
Tile16x16<__nv_bfloat16>& B) {
#if defined(MLX_CUDA_SM_80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
@@ -82,6 +84,7 @@ __device__ __forceinline__ void mma_t(
"f"(C.values[2].y),
"f"(C.values[3].x),
"f"(C.values[3].y));
#endif
}
/**

View File

@@ -3,8 +3,7 @@
#pragma once
#include "mlx/backend/cuda/device/utils.cuh"
#define MLX_UNROLL _Pragma("unroll")
#include "mlx/backend/cuda/steel/defines.cuh"
namespace mlx::core::cu {
@@ -19,7 +18,7 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
static_assert(
N == 16 || N == 8 || N == 4,
"cp.async is only supported for N in {4, 8, 16}.");
#if defined(MLX_CUDA_SM_80_ENABLED)
if constexpr (N == 16) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
@@ -33,13 +32,16 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int*>(x)));
}
#endif
}
/**
* Submit all the previous async copies to be executed.
*/
__device__ inline void cp_async_commit() {
#if defined(MLX_CUDA_SM_80_ENABLED)
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
/**
@@ -47,11 +49,13 @@ __device__ inline void cp_async_commit() {
*/
template <int N>
__device__ inline void cp_async_wait() {
#if defined(MLX_CUDA_SM_80_ENABLED)
if constexpr (N == 0) {
asm volatile("cp.async.wait_all;\n" ::);
} else {
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
}
#endif
}
/**