Reset cutlass gemm to working state again

This commit is contained in:
Angelos Katharopoulos 2025-08-21 01:29:43 -07:00
parent cf5eef095d
commit e1303f6160

View File

@ -10,219 +10,315 @@ namespace mlx::core::cu {
namespace { namespace {
template <typename Kernel>
void configure_matmul(Kernel kernel, int smem_size) {
static bool initialized = false;
if (!initialized) {
initialized = true;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
cudaFuncSetAttribute(
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
}
}
template <class ElementA, class ElementB, class SmemLayoutA, class SmemLayoutB>
struct SharedStorage {
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A;
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B;
};
template < template <
class T, class ProblemShape,
class CtaTiler, class CtaTiler,
class SmemLayoutA, class TA,
class SmemLayoutB, class AStride,
class TiledMMA, class ASmemLayout,
class G2STiledCopyA, class TiledCopyA,
class G2STiledCopyB, class S2RAtomA,
class S2RCopyAtomA, class TB,
class S2RCopyAtomB> class BStride,
__global__ void cute_gemm_v02( class BSmemLayout,
unsigned int M, class TiledCopyB,
unsigned int N, class S2RAtomB,
unsigned int K, class TC,
const T* A, class CStride,
size_t lda, class CSmemLayout,
const T* B, class TiledMma>
size_t ldb, __global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device(
T* C, ProblemShape shape_MNK,
size_t ldc) { CtaTiler cta_tiler,
TA const* A,
AStride dA,
ASmemLayout sA_layout,
TiledCopyA copy_a,
S2RAtomA s2r_atom_a,
TB const* B,
BStride dB,
BSmemLayout sB_layout,
TiledCopyB copy_b,
S2RAtomB s2r_atom_b,
TC* C,
CStride dC,
CSmemLayout,
TiledMma mma) {
using namespace cute; using namespace cute;
// global full tensor // Preconditions
// shape CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
auto shape_MNK = make_shape(M, N, K); CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
// stride CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
// cublas covenience for TN gemm CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
// all matrices are in column major
// A (m,k) --> transpose --> A(k, m) --> cute layout: A (m, k) : (k, 1) --> static_assert(is_static<ASmemLayout>::value);
// lda = k B (k,n) --> cute layout: B (n, k) : (k, 1) --> ldb = k C (m,n) --> static_assert(is_static<BSmemLayout>::value);
// cute layout: C (m, n) : (1, m) --> ldc = m static_assert(is_static<CSmemLayout>::value);
auto dA = make_stride(lda, _1{});
auto dB = make_stride(ldb, _1{}); CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
auto dC = make_stride(_1{}, ldc); CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(
congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(
congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(
congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = Tensor mA =
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // M x K make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K)
Tensor mB = Tensor mB =
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // N x K make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K)
Tensor mC = Tensor mC =
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // M x N make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N)
// global tile tensor // Get the appropriate blocks for this thread block
auto cta_tiler = CtaTiler{}; auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
auto cta_coord = make_coord(blockIdx.y, blockIdx.x, _);
Tensor gA = local_tile( Tensor gA = local_tile(
mA, mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
cta_tiler,
cta_coord,
Step<_1, X, _1>{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_TILES_K
Tensor gB = local_tile( Tensor gB = local_tile(
mB, mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
cta_tiler, Tensor gC =
cta_coord, local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)
Step<X, _1, _1>{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_TILES_K
Tensor gC = local_tile(
mC,
cta_tiler,
cta_coord,
Step<_1, _1, X>{}); // BLOCK_SIZE_M x BLOCK_SIZE_N
// shared memory
// __shared__ T Asmem[cosize_v<SmemLayoutA>];
// __shared__ T Bsmem[cosize_v<SmemLayoutB>];
extern __shared__ T smem[];
T* Asmem = smem;
T* Bsmem = smem + cosize_v<SmemLayoutA>;
// Shared memory buffers
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor sA = make_tensor( Tensor sA = make_tensor(
make_smem_ptr(Asmem), make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE)
SmemLayoutA{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_STAGES
Tensor sB = make_tensor( Tensor sB = make_tensor(
make_smem_ptr(Bsmem), make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE)
SmemLayoutB{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_STAGES
// MMA //
// use TiledMMA --> get one thread work // Partition the copying of A and B tiles across the threads
auto tiled_mma = TiledMMA{}; //
ThrMMA thr_mma = tiled_mma.get_slice(threadIdx.x);
auto tCgC = thr_mma.partition_C(gC); // MMA x MMA_M x MMA_N
// thread private memory for MMA ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // MMA x MMA_M x MMA_K Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // MMA x MMA_N x MMA_K Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE)
// thread private memory for accumulator for MMA ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tCrC = thr_mma.partition_fragment_C(gC); // MMA x MMA_M x MMA_N Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE)
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
//
// PREFETCH
//
auto K_PIPE_MAX = size<3>(tAsA);
// Total count of tiles
int k_tile_count = size<3>(tAgA);
// Current tile index in gmem to read from
int k_tile_next = 0;
// Start async loads for all pipes but the last
CUTE_UNROLL
for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) {
copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe));
copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe));
cp_async_fence();
--k_tile_count;
if (k_tile_count > 0) {
++k_tile_next;
}
}
//
// Define A/B partitioning and C accumulators
//
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate registers for pipelining
Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(
(shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N
// Clear the accumulators
clear(tCrC); clear(tCrC);
// initiate copy from global memory to shared memory //
// use G2S TiledCopy --> get one thread copy work // Copy Atom retiling
auto g2s_tiled_copy_a = G2STiledCopyA{}; //
auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(threadIdx.x);
const auto tAgA =
g2s_thr_copy_a.partition_S(gA); // CPY x CPY_M x CPY_K x NUM_TILES_K
auto tAsA =
g2s_thr_copy_a.partition_D(sA); // CPY x CPY_M x CPY_K x NUM_STAGES
auto g2s_tiled_copy_b = G2STiledCopyB{}; TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(threadIdx.x); ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x);
const auto tBgB = Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE)
g2s_thr_copy_b.partition_S(gB); // CPY x CPY_N x CPY_K x NUM_TILES_K Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K)
auto tBsB =
g2s_thr_copy_b.partition_D(sB); // CPY x CPY_N x CPY_K x NUM_STAGES
// initiate copy from shared memory to thread private memory TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
// use S2R TiledCopy --> get one thread copy work ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x);
auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma); Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE)
auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(threadIdx.x); Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K)
const auto tCsA =
s2r_thr_copy_a.partition_S(sA); // CPY x CPY_M x CPY_K x NUM_STAGES
auto tCrA_copy_view = s2r_thr_copy_a.retile_D(tCrA); // CPY x CPY_M x CPY_K
auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma); #if 0
auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(threadIdx.x); if(thread0()) {
const auto tCsB = print(" mA : "); print( mA); print("\n");
s2r_thr_copy_b.partition_S(sB); // CPY x CPY_N x CPY_K x NUM_STAGES print(" gA : "); print( gA); print("\n");
auto tCrB_copy_view = s2r_thr_copy_b.retile_D(tCrB); // CPY x CPY_N x CPY_K print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
// pipeline print("tAsA : "); print(tAsA); print("\n");
// counter
int itile_to_read = 0; // read index of the next tile
// 2 pointers of the buffer
int ismem_write = 0;
int ismem_read = 0;
// NUM_STAGES = 5 --> Prefetech NUM_STAGES-1 = 4 tiles first
auto NUM_STAGES = size<3>(tAsA);
CUTE_UNROLL
for (int stage = 0; stage < NUM_STAGES - 1; ++stage) {
// prefetch
// issue copy
copy(g2s_tiled_copy_a, tAgA(_, _, _, itile_to_read), tAsA(_, _, _, stage));
copy(g2s_tiled_copy_b, tBgB(_, _, _, itile_to_read), tBsB(_, _, _, stage));
// commit
cp_async_fence();
ismem_write++;
itile_to_read++;
} }
#endif
// wait for first tile to be prefetched: G^0 -> S^0 #if 0
cp_async_wait<NUM_STAGES - 2>(); if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrA : "); print(tCrA); print("\n");
print("tCrB : "); print(tCrB); print("\n");
print("tCrC : "); print(tCrC); print("\n");
print("tXsA : "); print(tXsA); print("\n");
print("tXrA : "); print(tXrA); print("\n");
print("tXsB : "); print(tXsB); print("\n");
print("tXrB : "); print(tXrB); print("\n");
}
#endif
#if 1
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = K_PIPE_MAX - 1;
// Pipe slice
Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read);
Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read);
// Size of the register pipeline
auto K_BLOCK_MAX = size<2>(tCrA);
CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA));
// PREFETCH register pipeline
if (K_BLOCK_MAX > 1) {
// Wait until our first prefetched tile is loaded in
cp_async_wait<K_PIPE_MAX - 2>();
__syncthreads(); __syncthreads();
// Having S^0, copy from S^0,0 to R^0 // Prefetch the first rmem from the first k-tile
int k = 0; copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{}));
copy(s2r_tiled_copy_a, tCsA(_, _, k, ismem_read), tCrA_copy_view(_, _, k)); copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{}));
copy(s2r_tiled_copy_b, tCsB(_, _, k, ismem_read), tCrB_copy_view(_, _, k)); }
// loop over tiles //
auto NUM_TILES_K = size<3>(tAgA); // PIPELINED MAIN LOOP
// TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's
// cp.async instructions
// and explicit pipelines in shared memory.
// Data is read from global(k_tile_next) to shared(smem_pipe_write).
// Data is read from shared(smem_pipe_read) to registers(k_block_next).
// Data is computed on registers(b_block).
//
// This allows all copies and compute to overlap:
// Copy from gmem->smem can overlap with copies from smem->rmem and
// compute on rmem. Copy from smem->rmem can overlap with compute on rmem.
//
CUTE_NO_UNROLL CUTE_NO_UNROLL
for (int tile = 0; tile < NUM_TILES_K; ++tile) { while (k_tile_count > -(K_PIPE_MAX - 1)) {
auto MMA_K = size<2>(tCrA);
// loop over MMAs in direction of K
CUTE_UNROLL CUTE_UNROLL
for (int k = 0; k < MMA_K; ++k) { for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
int k_next = (k + 1) % MMA_K; if (k_block == K_BLOCK_MAX - 1) {
// Slice the smem_pipe_read smem
tXsA_p = tXsA(_, _, _, smem_pipe_read);
tXsB_p = tXsB(_, _, _, smem_pipe_read);
// if this is the second last MMA, wait the next tile to be fetched // Commit the smem for smem_pipe_read
if (k == MMA_K - 1) { cp_async_wait<K_PIPE_MAX - 2>();
cp_async_wait<NUM_STAGES - 2>();
__syncthreads(); __syncthreads();
ismem_read = (ismem_read + 1) % NUM_STAGES;
} }
// load data for the next MMA, from S^tile to registers // Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static
copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next));
copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0) {
copy( copy(
s2r_tiled_copy_a, copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write));
tCsA(_, _, k_next, ismem_read),
tCrA_copy_view(_, _, k_next));
copy( copy(
s2r_tiled_copy_b, copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write));
tCsB(_, _, k_next, ismem_read),
tCrB_copy_view(_, _, k_next));
if (k == 0) {
// prefetch the next tile
// issue copy
if (itile_to_read < NUM_TILES_K) {
copy(
g2s_tiled_copy_a,
tAgA(_, _, _, itile_to_read),
tAsA(_, _, _, ismem_write));
copy(
g2s_tiled_copy_b,
tBgB(_, _, _, itile_to_read),
tBsB(_, _, _, ismem_write));
itile_to_read++;
ismem_write = (ismem_write + 1) % NUM_STAGES;
}
// commit
cp_async_fence(); cp_async_fence();
// Advance the gmem tile
--k_tile_count;
if (k_tile_count > 0) {
++k_tile_next;
} }
// mma // Advance the smem pipe
gemm(tiled_mma, tCrC, tCrA(_, _, k), tCrB(_, _, k), tCrC); smem_pipe_write = smem_pipe_read;
smem_pipe_read =
(smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1;
}
// Thread-level register gemm for k_block
gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
} }
} }
#endif
//
// Epilogue
//
copy(tCrC, tCgC); copy(tCrC, tCgC);
// constexpr T alpha{1.0f};
// constexpr T beta{0.0f};
// axpby(alpha, tCrC, beta, tCgC);
} }
} // namespace } // namespace
@ -243,144 +339,103 @@ void cutlass_gemm(
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) { if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
using namespace cute; using namespace cute;
// block shape and cta tiler // Define shapes (dynamic)
// additional dim: NUM_STAGES --> This is for later pipelining the k-slice auto prob_shape = make_shape(M, N, K);
// GEMM
auto BLOCK_SIZE_M = _128{};
auto BLOCK_SIZE_N = _128{};
auto BLOCK_SIZE_K = _32{};
auto NUM_STAGES = _5{};
using CtaTiler =
decltype(make_shape(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K));
// smem layout // Define TN strides (mixed)
// Swizzle parameters need to be chosen right auto dA = make_stride(K, Int<1>{});
static constexpr int kShmLoadSwizzleM = 3; auto dB = make_stride(K, Int<1>{});
static constexpr int kShmLoadSwizzleS = 3; auto dC = make_stride(N, Int<1>{});
static constexpr int kShmLoadSwizzleB = 3;
using SmemLayoutAtom = decltype(composition( // Define CTA tile sizes (static)
Swizzle<kShmLoadSwizzleB, kShmLoadSwizzleM, kShmLoadSwizzleS>{}, auto bM = Int<128>{};
make_layout(make_shape(_8{}, BLOCK_SIZE_K), LayoutRight{}))); auto bN = Int<128>{};
auto bK = Int<64>{};
auto cta_tiler = make_shape(bM, bN, bK);
auto bP = Int<3>{};
// what does this do? // Define the smem layouts (static)
// with BLOCK_SIZE_K = 32, shape: (8, 32) // Swizzles for LDSM and 128b k-major loads
// 2^M = 8, 1 new unit = 8 units --> 1 row contains 32/8 = 4 new units auto swizzle_atom = composition(
// 2^S = 8, it will treat 1 row = 8 new units --> do 8-unit swizzle Swizzle<3, 3, 3>{},
// 2^B = 8, it will reset the swizzle pattern after 8 rows Layout<
// print_layout(SmemLayoutAtom{}); cute::Shape<_8, cute::Shape<_8, _8>>,
cute::Stride<_8, cute::Stride<_1, _64>>>{});
// tile_to_shape extends the layout in LayoutLeft order auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP));
using SmemLayoutA = decltype(tile_to_shape( auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP));
SmemLayoutAtom{}, auto sC = make_layout(make_shape(bM, bN));
make_shape(BLOCK_SIZE_M, BLOCK_SIZE_K, NUM_STAGES)));
using SmemLayoutB = decltype(tile_to_shape( // Define the thread layouts (static)
SmemLayoutAtom{},
make_shape(BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES)));
// TiledMMA TiledCopy copyA = make_tiled_copy(
using mma_op = SM80_16x8x16_F32BF16BF16F32_TN; Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
using mma_traits = MMA_Traits<mma_op>; Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
using mma_atom = MMA_Atom<mma_traits>; // 16x8 k-major
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 k-major
TiledCopy copyB = make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
// 16x8 k-major
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 n-major
static constexpr int kMmaEURepeatM = 2; TiledMMA mmaC = make_tiled_mma(
static constexpr int kMmaEURepeatN = 2; SM80_16x8x16_F32BF16BF16F32_TN{},
static constexpr int kMmaEURepeatK = 1; Layout<cute::Shape<_2, _2>>{}, // 2x2x1 MMA Atoms
// 32 x 2 x 2 = 128 threads Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM
using mma_atom_shape = mma_traits::Shape_MNK; Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_A;
static constexpr int MmaVM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{}); Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_B;
static constexpr int MmaVN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{});
static constexpr int MmaVK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{});
// this is for problem shape (16x2) x (8x2x2) x (16x1) = 32x32x16 int smem_size = int(sizeof(SharedStorage<
using MMA_EU_RepeatT = decltype(make_layout(make_shape( cute::bfloat16_t,
Int<kMmaEURepeatM>{}, Int<kMmaEURepeatN>{}, Int<kMmaEURepeatK>{}))); cute::bfloat16_t,
using MMA_V_T = Tile<Int<MmaVM>, Int<MmaVN>, Int<MmaVK>>; decltype(sA),
decltype(sB)>));
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN)));
using TiledMMA = auto kernel = gemm_device<
decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_V_T{})); decltype(prob_shape),
decltype(cta_tiler),
cute::bfloat16_t,
decltype(dA),
decltype(sA),
decltype(copyA),
decltype(s2r_atom_A),
cute::bfloat16_t,
decltype(dB),
decltype(sB),
decltype(copyB),
decltype(s2r_atom_B),
cute::bfloat16_t,
decltype(dC),
decltype(sC),
decltype(mmaC)>;
// TiledCopy from global memory to shared memory configure_matmul(kernel, smem_size);
// uint128_t is 16 bytes = 4 floats = 8 halfs
static constexpr int NUM_VECTOR_UNITS =
sizeof(cute::uint128_t) / sizeof(DataType);
using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
using g2s_copy_traits = Copy_Traits<g2s_copy_op>;
using g2s_copy_atom = Copy_Atom<g2s_copy_traits, DataType>;
// one block contains 128 threads
// --> find the compatible thread layout
using G2S_Copy_Thread_Layout = decltype(make_layout(
make_shape(_32{}, _4{}), // 32x4 = 128 threads
LayoutRight{} // A is in row-major
));
using G2S_Copy_Value_Layout =
decltype(make_layout(make_shape(_1{}, Int<NUM_VECTOR_UNITS>{})));
// This is for copy shape 32x4 of uint128_t
using G2STiledCopyA = decltype(make_tiled_copy(
g2s_copy_atom{}, G2S_Copy_Thread_Layout{}, G2S_Copy_Value_Layout{}));
// Both A and B are in row-major so use the same TiledCopy for B
using G2STiledCopyB = G2STiledCopyA;
// CopyAtom from shared memory to registers
// Why no need to do tiling atom here? Because we will do it later with
// the information from TiledMMA
using s2r_copy_op = SM75_U32x4_LDSM_N;
using s2r_copy_traits = Copy_Traits<s2r_copy_op>;
using s2r_copy_atom = Copy_Atom<s2r_copy_traits, DataType>;
using S2RCopyAtomA = s2r_copy_atom;
using S2RCopyAtomB = s2r_copy_atom;
// print_latex(
// make_tiled_copy(S2RCopyAtomA{}, make_layout(Shape<_32>{}),
// make_layout(Shape<_1, _8>{}))
// );
// grid, block
dim3 block{size(TiledMMA{}), 1U, 1U};
dim3 grid{
size(ceil_div(static_cast<unsigned int>(N), BLOCK_SIZE_N)),
size(ceil_div(static_cast<unsigned int>(M), BLOCK_SIZE_M)),
1U};
static constexpr int smem_size =
(cosize_v<SmemLayoutA> + cosize_v<SmemLayoutB>)*sizeof(DataType);
auto kernel = cute_gemm_v02<
DataType,
CtaTiler,
SmemLayoutA,
SmemLayoutB,
TiledMMA,
G2STiledCopyA,
G2STiledCopyB,
S2RCopyAtomA,
S2RCopyAtomB>;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
enc.add_kernel_node( enc.add_kernel_node(
kernel, kernel,
grid, dimGrid,
block, dimBlock,
smem_size, smem_size,
M, prob_shape,
N, cta_tiler,
K, a.data<cute::bfloat16_t>(),
a.data<DataType>(), dA,
K, sA,
b.data<DataType>(), copyA,
K, s2r_atom_A,
out.data<DataType>(), b.data<cute::bfloat16_t>(),
N); dB,
sB,
copyB,
s2r_atom_B,
out.data<cute::bfloat16_t>(),
dC,
sC,
mmaC);
} else { } else {
throw std::runtime_error("Only bfloat16 supported"); throw std::runtime_error("Only bfloat16 supported");
} }