mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 14:59:22 +08:00
Reset cutlass gemm to working state again
This commit is contained in:
parent
cf5eef095d
commit
e1303f6160
@ -10,219 +10,315 @@ namespace mlx::core::cu {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
void configure_matmul(Kernel kernel, int smem_size) {
|
||||||
|
static bool initialized = false;
|
||||||
|
if (!initialized) {
|
||||||
|
initialized = true;
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ElementA, class ElementB, class SmemLayoutA, class SmemLayoutB>
|
||||||
|
struct SharedStorage {
|
||||||
|
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A;
|
||||||
|
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B;
|
||||||
|
};
|
||||||
|
|
||||||
template <
|
template <
|
||||||
class T,
|
class ProblemShape,
|
||||||
class CtaTiler,
|
class CtaTiler,
|
||||||
class SmemLayoutA,
|
class TA,
|
||||||
class SmemLayoutB,
|
class AStride,
|
||||||
class TiledMMA,
|
class ASmemLayout,
|
||||||
class G2STiledCopyA,
|
class TiledCopyA,
|
||||||
class G2STiledCopyB,
|
class S2RAtomA,
|
||||||
class S2RCopyAtomA,
|
class TB,
|
||||||
class S2RCopyAtomB>
|
class BStride,
|
||||||
__global__ void cute_gemm_v02(
|
class BSmemLayout,
|
||||||
unsigned int M,
|
class TiledCopyB,
|
||||||
unsigned int N,
|
class S2RAtomB,
|
||||||
unsigned int K,
|
class TC,
|
||||||
const T* A,
|
class CStride,
|
||||||
size_t lda,
|
class CSmemLayout,
|
||||||
const T* B,
|
class TiledMma>
|
||||||
size_t ldb,
|
__global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device(
|
||||||
T* C,
|
ProblemShape shape_MNK,
|
||||||
size_t ldc) {
|
CtaTiler cta_tiler,
|
||||||
|
TA const* A,
|
||||||
|
AStride dA,
|
||||||
|
ASmemLayout sA_layout,
|
||||||
|
TiledCopyA copy_a,
|
||||||
|
S2RAtomA s2r_atom_a,
|
||||||
|
TB const* B,
|
||||||
|
BStride dB,
|
||||||
|
BSmemLayout sB_layout,
|
||||||
|
TiledCopyB copy_b,
|
||||||
|
S2RAtomB s2r_atom_b,
|
||||||
|
TC* C,
|
||||||
|
CStride dC,
|
||||||
|
CSmemLayout,
|
||||||
|
TiledMma mma) {
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
// global full tensor
|
// Preconditions
|
||||||
// shape
|
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
|
||||||
auto shape_MNK = make_shape(M, N, K);
|
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
|
||||||
|
|
||||||
// stride
|
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
|
||||||
// cublas covenience for TN gemm
|
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
|
||||||
// all matrices are in column major
|
|
||||||
// A (m,k) --> transpose --> A(k, m) --> cute layout: A (m, k) : (k, 1) -->
|
static_assert(is_static<ASmemLayout>::value);
|
||||||
// lda = k B (k,n) --> cute layout: B (n, k) : (k, 1) --> ldb = k C (m,n) -->
|
static_assert(is_static<BSmemLayout>::value);
|
||||||
// cute layout: C (m, n) : (1, m) --> ldc = m
|
static_assert(is_static<CSmemLayout>::value);
|
||||||
auto dA = make_stride(lda, _1{});
|
|
||||||
auto dB = make_stride(ldb, _1{});
|
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||||
auto dC = make_stride(_1{}, ldc);
|
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||||
|
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||||
|
|
||||||
|
CUTE_STATIC_ASSERT_V(
|
||||||
|
congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK
|
||||||
|
CUTE_STATIC_ASSERT_V(
|
||||||
|
congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK
|
||||||
|
CUTE_STATIC_ASSERT_V(
|
||||||
|
congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN
|
||||||
|
|
||||||
|
//
|
||||||
|
// Full and Tiled Tensors
|
||||||
|
//
|
||||||
|
|
||||||
|
// Represent the full tensors
|
||||||
Tensor mA =
|
Tensor mA =
|
||||||
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // M x K
|
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K)
|
||||||
Tensor mB =
|
Tensor mB =
|
||||||
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // N x K
|
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K)
|
||||||
Tensor mC =
|
Tensor mC =
|
||||||
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // M x N
|
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N)
|
||||||
|
|
||||||
// global tile tensor
|
// Get the appropriate blocks for this thread block
|
||||||
auto cta_tiler = CtaTiler{};
|
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||||
auto cta_coord = make_coord(blockIdx.y, blockIdx.x, _);
|
|
||||||
Tensor gA = local_tile(
|
Tensor gA = local_tile(
|
||||||
mA,
|
mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
|
||||||
cta_tiler,
|
|
||||||
cta_coord,
|
|
||||||
Step<_1, X, _1>{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_TILES_K
|
|
||||||
Tensor gB = local_tile(
|
Tensor gB = local_tile(
|
||||||
mB,
|
mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
|
||||||
cta_tiler,
|
Tensor gC =
|
||||||
cta_coord,
|
local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)
|
||||||
Step<X, _1, _1>{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_TILES_K
|
|
||||||
Tensor gC = local_tile(
|
|
||||||
mC,
|
|
||||||
cta_tiler,
|
|
||||||
cta_coord,
|
|
||||||
Step<_1, _1, X>{}); // BLOCK_SIZE_M x BLOCK_SIZE_N
|
|
||||||
|
|
||||||
// shared memory
|
|
||||||
// __shared__ T Asmem[cosize_v<SmemLayoutA>];
|
|
||||||
// __shared__ T Bsmem[cosize_v<SmemLayoutB>];
|
|
||||||
|
|
||||||
extern __shared__ T smem[];
|
|
||||||
T* Asmem = smem;
|
|
||||||
T* Bsmem = smem + cosize_v<SmemLayoutA>;
|
|
||||||
|
|
||||||
|
// Shared memory buffers
|
||||||
|
extern __shared__ char shared_memory[];
|
||||||
|
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
|
||||||
|
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||||
Tensor sA = make_tensor(
|
Tensor sA = make_tensor(
|
||||||
make_smem_ptr(Asmem),
|
make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE)
|
||||||
SmemLayoutA{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_STAGES
|
|
||||||
Tensor sB = make_tensor(
|
Tensor sB = make_tensor(
|
||||||
make_smem_ptr(Bsmem),
|
make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE)
|
||||||
SmemLayoutB{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_STAGES
|
|
||||||
|
|
||||||
// MMA
|
//
|
||||||
// use TiledMMA --> get one thread work
|
// Partition the copying of A and B tiles across the threads
|
||||||
auto tiled_mma = TiledMMA{};
|
//
|
||||||
ThrMMA thr_mma = tiled_mma.get_slice(threadIdx.x);
|
|
||||||
auto tCgC = thr_mma.partition_C(gC); // MMA x MMA_M x MMA_N
|
|
||||||
|
|
||||||
// thread private memory for MMA
|
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
|
||||||
auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // MMA x MMA_M x MMA_K
|
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
|
||||||
auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // MMA x MMA_N x MMA_K
|
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE)
|
||||||
|
|
||||||
// thread private memory for accumulator for MMA
|
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
|
||||||
Tensor tCrC = thr_mma.partition_fragment_C(gC); // MMA x MMA_M x MMA_N
|
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
|
||||||
|
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE)
|
||||||
|
|
||||||
clear(tCrC);
|
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
|
||||||
|
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
|
||||||
|
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
|
||||||
|
|
||||||
// initiate copy from global memory to shared memory
|
//
|
||||||
// use G2S TiledCopy --> get one thread copy work
|
// PREFETCH
|
||||||
auto g2s_tiled_copy_a = G2STiledCopyA{};
|
//
|
||||||
auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(threadIdx.x);
|
|
||||||
const auto tAgA =
|
|
||||||
g2s_thr_copy_a.partition_S(gA); // CPY x CPY_M x CPY_K x NUM_TILES_K
|
|
||||||
auto tAsA =
|
|
||||||
g2s_thr_copy_a.partition_D(sA); // CPY x CPY_M x CPY_K x NUM_STAGES
|
|
||||||
|
|
||||||
auto g2s_tiled_copy_b = G2STiledCopyB{};
|
auto K_PIPE_MAX = size<3>(tAsA);
|
||||||
auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(threadIdx.x);
|
|
||||||
const auto tBgB =
|
|
||||||
g2s_thr_copy_b.partition_S(gB); // CPY x CPY_N x CPY_K x NUM_TILES_K
|
|
||||||
auto tBsB =
|
|
||||||
g2s_thr_copy_b.partition_D(sB); // CPY x CPY_N x CPY_K x NUM_STAGES
|
|
||||||
|
|
||||||
// initiate copy from shared memory to thread private memory
|
// Total count of tiles
|
||||||
// use S2R TiledCopy --> get one thread copy work
|
int k_tile_count = size<3>(tAgA);
|
||||||
auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma);
|
// Current tile index in gmem to read from
|
||||||
auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(threadIdx.x);
|
int k_tile_next = 0;
|
||||||
const auto tCsA =
|
|
||||||
s2r_thr_copy_a.partition_S(sA); // CPY x CPY_M x CPY_K x NUM_STAGES
|
|
||||||
auto tCrA_copy_view = s2r_thr_copy_a.retile_D(tCrA); // CPY x CPY_M x CPY_K
|
|
||||||
|
|
||||||
auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma);
|
|
||||||
auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(threadIdx.x);
|
|
||||||
const auto tCsB =
|
|
||||||
s2r_thr_copy_b.partition_S(sB); // CPY x CPY_N x CPY_K x NUM_STAGES
|
|
||||||
auto tCrB_copy_view = s2r_thr_copy_b.retile_D(tCrB); // CPY x CPY_N x CPY_K
|
|
||||||
|
|
||||||
// pipeline
|
|
||||||
// counter
|
|
||||||
int itile_to_read = 0; // read index of the next tile
|
|
||||||
// 2 pointers of the buffer
|
|
||||||
int ismem_write = 0;
|
|
||||||
int ismem_read = 0;
|
|
||||||
|
|
||||||
// NUM_STAGES = 5 --> Prefetech NUM_STAGES-1 = 4 tiles first
|
|
||||||
auto NUM_STAGES = size<3>(tAsA);
|
|
||||||
|
|
||||||
|
// Start async loads for all pipes but the last
|
||||||
CUTE_UNROLL
|
CUTE_UNROLL
|
||||||
for (int stage = 0; stage < NUM_STAGES - 1; ++stage) {
|
for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) {
|
||||||
// prefetch
|
copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe));
|
||||||
// issue copy
|
copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe));
|
||||||
copy(g2s_tiled_copy_a, tAgA(_, _, _, itile_to_read), tAsA(_, _, _, stage));
|
|
||||||
copy(g2s_tiled_copy_b, tBgB(_, _, _, itile_to_read), tBsB(_, _, _, stage));
|
|
||||||
|
|
||||||
// commit
|
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
|
--k_tile_count;
|
||||||
ismem_write++;
|
if (k_tile_count > 0) {
|
||||||
itile_to_read++;
|
++k_tile_next;
|
||||||
}
|
|
||||||
|
|
||||||
// wait for first tile to be prefetched: G^0 -> S^0
|
|
||||||
cp_async_wait<NUM_STAGES - 2>();
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Having S^0, copy from S^0,0 to R^0
|
|
||||||
int k = 0;
|
|
||||||
copy(s2r_tiled_copy_a, tCsA(_, _, k, ismem_read), tCrA_copy_view(_, _, k));
|
|
||||||
copy(s2r_tiled_copy_b, tCsB(_, _, k, ismem_read), tCrB_copy_view(_, _, k));
|
|
||||||
|
|
||||||
// loop over tiles
|
|
||||||
auto NUM_TILES_K = size<3>(tAgA);
|
|
||||||
|
|
||||||
CUTE_NO_UNROLL
|
|
||||||
for (int tile = 0; tile < NUM_TILES_K; ++tile) {
|
|
||||||
auto MMA_K = size<2>(tCrA);
|
|
||||||
// loop over MMAs in direction of K
|
|
||||||
|
|
||||||
CUTE_UNROLL
|
|
||||||
for (int k = 0; k < MMA_K; ++k) {
|
|
||||||
int k_next = (k + 1) % MMA_K;
|
|
||||||
|
|
||||||
// if this is the second last MMA, wait the next tile to be fetched
|
|
||||||
if (k == MMA_K - 1) {
|
|
||||||
cp_async_wait<NUM_STAGES - 2>();
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
ismem_read = (ismem_read + 1) % NUM_STAGES;
|
|
||||||
}
|
|
||||||
|
|
||||||
// load data for the next MMA, from S^tile to registers
|
|
||||||
copy(
|
|
||||||
s2r_tiled_copy_a,
|
|
||||||
tCsA(_, _, k_next, ismem_read),
|
|
||||||
tCrA_copy_view(_, _, k_next));
|
|
||||||
copy(
|
|
||||||
s2r_tiled_copy_b,
|
|
||||||
tCsB(_, _, k_next, ismem_read),
|
|
||||||
tCrB_copy_view(_, _, k_next));
|
|
||||||
|
|
||||||
if (k == 0) {
|
|
||||||
// prefetch the next tile
|
|
||||||
// issue copy
|
|
||||||
if (itile_to_read < NUM_TILES_K) {
|
|
||||||
copy(
|
|
||||||
g2s_tiled_copy_a,
|
|
||||||
tAgA(_, _, _, itile_to_read),
|
|
||||||
tAsA(_, _, _, ismem_write));
|
|
||||||
copy(
|
|
||||||
g2s_tiled_copy_b,
|
|
||||||
tBgB(_, _, _, itile_to_read),
|
|
||||||
tBsB(_, _, _, ismem_write));
|
|
||||||
|
|
||||||
itile_to_read++;
|
|
||||||
ismem_write = (ismem_write + 1) % NUM_STAGES;
|
|
||||||
}
|
|
||||||
// commit
|
|
||||||
cp_async_fence();
|
|
||||||
}
|
|
||||||
|
|
||||||
// mma
|
|
||||||
gemm(tiled_mma, tCrC, tCrA(_, _, k), tCrB(_, _, k), tCrC);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Define A/B partitioning and C accumulators
|
||||||
|
//
|
||||||
|
|
||||||
|
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
|
||||||
|
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
|
||||||
|
|
||||||
|
// Allocate registers for pipelining
|
||||||
|
Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K)
|
||||||
|
Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K)
|
||||||
|
// Allocate the accumulators -- same size as the projected data
|
||||||
|
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
||||||
|
|
||||||
|
CUTE_STATIC_ASSERT_V(
|
||||||
|
(shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
|
||||||
|
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M
|
||||||
|
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N
|
||||||
|
|
||||||
|
// Clear the accumulators
|
||||||
|
clear(tCrC);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Copy Atom retiling
|
||||||
|
//
|
||||||
|
|
||||||
|
TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
|
||||||
|
ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x);
|
||||||
|
Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE)
|
||||||
|
Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K)
|
||||||
|
|
||||||
|
TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
|
||||||
|
ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x);
|
||||||
|
Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE)
|
||||||
|
Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K)
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
if(thread0()) {
|
||||||
|
print(" mA : "); print( mA); print("\n");
|
||||||
|
print(" gA : "); print( gA); print("\n");
|
||||||
|
print(" sA : "); print( sA); print("\n");
|
||||||
|
print("tAgA : "); print(tAgA); print("\n");
|
||||||
|
print("tAsA : "); print(tAsA); print("\n");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
if(thread0()) {
|
||||||
|
print(" mB : "); print( mB); print("\n");
|
||||||
|
print(" gB : "); print( gB); print("\n");
|
||||||
|
print(" sB : "); print( sB); print("\n");
|
||||||
|
print("tBgB : "); print(tBgB); print("\n");
|
||||||
|
print("tBsB : "); print(tBsB); print("\n");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
if(thread0()) {
|
||||||
|
print(" mC : "); print( mC); print("\n");
|
||||||
|
print(" gC : "); print( gC); print("\n");
|
||||||
|
print("tCgC : "); print(tCgC); print("\n");
|
||||||
|
print("tCrA : "); print(tCrA); print("\n");
|
||||||
|
print("tCrB : "); print(tCrB); print("\n");
|
||||||
|
print("tCrC : "); print(tCrC); print("\n");
|
||||||
|
|
||||||
|
print("tXsA : "); print(tXsA); print("\n");
|
||||||
|
print("tXrA : "); print(tXrA); print("\n");
|
||||||
|
print("tXsB : "); print(tXsB); print("\n");
|
||||||
|
print("tXrB : "); print(tXrB); print("\n");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
|
||||||
|
// Current pipe index in smem to read from
|
||||||
|
int smem_pipe_read = 0;
|
||||||
|
// Current pipe index in smem to write to
|
||||||
|
int smem_pipe_write = K_PIPE_MAX - 1;
|
||||||
|
|
||||||
|
// Pipe slice
|
||||||
|
Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read);
|
||||||
|
Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read);
|
||||||
|
|
||||||
|
// Size of the register pipeline
|
||||||
|
auto K_BLOCK_MAX = size<2>(tCrA);
|
||||||
|
CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA));
|
||||||
|
|
||||||
|
// PREFETCH register pipeline
|
||||||
|
if (K_BLOCK_MAX > 1) {
|
||||||
|
// Wait until our first prefetched tile is loaded in
|
||||||
|
cp_async_wait<K_PIPE_MAX - 2>();
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Prefetch the first rmem from the first k-tile
|
||||||
|
copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{}));
|
||||||
|
copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{}));
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// PIPELINED MAIN LOOP
|
||||||
|
// TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's
|
||||||
|
// cp.async instructions
|
||||||
|
// and explicit pipelines in shared memory.
|
||||||
|
// Data is read from global(k_tile_next) to shared(smem_pipe_write).
|
||||||
|
// Data is read from shared(smem_pipe_read) to registers(k_block_next).
|
||||||
|
// Data is computed on registers(b_block).
|
||||||
|
//
|
||||||
|
// This allows all copies and compute to overlap:
|
||||||
|
// Copy from gmem->smem can overlap with copies from smem->rmem and
|
||||||
|
// compute on rmem. Copy from smem->rmem can overlap with compute on rmem.
|
||||||
|
//
|
||||||
|
|
||||||
|
CUTE_NO_UNROLL
|
||||||
|
while (k_tile_count > -(K_PIPE_MAX - 1)) {
|
||||||
|
CUTE_UNROLL
|
||||||
|
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
|
||||||
|
if (k_block == K_BLOCK_MAX - 1) {
|
||||||
|
// Slice the smem_pipe_read smem
|
||||||
|
tXsA_p = tXsA(_, _, _, smem_pipe_read);
|
||||||
|
tXsB_p = tXsB(_, _, _, smem_pipe_read);
|
||||||
|
|
||||||
|
// Commit the smem for smem_pipe_read
|
||||||
|
cp_async_wait<K_PIPE_MAX - 2>();
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load A, B shmem->regs for k_block+1
|
||||||
|
auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static
|
||||||
|
copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next));
|
||||||
|
copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next));
|
||||||
|
// Copy gmem to smem before computing gemm on each k-pipe
|
||||||
|
if (k_block == 0) {
|
||||||
|
copy(
|
||||||
|
copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write));
|
||||||
|
copy(
|
||||||
|
copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write));
|
||||||
|
cp_async_fence();
|
||||||
|
|
||||||
|
// Advance the gmem tile
|
||||||
|
--k_tile_count;
|
||||||
|
if (k_tile_count > 0) {
|
||||||
|
++k_tile_next;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance the smem pipe
|
||||||
|
smem_pipe_write = smem_pipe_read;
|
||||||
|
smem_pipe_read =
|
||||||
|
(smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1;
|
||||||
|
}
|
||||||
|
// Thread-level register gemm for k_block
|
||||||
|
gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//
|
||||||
|
// Epilogue
|
||||||
|
//
|
||||||
|
|
||||||
copy(tCrC, tCgC);
|
copy(tCrC, tCgC);
|
||||||
// constexpr T alpha{1.0f};
|
|
||||||
// constexpr T beta{0.0f};
|
|
||||||
// axpby(alpha, tCrC, beta, tCgC);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -243,144 +339,103 @@ void cutlass_gemm(
|
|||||||
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
|
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
// block shape and cta tiler
|
// Define shapes (dynamic)
|
||||||
// additional dim: NUM_STAGES --> This is for later pipelining the k-slice
|
auto prob_shape = make_shape(M, N, K);
|
||||||
// GEMM
|
|
||||||
auto BLOCK_SIZE_M = _128{};
|
|
||||||
auto BLOCK_SIZE_N = _128{};
|
|
||||||
auto BLOCK_SIZE_K = _32{};
|
|
||||||
auto NUM_STAGES = _5{};
|
|
||||||
using CtaTiler =
|
|
||||||
decltype(make_shape(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K));
|
|
||||||
|
|
||||||
// smem layout
|
// Define TN strides (mixed)
|
||||||
// Swizzle parameters need to be chosen right
|
auto dA = make_stride(K, Int<1>{});
|
||||||
static constexpr int kShmLoadSwizzleM = 3;
|
auto dB = make_stride(K, Int<1>{});
|
||||||
static constexpr int kShmLoadSwizzleS = 3;
|
auto dC = make_stride(N, Int<1>{});
|
||||||
static constexpr int kShmLoadSwizzleB = 3;
|
|
||||||
|
|
||||||
using SmemLayoutAtom = decltype(composition(
|
// Define CTA tile sizes (static)
|
||||||
Swizzle<kShmLoadSwizzleB, kShmLoadSwizzleM, kShmLoadSwizzleS>{},
|
auto bM = Int<128>{};
|
||||||
make_layout(make_shape(_8{}, BLOCK_SIZE_K), LayoutRight{})));
|
auto bN = Int<128>{};
|
||||||
|
auto bK = Int<64>{};
|
||||||
|
auto cta_tiler = make_shape(bM, bN, bK);
|
||||||
|
auto bP = Int<3>{};
|
||||||
|
|
||||||
// what does this do?
|
// Define the smem layouts (static)
|
||||||
// with BLOCK_SIZE_K = 32, shape: (8, 32)
|
// Swizzles for LDSM and 128b k-major loads
|
||||||
// 2^M = 8, 1 new unit = 8 units --> 1 row contains 32/8 = 4 new units
|
auto swizzle_atom = composition(
|
||||||
// 2^S = 8, it will treat 1 row = 8 new units --> do 8-unit swizzle
|
Swizzle<3, 3, 3>{},
|
||||||
// 2^B = 8, it will reset the swizzle pattern after 8 rows
|
Layout<
|
||||||
// print_layout(SmemLayoutAtom{});
|
cute::Shape<_8, cute::Shape<_8, _8>>,
|
||||||
|
cute::Stride<_8, cute::Stride<_1, _64>>>{});
|
||||||
|
|
||||||
// tile_to_shape extends the layout in LayoutLeft order
|
auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP));
|
||||||
using SmemLayoutA = decltype(tile_to_shape(
|
auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP));
|
||||||
SmemLayoutAtom{},
|
auto sC = make_layout(make_shape(bM, bN));
|
||||||
make_shape(BLOCK_SIZE_M, BLOCK_SIZE_K, NUM_STAGES)));
|
|
||||||
|
|
||||||
using SmemLayoutB = decltype(tile_to_shape(
|
// Define the thread layouts (static)
|
||||||
SmemLayoutAtom{},
|
|
||||||
make_shape(BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES)));
|
|
||||||
|
|
||||||
// TiledMMA
|
TiledCopy copyA = make_tiled_copy(
|
||||||
using mma_op = SM80_16x8x16_F32BF16BF16F32_TN;
|
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
|
||||||
using mma_traits = MMA_Traits<mma_op>;
|
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
|
||||||
using mma_atom = MMA_Atom<mma_traits>;
|
// 16x8 k-major
|
||||||
|
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 k-major
|
||||||
|
TiledCopy copyB = make_tiled_copy(
|
||||||
|
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
|
||||||
|
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
|
||||||
|
// 16x8 k-major
|
||||||
|
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 n-major
|
||||||
|
|
||||||
static constexpr int kMmaEURepeatM = 2;
|
TiledMMA mmaC = make_tiled_mma(
|
||||||
static constexpr int kMmaEURepeatN = 2;
|
SM80_16x8x16_F32BF16BF16F32_TN{},
|
||||||
static constexpr int kMmaEURepeatK = 1;
|
Layout<cute::Shape<_2, _2>>{}, // 2x2x1 MMA Atoms
|
||||||
// 32 x 2 x 2 = 128 threads
|
Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM
|
||||||
|
|
||||||
using mma_atom_shape = mma_traits::Shape_MNK;
|
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_A;
|
||||||
static constexpr int MmaVM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{});
|
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_B;
|
||||||
static constexpr int MmaVN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{});
|
|
||||||
static constexpr int MmaVK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{});
|
|
||||||
|
|
||||||
// this is for problem shape (16x2) x (8x2x2) x (16x1) = 32x32x16
|
int smem_size = int(sizeof(SharedStorage<
|
||||||
using MMA_EU_RepeatT = decltype(make_layout(make_shape(
|
cute::bfloat16_t,
|
||||||
Int<kMmaEURepeatM>{}, Int<kMmaEURepeatN>{}, Int<kMmaEURepeatK>{})));
|
cute::bfloat16_t,
|
||||||
using MMA_V_T = Tile<Int<MmaVM>, Int<MmaVN>, Int<MmaVK>>;
|
decltype(sA),
|
||||||
|
decltype(sB)>));
|
||||||
|
dim3 dimBlock(size(mmaC));
|
||||||
|
dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN)));
|
||||||
|
|
||||||
using TiledMMA =
|
auto kernel = gemm_device<
|
||||||
decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_V_T{}));
|
decltype(prob_shape),
|
||||||
|
decltype(cta_tiler),
|
||||||
|
cute::bfloat16_t,
|
||||||
|
decltype(dA),
|
||||||
|
decltype(sA),
|
||||||
|
decltype(copyA),
|
||||||
|
decltype(s2r_atom_A),
|
||||||
|
cute::bfloat16_t,
|
||||||
|
decltype(dB),
|
||||||
|
decltype(sB),
|
||||||
|
decltype(copyB),
|
||||||
|
decltype(s2r_atom_B),
|
||||||
|
cute::bfloat16_t,
|
||||||
|
decltype(dC),
|
||||||
|
decltype(sC),
|
||||||
|
decltype(mmaC)>;
|
||||||
|
|
||||||
// TiledCopy from global memory to shared memory
|
configure_matmul(kernel, smem_size);
|
||||||
// uint128_t is 16 bytes = 4 floats = 8 halfs
|
|
||||||
static constexpr int NUM_VECTOR_UNITS =
|
|
||||||
sizeof(cute::uint128_t) / sizeof(DataType);
|
|
||||||
|
|
||||||
using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
|
|
||||||
using g2s_copy_traits = Copy_Traits<g2s_copy_op>;
|
|
||||||
using g2s_copy_atom = Copy_Atom<g2s_copy_traits, DataType>;
|
|
||||||
|
|
||||||
// one block contains 128 threads
|
|
||||||
// --> find the compatible thread layout
|
|
||||||
using G2S_Copy_Thread_Layout = decltype(make_layout(
|
|
||||||
make_shape(_32{}, _4{}), // 32x4 = 128 threads
|
|
||||||
LayoutRight{} // A is in row-major
|
|
||||||
));
|
|
||||||
|
|
||||||
using G2S_Copy_Value_Layout =
|
|
||||||
decltype(make_layout(make_shape(_1{}, Int<NUM_VECTOR_UNITS>{})));
|
|
||||||
|
|
||||||
// This is for copy shape 32x4 of uint128_t
|
|
||||||
using G2STiledCopyA = decltype(make_tiled_copy(
|
|
||||||
g2s_copy_atom{}, G2S_Copy_Thread_Layout{}, G2S_Copy_Value_Layout{}));
|
|
||||||
|
|
||||||
// Both A and B are in row-major so use the same TiledCopy for B
|
|
||||||
using G2STiledCopyB = G2STiledCopyA;
|
|
||||||
|
|
||||||
// CopyAtom from shared memory to registers
|
|
||||||
// Why no need to do tiling atom here? Because we will do it later with
|
|
||||||
// the information from TiledMMA
|
|
||||||
using s2r_copy_op = SM75_U32x4_LDSM_N;
|
|
||||||
using s2r_copy_traits = Copy_Traits<s2r_copy_op>;
|
|
||||||
using s2r_copy_atom = Copy_Atom<s2r_copy_traits, DataType>;
|
|
||||||
|
|
||||||
using S2RCopyAtomA = s2r_copy_atom;
|
|
||||||
using S2RCopyAtomB = s2r_copy_atom;
|
|
||||||
|
|
||||||
// print_latex(
|
|
||||||
// make_tiled_copy(S2RCopyAtomA{}, make_layout(Shape<_32>{}),
|
|
||||||
// make_layout(Shape<_1, _8>{}))
|
|
||||||
// );
|
|
||||||
|
|
||||||
// grid, block
|
|
||||||
dim3 block{size(TiledMMA{}), 1U, 1U};
|
|
||||||
dim3 grid{
|
|
||||||
size(ceil_div(static_cast<unsigned int>(N), BLOCK_SIZE_N)),
|
|
||||||
size(ceil_div(static_cast<unsigned int>(M), BLOCK_SIZE_M)),
|
|
||||||
1U};
|
|
||||||
|
|
||||||
static constexpr int smem_size =
|
|
||||||
(cosize_v<SmemLayoutA> + cosize_v<SmemLayoutB>)*sizeof(DataType);
|
|
||||||
|
|
||||||
auto kernel = cute_gemm_v02<
|
|
||||||
DataType,
|
|
||||||
CtaTiler,
|
|
||||||
SmemLayoutA,
|
|
||||||
SmemLayoutB,
|
|
||||||
TiledMMA,
|
|
||||||
G2STiledCopyA,
|
|
||||||
G2STiledCopyB,
|
|
||||||
S2RCopyAtomA,
|
|
||||||
S2RCopyAtomB>;
|
|
||||||
|
|
||||||
cudaFuncSetAttribute(
|
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
|
||||||
|
|
||||||
enc.add_kernel_node(
|
enc.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
grid,
|
dimGrid,
|
||||||
block,
|
dimBlock,
|
||||||
smem_size,
|
smem_size,
|
||||||
M,
|
prob_shape,
|
||||||
N,
|
cta_tiler,
|
||||||
K,
|
a.data<cute::bfloat16_t>(),
|
||||||
a.data<DataType>(),
|
dA,
|
||||||
K,
|
sA,
|
||||||
b.data<DataType>(),
|
copyA,
|
||||||
K,
|
s2r_atom_A,
|
||||||
out.data<DataType>(),
|
b.data<cute::bfloat16_t>(),
|
||||||
N);
|
dB,
|
||||||
|
sB,
|
||||||
|
copyB,
|
||||||
|
s2r_atom_B,
|
||||||
|
out.data<cute::bfloat16_t>(),
|
||||||
|
dC,
|
||||||
|
sC,
|
||||||
|
mmaC);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Only bfloat16 supported");
|
throw std::runtime_error("Only bfloat16 supported");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user