Improve the cutlass gemm

This commit is contained in:
Angelos Katharopoulos 2025-08-25 18:18:19 -07:00
parent e1303f6160
commit 4987e7615a

View File

@ -5,11 +5,21 @@
#include "mlx/dtype_utils.h"
#include <cute/tensor.hpp>
#include <cutlass/arch/arch.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/numeric_types.h>
#include <iostream>
namespace mlx::core::cu {
namespace {
using namespace cute;
using bf16 = cute::bfloat16_t;
template <typename Kernel>
void configure_matmul(Kernel kernel, int smem_size) {
static bool initialized = false;
@ -17,308 +27,278 @@ void configure_matmul(Kernel kernel, int smem_size) {
initialized = true;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
cudaFuncSetAttribute(
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
}
}
template <class ElementA, class ElementB, class SmemLayoutA, class SmemLayoutB>
struct SharedStorage {
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A;
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B;
};
template <bool transpose, typename Tiler>
constexpr int get_feature_size(Tiler smem) {
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
return (feature_size >= 64) ? 64 : feature_size;
}
constexpr int constexpr_log2(int x) {
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
}
template <int feature_size, int itemsize, int copy_bits>
constexpr int get_swizzle_bits() {
constexpr int swizzle_bits =
constexpr_log2(feature_size * itemsize / copy_bits);
return (swizzle_bits > 3) ? 3 : swizzle_bits;
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled =
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_result_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled = make_composed_layout(
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <
class ProblemShape,
class CtaTiler,
class TA,
class AStride,
class ASmemLayout,
class TiledCopyA,
class S2RAtomA,
class TB,
class BStride,
class BSmemLayout,
class TiledCopyB,
class S2RAtomB,
class TC,
class CStride,
class CSmemLayout,
class TiledMma>
__global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device(
ProblemShape shape_MNK,
CtaTiler cta_tiler,
TA const* A,
AStride dA,
ASmemLayout sA_layout,
TiledCopyA copy_a,
S2RAtomA s2r_atom_a,
TB const* B,
BStride dB,
BSmemLayout sB_layout,
TiledCopyB copy_b,
S2RAtomB s2r_atom_b,
TC* C,
CStride dC,
CSmemLayout,
TiledMma mma) {
using namespace cute;
int num_threads,
int itemsize,
bool transpose,
int copy_bits,
typename Copier,
typename Tiler>
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
constexpr int num_elements = copy_bits / itemsize;
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
constexpr int copies_per_feature = feature_size / num_elements;
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
using E = Int<num_elements>;
using C = Int<copies_per_feature>;
using R = Int<num_threads / copies_per_feature>;
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
using ThreadLayout = std::conditional_t<
transpose,
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
using ValueLayout = std::conditional_t<
transpose,
Layout<cute::Shape<E, _1>>,
Layout<cute::Shape<_1, E>>>;
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
}
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
template <int rasterization_factor>
__device__ inline int2 raster_tile(int x, int y) {
return {
x / rasterization_factor,
(x % rasterization_factor) + y * rasterization_factor};
}
CUTE_STATIC_ASSERT_V(
congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(
congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(
congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN
template <
typename T,
typename SLayoutA,
typename SLayoutB,
typename SLayoutC,
typename CopyA,
typename CopyB,
typename CopyC,
typename MMA,
int rasterization_factor>
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
const T* __restrict__ A,
const T* __restrict__ B,
T* __restrict__ C,
SLayoutA SA,
SLayoutB SB,
SLayoutC SC,
CopyA copy_a,
CopyB copy_b,
CopyC copy_c,
MMA mma,
int M,
int N,
int K) {
constexpr auto BM = size<0>(SA);
constexpr auto BN = size<0>(SB);
constexpr auto BK = size<1>(SA);
constexpr auto PIPE = size<2>(SA);
//
// Full and Tiled Tensors
//
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
const int blocks_m = ceil_div(M, BM);
const int blocks_n = ceil_div(N, BN);
// Represent the full tensors
Tensor mA =
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K)
Tensor mB =
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K)
Tensor mC =
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N)
// Exit early if the tile is OOB
if (tile.x >= blocks_m || tile.y >= blocks_n) {
return;
}
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(
mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(
mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
Tensor gC =
local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)
// Make the full tensors
Tensor full_A =
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
Tensor full_B =
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
Tensor full_C =
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
// Shared memory buffers
// Partition the tensors into tiles and select the ones for this threadblock
Tensor local_A =
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
Tensor local_B =
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
Tensor local_C =
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
// Make shared memory tensors
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor sA = make_tensor(
make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(
make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE)
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
T* shared_B_ptr =
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
//
// Partition the copying of A and B tiles across the threads
//
// Get the copies that correspond to this thread
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
Tensor local_A_src = thread_copy_a.partition_S(local_A);
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
Tensor local_B_src = thread_copy_a.partition_S(local_B);
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE)
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE)
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
//
// PREFETCH
//
auto K_PIPE_MAX = size<3>(tAsA);
// Total count of tiles
int k_tile_count = size<3>(tAgA);
// Current tile index in gmem to read from
// Start fetches
int k_tile_count = size<2>(local_A);
int k_tile_next = 0;
// Start async loads for all pipes but the last
CUTE_UNROLL
for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) {
copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe));
copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe));
for (int k = 0; k < PIPE - 1; k++) {
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
cp_async_fence();
--k_tile_count;
if (k_tile_count > 0) {
++k_tile_next;
}
k_tile_count--;
k_tile_next += (k_tile_count > 0);
}
//
// Define A/B partitioning and C accumulators
//
// Get the MMA that corresponds to this thread and allocate registers
auto thread_mma = mma.get_slice(threadIdx.x);
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
Tensor mma_global_C = thread_mma.partition_C(local_C);
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
clear(mma_frag_C);
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Make shared to register copies
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
// Allocate registers for pipelining
Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
constexpr auto RPIPE = size<2>(mma_shared_A);
int smem_read = 0;
int smem_write = PIPE - 1;
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
CUTE_STATIC_ASSERT_V(
(shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N
// Clear the accumulators
clear(tCrC);
//
// Copy Atom retiling
//
TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x);
Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE)
Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K)
TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE)
Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K)
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrA : "); print(tCrA); print("\n");
print("tCrB : "); print(tCrB); print("\n");
print("tCrC : "); print(tCrC); print("\n");
print("tXsA : "); print(tXsA); print("\n");
print("tXrA : "); print(tXrA); print("\n");
print("tXsB : "); print(tXsB); print("\n");
print("tXrB : "); print(tXrB); print("\n");
}
#endif
#if 1
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = K_PIPE_MAX - 1;
// Pipe slice
Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read);
Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read);
// Size of the register pipeline
auto K_BLOCK_MAX = size<2>(tCrA);
CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA));
// PREFETCH register pipeline
if (K_BLOCK_MAX > 1) {
// Wait until our first prefetched tile is loaded in
cp_async_wait<K_PIPE_MAX - 2>();
// Start the register pipeline
if constexpr (RPIPE > 1) {
cp_async_wait<PIPE - 2>();
__syncthreads();
// Prefetch the first rmem from the first k-tile
copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{}));
copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{}));
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
}
//
// PIPELINED MAIN LOOP
// TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's
// cp.async instructions
// and explicit pipelines in shared memory.
// Data is read from global(k_tile_next) to shared(smem_pipe_write).
// Data is read from shared(smem_pipe_read) to registers(k_block_next).
// Data is computed on registers(b_block).
//
// This allows all copies and compute to overlap:
// Copy from gmem->smem can overlap with copies from smem->rmem and
// compute on rmem. Copy from smem->rmem can overlap with compute on rmem.
//
CUTE_NO_UNROLL
while (k_tile_count > -(K_PIPE_MAX - 1)) {
while (k_tile_count > -(PIPE - 1)) {
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
if (k_block == K_BLOCK_MAX - 1) {
// Slice the smem_pipe_read smem
tXsA_p = tXsA(_, _, _, smem_pipe_read);
tXsB_p = tXsB(_, _, _, smem_pipe_read);
// Commit the smem for smem_pipe_read
cp_async_wait<K_PIPE_MAX - 2>();
for (int k_block = 0; k_block < RPIPE; k_block++) {
if (k_block == RPIPE - 1) {
mma_A_src_p = mma_A_src(_, _, _, smem_read);
mma_B_src_p = mma_B_src(_, _, _, smem_read);
cp_async_wait<PIPE - 2>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static
copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next));
copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
// Load the next register tile
auto k_block_next = (k_block + 1) % RPIPE;
copy(
s2r_copy_a,
mma_A_src_p(_, _, k_block_next),
mma_A_dst(_, _, k_block_next));
copy(
s2r_copy_b,
mma_B_src_p(_, _, k_block_next),
mma_B_dst(_, _, k_block_next));
if (k_block == 0) {
copy(
copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write));
copy_a,
local_A_src(_, _, _, k_tile_next),
local_A_dst(_, _, _, smem_write));
copy(
copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write));
copy_b,
local_B_src(_, _, _, k_tile_next),
local_B_dst(_, _, _, smem_write));
cp_async_fence();
// Advance the gmem tile
--k_tile_count;
if (k_tile_count > 0) {
++k_tile_next;
k_tile_count--;
k_tile_next += (k_tile_count > 0);
smem_write = smem_read;
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
}
// Advance the smem pipe
smem_pipe_write = smem_pipe_read;
smem_pipe_read =
(smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1;
}
// Thread-level register gemm for k_block
gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
gemm(
mma,
mma_frag_A(_, _, k_block),
mma_frag_B(_, _, k_block),
mma_frag_C);
}
}
#endif
copy(mma_frag_C, mma_shared_C);
__syncthreads();
copy(copy_c, local_C_src, local_C_dst);
// if (threadIdx.x == 0) {
// print("fC: "); print(mma_frag_C); print("\n");
// print("sC: "); print(mma_shared_C); print("\n");
// print("dC: "); print(local_C_dst); print("\n");
//
// Epilogue
//
copy(tCrC, tCgC);
// print(s2r_atom_a); print("\n");
// }
}
} // namespace
@ -339,103 +319,74 @@ void cutlass_gemm(
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
using namespace cute;
// Define shapes (dynamic)
auto prob_shape = make_shape(M, N, K);
// Tile definitions
auto BM = Int<128>{};
auto BN = Int<128>{};
auto BK = Int<64>{};
auto BP = Int<3>{};
auto GM = Int<8>{};
// Define TN strides (mixed)
auto dA = make_stride(K, Int<1>{});
auto dB = make_stride(K, Int<1>{});
auto dC = make_stride(N, Int<1>{});
// Thread definitions
using TM = Int<2>;
using TN = Int<2>;
using TK = Int<1>;
constexpr int num_threads = TM::value * TN::value * 32;
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int<64>{};
auto cta_tiler = make_shape(bM, bN, bK);
auto bP = Int<3>{};
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
// Define the smem layouts (static)
// Swizzles for LDSM and 128b k-major loads
auto swizzle_atom = composition(
Swizzle<3, 3, 3>{},
Layout<
cute::Shape<_8, cute::Shape<_8, _8>>,
cute::Stride<_8, cute::Stride<_1, _64>>>{});
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP));
auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP));
auto sC = make_layout(make_shape(bM, bN));
auto async_copy_op =
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BM, BK));
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BN, BK));
// Define the thread layouts (static)
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
sync_copy_op, make_shape(BM, BN));
TiledCopy copyA = make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
// 16x8 k-major
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 k-major
TiledCopy copyB = make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
// 16x8 k-major
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 n-major
TiledMMA mmaC = make_tiled_mma(
SM80_16x8x16_F32BF16BF16F32_TN{},
Layout<cute::Shape<_2, _2>>{}, // 2x2x1 MMA Atoms
Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_A;
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_B;
int smem_size = int(sizeof(SharedStorage<
cute::bfloat16_t,
cute::bfloat16_t,
decltype(sA),
decltype(sB)>));
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN)));
auto kernel = gemm_device<
decltype(prob_shape),
decltype(cta_tiler),
cute::bfloat16_t,
decltype(dA),
decltype(sA),
decltype(copyA),
decltype(s2r_atom_A),
cute::bfloat16_t,
decltype(dB),
decltype(sB),
decltype(copyB),
decltype(s2r_atom_B),
cute::bfloat16_t,
decltype(dC),
decltype(sC),
decltype(mmaC)>;
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
auto tiled_mma = make_tiled_mma(
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
auto kernel = matmul_kernel<
bf16,
decltype(SA),
decltype(SB),
decltype(SC),
decltype(tiled_copy_a),
decltype(tiled_copy_b),
decltype(tiled_copy_c),
decltype(tiled_mma),
GM.value>;
configure_matmul(kernel, smem_size);
dim3 block(size(tiled_mma));
dim3 grid(
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
enc.add_kernel_node(
kernel,
dimGrid,
dimBlock,
grid,
block,
smem_size,
prob_shape,
cta_tiler,
a.data<cute::bfloat16_t>(),
dA,
sA,
copyA,
s2r_atom_A,
b.data<cute::bfloat16_t>(),
dB,
sB,
copyB,
s2r_atom_B,
out.data<cute::bfloat16_t>(),
dC,
sC,
mmaC);
a.data<bf16>(),
b.data<bf16>(),
out.data<bf16>(),
SA,
SB,
SC,
tiled_copy_a,
tiled_copy_b,
tiled_copy_c,
tiled_mma,
M,
N,
K);
} else {
throw std::runtime_error("Only bfloat16 supported");
}