mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Guard PTX with architecture defines
This commit is contained in:
9
mlx/backend/cuda/steel/defines.cuh
Normal file
9
mlx/backend/cuda/steel/defines.cuh
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#define MLX_UNROLL _Pragma("unroll")
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||||
|
#define MLX_CUDA_SM_80_ENABLED
|
||||||
|
#endif
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/steel/defines.cuh"
|
||||||
#include "mlx/backend/cuda/steel/tiles.cuh"
|
#include "mlx/backend/cuda/steel/tiles.cuh"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
@@ -26,6 +27,7 @@ __device__ __forceinline__ void mma_t(
|
|||||||
Tile16x16<float>& C,
|
Tile16x16<float>& C,
|
||||||
Tile16x16<__nv_bfloat16>& A,
|
Tile16x16<__nv_bfloat16>& A,
|
||||||
Tile16x16<__nv_bfloat16>& B) {
|
Tile16x16<__nv_bfloat16>& B) {
|
||||||
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||||
"{%0, %1, %2, %3}, "
|
"{%0, %1, %2, %3}, "
|
||||||
@@ -82,6 +84,7 @@ __device__ __forceinline__ void mma_t(
|
|||||||
"f"(C.values[2].y),
|
"f"(C.values[2].y),
|
||||||
"f"(C.values[3].x),
|
"f"(C.values[3].x),
|
||||||
"f"(C.values[3].y));
|
"f"(C.values[3].y));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -3,8 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/defines.cuh"
|
||||||
#define MLX_UNROLL _Pragma("unroll")
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
@@ -19,7 +18,7 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
|
|||||||
static_assert(
|
static_assert(
|
||||||
N == 16 || N == 8 || N == 4,
|
N == 16 || N == 8 || N == 4,
|
||||||
"cp.async is only supported for N in {4, 8, 16}.");
|
"cp.async is only supported for N in {4, 8, 16}.");
|
||||||
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
if constexpr (N == 16) {
|
if constexpr (N == 16) {
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||||
@@ -33,13 +32,16 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
|
|||||||
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
||||||
"l"(reinterpret_cast<const int*>(x)));
|
"l"(reinterpret_cast<const int*>(x)));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Submit all the previous async copies to be executed.
|
* Submit all the previous async copies to be executed.
|
||||||
*/
|
*/
|
||||||
__device__ inline void cp_async_commit() {
|
__device__ inline void cp_async_commit() {
|
||||||
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
asm volatile("cp.async.commit_group;\n" ::);
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -47,11 +49,13 @@ __device__ inline void cp_async_commit() {
|
|||||||
*/
|
*/
|
||||||
template <int N>
|
template <int N>
|
||||||
__device__ inline void cp_async_wait() {
|
__device__ inline void cp_async_wait() {
|
||||||
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
if constexpr (N == 0) {
|
if constexpr (N == 0) {
|
||||||
asm volatile("cp.async.wait_all;\n" ::);
|
asm volatile("cp.async.wait_all;\n" ::);
|
||||||
} else {
|
} else {
|
||||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Reference in New Issue
Block a user