mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Guard PTX with architecture defines
This commit is contained in:
9
mlx/backend/cuda/steel/defines.cuh
Normal file
9
mlx/backend/cuda/steel/defines.cuh
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define MLX_UNROLL _Pragma("unroll")
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
#define MLX_CUDA_SM_80_ENABLED
|
||||
#endif
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/steel/defines.cuh"
|
||||
#include "mlx/backend/cuda/steel/tiles.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
@@ -26,6 +27,7 @@ __device__ __forceinline__ void mma_t(
|
||||
Tile16x16<float>& C,
|
||||
Tile16x16<__nv_bfloat16>& A,
|
||||
Tile16x16<__nv_bfloat16>& B) {
|
||||
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0, %1, %2, %3}, "
|
||||
@@ -82,6 +84,7 @@ __device__ __forceinline__ void mma_t(
|
||||
"f"(C.values[2].y),
|
||||
"f"(C.values[3].x),
|
||||
"f"(C.values[3].y));
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#define MLX_UNROLL _Pragma("unroll")
|
||||
#include "mlx/backend/cuda/steel/defines.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
@@ -19,7 +18,7 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
|
||||
static_assert(
|
||||
N == 16 || N == 8 || N == 4,
|
||||
"cp.async is only supported for N in {4, 8, 16}.");
|
||||
|
||||
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||
if constexpr (N == 16) {
|
||||
asm volatile(
|
||||
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||
@@ -33,13 +32,16 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
|
||||
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
||||
"l"(reinterpret_cast<const int*>(x)));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Submit all the previous async copies to be executed.
|
||||
*/
|
||||
__device__ inline void cp_async_commit() {
|
||||
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -47,11 +49,13 @@ __device__ inline void cp_async_commit() {
|
||||
*/
|
||||
template <int N>
|
||||
__device__ inline void cp_async_wait() {
|
||||
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||
if constexpr (N == 0) {
|
||||
asm volatile("cp.async.wait_all;\n" ::);
|
||||
} else {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user