Improve the cutlass gemm

This commit is contained in:
Angelos Katharopoulos 2025-08-25 18:18:19 -07:00
parent e1303f6160
commit 4987e7615a

View File

@ -5,11 +5,21 @@
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cutlass/arch/arch.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/numeric_types.h>
#include <iostream>
namespace mlx::core::cu { namespace mlx::core::cu {
namespace { namespace {
using namespace cute;
using bf16 = cute::bfloat16_t;
template <typename Kernel> template <typename Kernel>
void configure_matmul(Kernel kernel, int smem_size) { void configure_matmul(Kernel kernel, int smem_size) {
static bool initialized = false; static bool initialized = false;
@ -17,308 +27,278 @@ void configure_matmul(Kernel kernel, int smem_size) {
initialized = true; initialized = true;
cudaFuncSetAttribute( cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
cudaFuncSetAttribute(
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
} }
} }
template <class ElementA, class ElementB, class SmemLayoutA, class SmemLayoutB> template <bool transpose, typename Tiler>
struct SharedStorage { constexpr int get_feature_size(Tiler smem) {
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A; int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B; return (feature_size >= 64) ? 64 : feature_size;
}; }
constexpr int constexpr_log2(int x) {
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
}
template <int feature_size, int itemsize, int copy_bits>
constexpr int get_swizzle_bits() {
constexpr int swizzle_bits =
constexpr_log2(feature_size * itemsize / copy_bits);
return (swizzle_bits > 3) ? 3 : swizzle_bits;
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled =
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_result_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled = make_composed_layout(
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template < template <
class ProblemShape, int num_threads,
class CtaTiler, int itemsize,
class TA, bool transpose,
class AStride, int copy_bits,
class ASmemLayout, typename Copier,
class TiledCopyA, typename Tiler>
class S2RAtomA, constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
class TB, constexpr int num_elements = copy_bits / itemsize;
class BStride, constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
class BSmemLayout, constexpr int copies_per_feature = feature_size / num_elements;
class TiledCopyB,
class S2RAtomB,
class TC,
class CStride,
class CSmemLayout,
class TiledMma>
__global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device(
ProblemShape shape_MNK,
CtaTiler cta_tiler,
TA const* A,
AStride dA,
ASmemLayout sA_layout,
TiledCopyA copy_a,
S2RAtomA s2r_atom_a,
TB const* B,
BStride dB,
BSmemLayout sB_layout,
TiledCopyB copy_b,
S2RAtomB s2r_atom_b,
TC* C,
CStride dC,
CSmemLayout,
TiledMma mma) {
using namespace cute;
// Preconditions using E = Int<num_elements>;
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) using C = Int<copies_per_feature>;
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) using R = Int<num_threads / copies_per_feature>;
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads using ThreadLayout = std::conditional_t<
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads transpose,
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
using ValueLayout = std::conditional_t<
transpose,
Layout<cute::Shape<E, _1>>,
Layout<cute::Shape<_1, E>>>;
static_assert(is_static<ASmemLayout>::value); return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
static_assert(is_static<BSmemLayout>::value); }
static_assert(is_static<CSmemLayout>::value);
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M template <int rasterization_factor>
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M __device__ inline int2 raster_tile(int x, int y) {
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N return {
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N x / rasterization_factor,
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K (x % rasterization_factor) + y * rasterization_factor};
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K }
CUTE_STATIC_ASSERT_V( template <
congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK typename T,
CUTE_STATIC_ASSERT_V( typename SLayoutA,
congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK typename SLayoutB,
CUTE_STATIC_ASSERT_V( typename SLayoutC,
congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN typename CopyA,
typename CopyB,
typename CopyC,
typename MMA,
int rasterization_factor>
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
const T* __restrict__ A,
const T* __restrict__ B,
T* __restrict__ C,
SLayoutA SA,
SLayoutB SB,
SLayoutC SC,
CopyA copy_a,
CopyB copy_b,
CopyC copy_c,
MMA mma,
int M,
int N,
int K) {
constexpr auto BM = size<0>(SA);
constexpr auto BN = size<0>(SB);
constexpr auto BK = size<1>(SA);
constexpr auto PIPE = size<2>(SA);
// const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
// Full and Tiled Tensors const int blocks_m = ceil_div(M, BM);
// const int blocks_n = ceil_div(N, BN);
// Represent the full tensors // Exit early if the tile is OOB
Tensor mA = if (tile.x >= blocks_m || tile.y >= blocks_n) {
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K) return;
Tensor mB = }
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K)
Tensor mC =
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block // Make the full tensors
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) Tensor full_A =
Tensor gA = local_tile( make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k) Tensor full_B =
Tensor gB = local_tile( make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k) Tensor full_C =
Tensor gC = make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers // Partition the tensors into tiles and select the ones for this threadblock
Tensor local_A =
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
Tensor local_B =
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
Tensor local_C =
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
// Make shared memory tensors
extern __shared__ char shared_memory[]; extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>; T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory); T* shared_B_ptr =
Tensor sA = make_tensor( reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE) T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
Tensor sB = make_tensor( Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE) Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
// // Get the copies that correspond to this thread
// Partition the copying of A and B tiles across the threads auto thread_copy_a = copy_a.get_slice(threadIdx.x);
// Tensor local_A_src = thread_copy_a.partition_S(local_A);
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
Tensor local_B_src = thread_copy_a.partition_S(local_B);
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); // Start fetches
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) int k_tile_count = size<2>(local_A);
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE)
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE)
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
//
// PREFETCH
//
auto K_PIPE_MAX = size<3>(tAsA);
// Total count of tiles
int k_tile_count = size<3>(tAgA);
// Current tile index in gmem to read from
int k_tile_next = 0; int k_tile_next = 0;
// Start async loads for all pipes but the last
CUTE_UNROLL CUTE_UNROLL
for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) { for (int k = 0; k < PIPE - 1; k++) {
copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
cp_async_fence(); cp_async_fence();
--k_tile_count; k_tile_count--;
if (k_tile_count > 0) { k_tile_next += (k_tile_count > 0);
++k_tile_next;
}
} }
// // Get the MMA that corresponds to this thread and allocate registers
// Define A/B partitioning and C accumulators auto thread_mma = mma.get_slice(threadIdx.x);
// Tensor mma_shared_A = thread_mma.partition_A(shared_A);
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
Tensor mma_global_C = thread_mma.partition_C(local_C);
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
clear(mma_frag_C);
ThrMMA thr_mma = mma.get_slice(threadIdx.x); // Make shared to register copies
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
// Allocate registers for pipelining constexpr auto RPIPE = size<2>(mma_shared_A);
Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K) int smem_read = 0;
Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K) int smem_write = PIPE - 1;
// Allocate the accumulators -- same size as the projected data Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
CUTE_STATIC_ASSERT_V( // Start the register pipeline
(shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N) if constexpr (RPIPE > 1) {
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M cp_async_wait<PIPE - 2>();
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N
// Clear the accumulators
clear(tCrC);
//
// Copy Atom retiling
//
TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x);
Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE)
Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K)
TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE)
Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K)
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrA : "); print(tCrA); print("\n");
print("tCrB : "); print(tCrB); print("\n");
print("tCrC : "); print(tCrC); print("\n");
print("tXsA : "); print(tXsA); print("\n");
print("tXrA : "); print(tXrA); print("\n");
print("tXsB : "); print(tXsB); print("\n");
print("tXrB : "); print(tXrB); print("\n");
}
#endif
#if 1
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = K_PIPE_MAX - 1;
// Pipe slice
Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read);
Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read);
// Size of the register pipeline
auto K_BLOCK_MAX = size<2>(tCrA);
CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA));
// PREFETCH register pipeline
if (K_BLOCK_MAX > 1) {
// Wait until our first prefetched tile is loaded in
cp_async_wait<K_PIPE_MAX - 2>();
__syncthreads(); __syncthreads();
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
// Prefetch the first rmem from the first k-tile copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{}));
copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{}));
} }
//
// PIPELINED MAIN LOOP
// TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's
// cp.async instructions
// and explicit pipelines in shared memory.
// Data is read from global(k_tile_next) to shared(smem_pipe_write).
// Data is read from shared(smem_pipe_read) to registers(k_block_next).
// Data is computed on registers(b_block).
//
// This allows all copies and compute to overlap:
// Copy from gmem->smem can overlap with copies from smem->rmem and
// compute on rmem. Copy from smem->rmem can overlap with compute on rmem.
//
CUTE_NO_UNROLL CUTE_NO_UNROLL
while (k_tile_count > -(K_PIPE_MAX - 1)) { while (k_tile_count > -(PIPE - 1)) {
CUTE_UNROLL CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { for (int k_block = 0; k_block < RPIPE; k_block++) {
if (k_block == K_BLOCK_MAX - 1) { if (k_block == RPIPE - 1) {
// Slice the smem_pipe_read smem mma_A_src_p = mma_A_src(_, _, _, smem_read);
tXsA_p = tXsA(_, _, _, smem_pipe_read); mma_B_src_p = mma_B_src(_, _, _, smem_read);
tXsB_p = tXsB(_, _, _, smem_pipe_read); cp_async_wait<PIPE - 2>();
// Commit the smem for smem_pipe_read
cp_async_wait<K_PIPE_MAX - 2>();
__syncthreads(); __syncthreads();
} }
// Load A, B shmem->regs for k_block+1 // Load the next register tile
auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static auto k_block_next = (k_block + 1) % RPIPE;
copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next)); copy(
copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next)); s2r_copy_a,
// Copy gmem to smem before computing gemm on each k-pipe mma_A_src_p(_, _, k_block_next),
mma_A_dst(_, _, k_block_next));
copy(
s2r_copy_b,
mma_B_src_p(_, _, k_block_next),
mma_B_dst(_, _, k_block_next));
if (k_block == 0) { if (k_block == 0) {
copy( copy(
copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write)); copy_a,
local_A_src(_, _, _, k_tile_next),
local_A_dst(_, _, _, smem_write));
copy( copy(
copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write)); copy_b,
local_B_src(_, _, _, k_tile_next),
local_B_dst(_, _, _, smem_write));
cp_async_fence(); cp_async_fence();
k_tile_count--;
// Advance the gmem tile k_tile_next += (k_tile_count > 0);
--k_tile_count; smem_write = smem_read;
if (k_tile_count > 0) { smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
++k_tile_next;
}
// Advance the smem pipe
smem_pipe_write = smem_pipe_read;
smem_pipe_read =
(smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1;
} }
// Thread-level register gemm for k_block
gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); gemm(
mma,
mma_frag_A(_, _, k_block),
mma_frag_B(_, _, k_block),
mma_frag_C);
} }
} }
#endif copy(mma_frag_C, mma_shared_C);
__syncthreads();
copy(copy_c, local_C_src, local_C_dst);
// if (threadIdx.x == 0) {
// print("fC: "); print(mma_frag_C); print("\n");
// print("sC: "); print(mma_shared_C); print("\n");
// print("dC: "); print(local_C_dst); print("\n");
// //
// Epilogue // print(s2r_atom_a); print("\n");
// // }
copy(tCrC, tCgC);
} }
} // namespace } // namespace
@ -339,103 +319,74 @@ void cutlass_gemm(
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) { if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
using namespace cute; using namespace cute;
// Define shapes (dynamic) // Tile definitions
auto prob_shape = make_shape(M, N, K); auto BM = Int<128>{};
auto BN = Int<128>{};
auto BK = Int<64>{};
auto BP = Int<3>{};
auto GM = Int<8>{};
// Define TN strides (mixed) // Thread definitions
auto dA = make_stride(K, Int<1>{}); using TM = Int<2>;
auto dB = make_stride(K, Int<1>{}); using TN = Int<2>;
auto dC = make_stride(N, Int<1>{}); using TK = Int<1>;
constexpr int num_threads = TM::value * TN::value * 32;
// Define CTA tile sizes (static) auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
auto bM = Int<128>{}; auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
auto bN = Int<128>{}; auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
auto bK = Int<64>{};
auto cta_tiler = make_shape(bM, bN, bK);
auto bP = Int<3>{};
// Define the smem layouts (static) constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
// Swizzles for LDSM and 128b k-major loads
auto swizzle_atom = composition(
Swizzle<3, 3, 3>{},
Layout<
cute::Shape<_8, cute::Shape<_8, _8>>,
cute::Stride<_8, cute::Stride<_1, _64>>>{});
auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP)); auto async_copy_op =
auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP)); Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
auto sC = make_layout(make_shape(bM, bN)); auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BM, BK));
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BN, BK));
// Define the thread layouts (static) auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
sync_copy_op, make_shape(BM, BN));
TiledCopy copyA = make_tiled_copy( auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{}, auto tiled_mma = make_tiled_mma(
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
// 16x8 k-major
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 k-major
TiledCopy copyB = make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
// 16x8 k-major
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 n-major
TiledMMA mmaC = make_tiled_mma(
SM80_16x8x16_F32BF16BF16F32_TN{},
Layout<cute::Shape<_2, _2>>{}, // 2x2x1 MMA Atoms
Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_A;
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_B;
int smem_size = int(sizeof(SharedStorage<
cute::bfloat16_t,
cute::bfloat16_t,
decltype(sA),
decltype(sB)>));
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN)));
auto kernel = gemm_device<
decltype(prob_shape),
decltype(cta_tiler),
cute::bfloat16_t,
decltype(dA),
decltype(sA),
decltype(copyA),
decltype(s2r_atom_A),
cute::bfloat16_t,
decltype(dB),
decltype(sB),
decltype(copyB),
decltype(s2r_atom_B),
cute::bfloat16_t,
decltype(dC),
decltype(sC),
decltype(mmaC)>;
auto kernel = matmul_kernel<
bf16,
decltype(SA),
decltype(SB),
decltype(SC),
decltype(tiled_copy_a),
decltype(tiled_copy_b),
decltype(tiled_copy_c),
decltype(tiled_mma),
GM.value>;
configure_matmul(kernel, smem_size); configure_matmul(kernel, smem_size);
dim3 block(size(tiled_mma));
dim3 grid(
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
enc.add_kernel_node( enc.add_kernel_node(
kernel, kernel,
dimGrid, grid,
dimBlock, block,
smem_size, smem_size,
prob_shape, a.data<bf16>(),
cta_tiler, b.data<bf16>(),
a.data<cute::bfloat16_t>(), out.data<bf16>(),
dA, SA,
sA, SB,
copyA, SC,
s2r_atom_A, tiled_copy_a,
b.data<cute::bfloat16_t>(), tiled_copy_b,
dB, tiled_copy_c,
sB, tiled_mma,
copyB, M,
s2r_atom_B, N,
out.data<cute::bfloat16_t>(), K);
dC,
sC,
mmaC);
} else { } else {
throw std::runtime_error("Only bfloat16 supported"); throw std::runtime_error("Only bfloat16 supported");
} }