Reset cutlass gemm to working state again

This commit is contained in:
Angelos Katharopoulos 2025-08-21 01:29:43 -07:00
parent cf5eef095d
commit e1303f6160

View File

@ -10,219 +10,315 @@ namespace mlx::core::cu {
namespace {
template <typename Kernel>
void configure_matmul(Kernel kernel, int smem_size) {
static bool initialized = false;
if (!initialized) {
initialized = true;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
cudaFuncSetAttribute(
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
}
}
template <class ElementA, class ElementB, class SmemLayoutA, class SmemLayoutB>
struct SharedStorage {
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A;
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B;
};
template <
class T,
class ProblemShape,
class CtaTiler,
class SmemLayoutA,
class SmemLayoutB,
class TiledMMA,
class G2STiledCopyA,
class G2STiledCopyB,
class S2RCopyAtomA,
class S2RCopyAtomB>
__global__ void cute_gemm_v02(
unsigned int M,
unsigned int N,
unsigned int K,
const T* A,
size_t lda,
const T* B,
size_t ldb,
T* C,
size_t ldc) {
class TA,
class AStride,
class ASmemLayout,
class TiledCopyA,
class S2RAtomA,
class TB,
class BStride,
class BSmemLayout,
class TiledCopyB,
class S2RAtomB,
class TC,
class CStride,
class CSmemLayout,
class TiledMma>
__global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device(
ProblemShape shape_MNK,
CtaTiler cta_tiler,
TA const* A,
AStride dA,
ASmemLayout sA_layout,
TiledCopyA copy_a,
S2RAtomA s2r_atom_a,
TB const* B,
BStride dB,
BSmemLayout sB_layout,
TiledCopyB copy_b,
S2RAtomB s2r_atom_b,
TC* C,
CStride dC,
CSmemLayout,
TiledMma mma) {
using namespace cute;
// global full tensor
// shape
auto shape_MNK = make_shape(M, N, K);
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
// stride
// cublas covenience for TN gemm
// all matrices are in column major
// A (m,k) --> transpose --> A(k, m) --> cute layout: A (m, k) : (k, 1) -->
// lda = k B (k,n) --> cute layout: B (n, k) : (k, 1) --> ldb = k C (m,n) -->
// cute layout: C (m, n) : (1, m) --> ldc = m
auto dA = make_stride(lda, _1{});
auto dB = make_stride(ldb, _1{});
auto dC = make_stride(_1{}, ldc);
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(
congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(
congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(
congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA =
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // M x K
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K)
Tensor mB =
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // N x K
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K)
Tensor mC =
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // M x N
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N)
// global tile tensor
auto cta_tiler = CtaTiler{};
auto cta_coord = make_coord(blockIdx.y, blockIdx.x, _);
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(
mA,
cta_tiler,
cta_coord,
Step<_1, X, _1>{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_TILES_K
mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(
mB,
cta_tiler,
cta_coord,
Step<X, _1, _1>{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_TILES_K
Tensor gC = local_tile(
mC,
cta_tiler,
cta_coord,
Step<_1, _1, X>{}); // BLOCK_SIZE_M x BLOCK_SIZE_N
// shared memory
// __shared__ T Asmem[cosize_v<SmemLayoutA>];
// __shared__ T Bsmem[cosize_v<SmemLayoutB>];
extern __shared__ T smem[];
T* Asmem = smem;
T* Bsmem = smem + cosize_v<SmemLayoutA>;
mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
Tensor gC =
local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor sA = make_tensor(
make_smem_ptr(Asmem),
SmemLayoutA{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_STAGES
make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(
make_smem_ptr(Bsmem),
SmemLayoutB{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_STAGES
make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE)
// MMA
// use TiledMMA --> get one thread work
auto tiled_mma = TiledMMA{};
ThrMMA thr_mma = tiled_mma.get_slice(threadIdx.x);
auto tCgC = thr_mma.partition_C(gC); // MMA x MMA_M x MMA_N
//
// Partition the copying of A and B tiles across the threads
//
// thread private memory for MMA
auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // MMA x MMA_M x MMA_K
auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // MMA x MMA_N x MMA_K
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE)
// thread private memory for accumulator for MMA
Tensor tCrC = thr_mma.partition_fragment_C(gC); // MMA x MMA_M x MMA_N
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE)
clear(tCrC);
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
// initiate copy from global memory to shared memory
// use G2S TiledCopy --> get one thread copy work
auto g2s_tiled_copy_a = G2STiledCopyA{};
auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(threadIdx.x);
const auto tAgA =
g2s_thr_copy_a.partition_S(gA); // CPY x CPY_M x CPY_K x NUM_TILES_K
auto tAsA =
g2s_thr_copy_a.partition_D(sA); // CPY x CPY_M x CPY_K x NUM_STAGES
//
// PREFETCH
//
auto g2s_tiled_copy_b = G2STiledCopyB{};
auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(threadIdx.x);
const auto tBgB =
g2s_thr_copy_b.partition_S(gB); // CPY x CPY_N x CPY_K x NUM_TILES_K
auto tBsB =
g2s_thr_copy_b.partition_D(sB); // CPY x CPY_N x CPY_K x NUM_STAGES
auto K_PIPE_MAX = size<3>(tAsA);
// initiate copy from shared memory to thread private memory
// use S2R TiledCopy --> get one thread copy work
auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma);
auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(threadIdx.x);
const auto tCsA =
s2r_thr_copy_a.partition_S(sA); // CPY x CPY_M x CPY_K x NUM_STAGES
auto tCrA_copy_view = s2r_thr_copy_a.retile_D(tCrA); // CPY x CPY_M x CPY_K
auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma);
auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(threadIdx.x);
const auto tCsB =
s2r_thr_copy_b.partition_S(sB); // CPY x CPY_N x CPY_K x NUM_STAGES
auto tCrB_copy_view = s2r_thr_copy_b.retile_D(tCrB); // CPY x CPY_N x CPY_K
// pipeline
// counter
int itile_to_read = 0; // read index of the next tile
// 2 pointers of the buffer
int ismem_write = 0;
int ismem_read = 0;
// NUM_STAGES = 5 --> Prefetech NUM_STAGES-1 = 4 tiles first
auto NUM_STAGES = size<3>(tAsA);
// Total count of tiles
int k_tile_count = size<3>(tAgA);
// Current tile index in gmem to read from
int k_tile_next = 0;
// Start async loads for all pipes but the last
CUTE_UNROLL
for (int stage = 0; stage < NUM_STAGES - 1; ++stage) {
// prefetch
// issue copy
copy(g2s_tiled_copy_a, tAgA(_, _, _, itile_to_read), tAsA(_, _, _, stage));
copy(g2s_tiled_copy_b, tBgB(_, _, _, itile_to_read), tBsB(_, _, _, stage));
// commit
for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) {
copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe));
copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe));
cp_async_fence();
ismem_write++;
itile_to_read++;
}
// wait for first tile to be prefetched: G^0 -> S^0
cp_async_wait<NUM_STAGES - 2>();
__syncthreads();
// Having S^0, copy from S^0,0 to R^0
int k = 0;
copy(s2r_tiled_copy_a, tCsA(_, _, k, ismem_read), tCrA_copy_view(_, _, k));
copy(s2r_tiled_copy_b, tCsB(_, _, k, ismem_read), tCrB_copy_view(_, _, k));
// loop over tiles
auto NUM_TILES_K = size<3>(tAgA);
CUTE_NO_UNROLL
for (int tile = 0; tile < NUM_TILES_K; ++tile) {
auto MMA_K = size<2>(tCrA);
// loop over MMAs in direction of K
CUTE_UNROLL
for (int k = 0; k < MMA_K; ++k) {
int k_next = (k + 1) % MMA_K;
// if this is the second last MMA, wait the next tile to be fetched
if (k == MMA_K - 1) {
cp_async_wait<NUM_STAGES - 2>();
__syncthreads();
ismem_read = (ismem_read + 1) % NUM_STAGES;
}
// load data for the next MMA, from S^tile to registers
copy(
s2r_tiled_copy_a,
tCsA(_, _, k_next, ismem_read),
tCrA_copy_view(_, _, k_next));
copy(
s2r_tiled_copy_b,
tCsB(_, _, k_next, ismem_read),
tCrB_copy_view(_, _, k_next));
if (k == 0) {
// prefetch the next tile
// issue copy
if (itile_to_read < NUM_TILES_K) {
copy(
g2s_tiled_copy_a,
tAgA(_, _, _, itile_to_read),
tAsA(_, _, _, ismem_write));
copy(
g2s_tiled_copy_b,
tBgB(_, _, _, itile_to_read),
tBsB(_, _, _, ismem_write));
itile_to_read++;
ismem_write = (ismem_write + 1) % NUM_STAGES;
}
// commit
cp_async_fence();
}
// mma
gemm(tiled_mma, tCrC, tCrA(_, _, k), tCrB(_, _, k), tCrC);
--k_tile_count;
if (k_tile_count > 0) {
++k_tile_next;
}
}
//
// Define A/B partitioning and C accumulators
//
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate registers for pipelining
Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(
(shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N
// Clear the accumulators
clear(tCrC);
//
// Copy Atom retiling
//
TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x);
Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE)
Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K)
TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE)
Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K)
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrA : "); print(tCrA); print("\n");
print("tCrB : "); print(tCrB); print("\n");
print("tCrC : "); print(tCrC); print("\n");
print("tXsA : "); print(tXsA); print("\n");
print("tXrA : "); print(tXrA); print("\n");
print("tXsB : "); print(tXsB); print("\n");
print("tXrB : "); print(tXrB); print("\n");
}
#endif
#if 1
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = K_PIPE_MAX - 1;
// Pipe slice
Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read);
Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read);
// Size of the register pipeline
auto K_BLOCK_MAX = size<2>(tCrA);
CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA));
// PREFETCH register pipeline
if (K_BLOCK_MAX > 1) {
// Wait until our first prefetched tile is loaded in
cp_async_wait<K_PIPE_MAX - 2>();
__syncthreads();
// Prefetch the first rmem from the first k-tile
copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{}));
copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{}));
}
//
// PIPELINED MAIN LOOP
// TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's
// cp.async instructions
// and explicit pipelines in shared memory.
// Data is read from global(k_tile_next) to shared(smem_pipe_write).
// Data is read from shared(smem_pipe_read) to registers(k_block_next).
// Data is computed on registers(b_block).
//
// This allows all copies and compute to overlap:
// Copy from gmem->smem can overlap with copies from smem->rmem and
// compute on rmem. Copy from smem->rmem can overlap with compute on rmem.
//
CUTE_NO_UNROLL
while (k_tile_count > -(K_PIPE_MAX - 1)) {
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
if (k_block == K_BLOCK_MAX - 1) {
// Slice the smem_pipe_read smem
tXsA_p = tXsA(_, _, _, smem_pipe_read);
tXsB_p = tXsB(_, _, _, smem_pipe_read);
// Commit the smem for smem_pipe_read
cp_async_wait<K_PIPE_MAX - 2>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static
copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next));
copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0) {
copy(
copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write));
copy(
copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write));
cp_async_fence();
// Advance the gmem tile
--k_tile_count;
if (k_tile_count > 0) {
++k_tile_next;
}
// Advance the smem pipe
smem_pipe_write = smem_pipe_read;
smem_pipe_read =
(smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1;
}
// Thread-level register gemm for k_block
gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
}
}
#endif
//
// Epilogue
//
copy(tCrC, tCgC);
// constexpr T alpha{1.0f};
// constexpr T beta{0.0f};
// axpby(alpha, tCrC, beta, tCgC);
}
} // namespace
@ -243,144 +339,103 @@ void cutlass_gemm(
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
using namespace cute;
// block shape and cta tiler
// additional dim: NUM_STAGES --> This is for later pipelining the k-slice
// GEMM
auto BLOCK_SIZE_M = _128{};
auto BLOCK_SIZE_N = _128{};
auto BLOCK_SIZE_K = _32{};
auto NUM_STAGES = _5{};
using CtaTiler =
decltype(make_shape(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K));
// Define shapes (dynamic)
auto prob_shape = make_shape(M, N, K);
// smem layout
// Swizzle parameters need to be chosen right
static constexpr int kShmLoadSwizzleM = 3;
static constexpr int kShmLoadSwizzleS = 3;
static constexpr int kShmLoadSwizzleB = 3;
// Define TN strides (mixed)
auto dA = make_stride(K, Int<1>{});
auto dB = make_stride(K, Int<1>{});
auto dC = make_stride(N, Int<1>{});
using SmemLayoutAtom = decltype(composition(
Swizzle<kShmLoadSwizzleB, kShmLoadSwizzleM, kShmLoadSwizzleS>{},
make_layout(make_shape(_8{}, BLOCK_SIZE_K), LayoutRight{})));
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int<64>{};
auto cta_tiler = make_shape(bM, bN, bK);
auto bP = Int<3>{};
// what does this do?
// with BLOCK_SIZE_K = 32, shape: (8, 32)
// 2^M = 8, 1 new unit = 8 units --> 1 row contains 32/8 = 4 new units
// 2^S = 8, it will treat 1 row = 8 new units --> do 8-unit swizzle
// 2^B = 8, it will reset the swizzle pattern after 8 rows
// print_layout(SmemLayoutAtom{});
// Define the smem layouts (static)
// Swizzles for LDSM and 128b k-major loads
auto swizzle_atom = composition(
Swizzle<3, 3, 3>{},
Layout<
cute::Shape<_8, cute::Shape<_8, _8>>,
cute::Stride<_8, cute::Stride<_1, _64>>>{});
// tile_to_shape extends the layout in LayoutLeft order
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtom{},
make_shape(BLOCK_SIZE_M, BLOCK_SIZE_K, NUM_STAGES)));
auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP));
auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP));
auto sC = make_layout(make_shape(bM, bN));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtom{},
make_shape(BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES)));
// Define the thread layouts (static)
// TiledMMA
using mma_op = SM80_16x8x16_F32BF16BF16F32_TN;
using mma_traits = MMA_Traits<mma_op>;
using mma_atom = MMA_Atom<mma_traits>;
TiledCopy copyA = make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
// 16x8 k-major
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 k-major
TiledCopy copyB = make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
// 16x8 k-major
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 n-major
static constexpr int kMmaEURepeatM = 2;
static constexpr int kMmaEURepeatN = 2;
static constexpr int kMmaEURepeatK = 1;
// 32 x 2 x 2 = 128 threads
TiledMMA mmaC = make_tiled_mma(
SM80_16x8x16_F32BF16BF16F32_TN{},
Layout<cute::Shape<_2, _2>>{}, // 2x2x1 MMA Atoms
Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM
using mma_atom_shape = mma_traits::Shape_MNK;
static constexpr int MmaVM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{});
static constexpr int MmaVN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{});
static constexpr int MmaVK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{});
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_A;
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_B;
// this is for problem shape (16x2) x (8x2x2) x (16x1) = 32x32x16
using MMA_EU_RepeatT = decltype(make_layout(make_shape(
Int<kMmaEURepeatM>{}, Int<kMmaEURepeatN>{}, Int<kMmaEURepeatK>{})));
using MMA_V_T = Tile<Int<MmaVM>, Int<MmaVN>, Int<MmaVK>>;
int smem_size = int(sizeof(SharedStorage<
cute::bfloat16_t,
cute::bfloat16_t,
decltype(sA),
decltype(sB)>));
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN)));
using TiledMMA =
decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_V_T{}));
auto kernel = gemm_device<
decltype(prob_shape),
decltype(cta_tiler),
cute::bfloat16_t,
decltype(dA),
decltype(sA),
decltype(copyA),
decltype(s2r_atom_A),
cute::bfloat16_t,
decltype(dB),
decltype(sB),
decltype(copyB),
decltype(s2r_atom_B),
cute::bfloat16_t,
decltype(dC),
decltype(sC),
decltype(mmaC)>;
// TiledCopy from global memory to shared memory
// uint128_t is 16 bytes = 4 floats = 8 halfs
static constexpr int NUM_VECTOR_UNITS =
sizeof(cute::uint128_t) / sizeof(DataType);
using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
using g2s_copy_traits = Copy_Traits<g2s_copy_op>;
using g2s_copy_atom = Copy_Atom<g2s_copy_traits, DataType>;
// one block contains 128 threads
// --> find the compatible thread layout
using G2S_Copy_Thread_Layout = decltype(make_layout(
make_shape(_32{}, _4{}), // 32x4 = 128 threads
LayoutRight{} // A is in row-major
));
using G2S_Copy_Value_Layout =
decltype(make_layout(make_shape(_1{}, Int<NUM_VECTOR_UNITS>{})));
// This is for copy shape 32x4 of uint128_t
using G2STiledCopyA = decltype(make_tiled_copy(
g2s_copy_atom{}, G2S_Copy_Thread_Layout{}, G2S_Copy_Value_Layout{}));
// Both A and B are in row-major so use the same TiledCopy for B
using G2STiledCopyB = G2STiledCopyA;
// CopyAtom from shared memory to registers
// Why no need to do tiling atom here? Because we will do it later with
// the information from TiledMMA
using s2r_copy_op = SM75_U32x4_LDSM_N;
using s2r_copy_traits = Copy_Traits<s2r_copy_op>;
using s2r_copy_atom = Copy_Atom<s2r_copy_traits, DataType>;
using S2RCopyAtomA = s2r_copy_atom;
using S2RCopyAtomB = s2r_copy_atom;
// print_latex(
// make_tiled_copy(S2RCopyAtomA{}, make_layout(Shape<_32>{}),
// make_layout(Shape<_1, _8>{}))
// );
// grid, block
dim3 block{size(TiledMMA{}), 1U, 1U};
dim3 grid{
size(ceil_div(static_cast<unsigned int>(N), BLOCK_SIZE_N)),
size(ceil_div(static_cast<unsigned int>(M), BLOCK_SIZE_M)),
1U};
static constexpr int smem_size =
(cosize_v<SmemLayoutA> + cosize_v<SmemLayoutB>)*sizeof(DataType);
auto kernel = cute_gemm_v02<
DataType,
CtaTiler,
SmemLayoutA,
SmemLayoutB,
TiledMMA,
G2STiledCopyA,
G2STiledCopyB,
S2RCopyAtomA,
S2RCopyAtomB>;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
configure_matmul(kernel, smem_size);
enc.add_kernel_node(
kernel,
grid,
block,
dimGrid,
dimBlock,
smem_size,
M,
N,
K,
a.data<DataType>(),
K,
b.data<DataType>(),
K,
out.data<DataType>(),
N);
prob_shape,
cta_tiler,
a.data<cute::bfloat16_t>(),
dA,
sA,
copyA,
s2r_atom_A,
b.data<cute::bfloat16_t>(),
dB,
sB,
copyB,
s2r_atom_B,
out.data<cute::bfloat16_t>(),
dC,
sC,
mmaC);
} else {
throw std::runtime_error("Only bfloat16 supported");
}