From 4987e7615ad1286405771a9d5ea8352aad903829 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 25 Aug 2025 18:18:19 -0700 Subject: [PATCH] Improve the cutlass gemm --- mlx/backend/cuda/gemms/cutlass_gemm.cu | 639 ++++++++++++------------- 1 file changed, 295 insertions(+), 344 deletions(-) diff --git a/mlx/backend/cuda/gemms/cutlass_gemm.cu b/mlx/backend/cuda/gemms/cutlass_gemm.cu index c58d7e62e..f912ad897 100644 --- a/mlx/backend/cuda/gemms/cutlass_gemm.cu +++ b/mlx/backend/cuda/gemms/cutlass_gemm.cu @@ -5,11 +5,21 @@ #include "mlx/dtype_utils.h" #include +#include +#include +#include +#include +#include + +#include namespace mlx::core::cu { namespace { +using namespace cute; +using bf16 = cute::bfloat16_t; + template void configure_matmul(Kernel kernel, int smem_size) { static bool initialized = false; @@ -17,308 +27,278 @@ void configure_matmul(Kernel kernel, int smem_size) { initialized = true; cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - cudaFuncSetAttribute( - kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); } } -template -struct SharedStorage { - cute::ArrayEngine> A; - cute::ArrayEngine> B; -}; +template +constexpr int get_feature_size(Tiler smem) { + int feature_size = (transpose) ? size<0>(smem) : size<1>(smem); + return (feature_size >= 64) ? 64 : feature_size; +} + +constexpr int constexpr_log2(int x) { + return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1; +} + +template +constexpr int get_swizzle_bits() { + constexpr int swizzle_bits = + constexpr_log2(feature_size * itemsize / copy_bits); + return (swizzle_bits > 3) ? 3 : swizzle_bits; +} + +template +constexpr auto make_smem_layout(Tiler smem) { + constexpr int feature_size = get_feature_size(smem); + constexpr int swizzle_bits = + get_swizzle_bits(); + + using F = Int; + using BaseLayout = std::conditional_t< + transpose, + Layout, cute::Stride<_1, F>>, + Layout, cute::Stride>>; + + auto swizzled = + make_composed_layout(Swizzle{}, 0, BaseLayout{}); + + return tile_to_shape(swizzled, smem); +} + +template +constexpr auto make_result_smem_layout(Tiler smem) { + constexpr int feature_size = get_feature_size(smem); + constexpr int swizzle_bits = + get_swizzle_bits(); + + using F = Int; + using BaseLayout = std::conditional_t< + transpose, + Layout, cute::Stride<_1, F>>, + Layout, cute::Stride>>; + + auto swizzled = make_composed_layout( + Swizzle{}, 0, BaseLayout{}); + + return tile_to_shape(swizzled, smem); +} template < - class ProblemShape, - class CtaTiler, - class TA, - class AStride, - class ASmemLayout, - class TiledCopyA, - class S2RAtomA, - class TB, - class BStride, - class BSmemLayout, - class TiledCopyB, - class S2RAtomB, - class TC, - class CStride, - class CSmemLayout, - class TiledMma> -__global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device( - ProblemShape shape_MNK, - CtaTiler cta_tiler, - TA const* A, - AStride dA, - ASmemLayout sA_layout, - TiledCopyA copy_a, - S2RAtomA s2r_atom_a, - TB const* B, - BStride dB, - BSmemLayout sB_layout, - TiledCopyB copy_b, - S2RAtomB s2r_atom_b, - TC* C, - CStride dC, - CSmemLayout, - TiledMma mma) { - using namespace cute; + int num_threads, + int itemsize, + bool transpose, + int copy_bits, + typename Copier, + typename Tiler> +constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) { + constexpr int num_elements = copy_bits / itemsize; + constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem); + constexpr int copies_per_feature = feature_size / num_elements; - // Preconditions - CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) - CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) + using E = Int; + using C = Int; + using R = Int; - CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads - CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads + using ThreadLayout = std::conditional_t< + transpose, + Layout, cute::Stride<_1, C>>, + Layout, cute::Stride>>; + using ValueLayout = std::conditional_t< + transpose, + Layout>, + Layout>>; - static_assert(is_static::value); - static_assert(is_static::value); - static_assert(is_static::value); + return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{}); +} - CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M - CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M - CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N - CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N - CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K - CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K +template +__device__ inline int2 raster_tile(int x, int y) { + return { + x / rasterization_factor, + (x % rasterization_factor) + y * rasterization_factor}; +} - CUTE_STATIC_ASSERT_V( - congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK - CUTE_STATIC_ASSERT_V( - congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK - CUTE_STATIC_ASSERT_V( - congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN +template < + typename T, + typename SLayoutA, + typename SLayoutB, + typename SLayoutC, + typename CopyA, + typename CopyB, + typename CopyC, + typename MMA, + int rasterization_factor> +__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel( + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + SLayoutA SA, + SLayoutB SB, + SLayoutC SC, + CopyA copy_a, + CopyB copy_b, + CopyC copy_c, + MMA mma, + int M, + int N, + int K) { + constexpr auto BM = size<0>(SA); + constexpr auto BN = size<0>(SB); + constexpr auto BK = size<1>(SA); + constexpr auto PIPE = size<2>(SA); - // - // Full and Tiled Tensors - // + const int2 tile = raster_tile(blockIdx.x, blockIdx.y); + const int blocks_m = ceil_div(M, BM); + const int blocks_n = ceil_div(N, BN); - // Represent the full tensors - Tensor mA = - make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K) - Tensor mB = - make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K) - Tensor mC = - make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N) + // Exit early if the tile is OOB + if (tile.x >= blocks_m || tile.y >= blocks_n) { + return; + } - // Get the appropriate blocks for this thread block - auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) - Tensor gA = local_tile( - mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k) - Tensor gB = local_tile( - mB, cta_tiler, cta_coord, Step{}); // (BLK_N,BLK_K,k) - Tensor gC = - local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N) + // Make the full tensors + Tensor full_A = + make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{})); + Tensor full_B = + make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{})); + Tensor full_C = + make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{})); - // Shared memory buffers + // Partition the tensors into tiles and select the ones for this threadblock + Tensor local_A = + local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _)); + Tensor local_B = + local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _)); + Tensor local_C = + local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y)); + + // Make shared memory tensors extern __shared__ char shared_memory[]; - using SharedStorage = SharedStorage; - SharedStorage& smem = *reinterpret_cast(shared_memory); - Tensor sA = make_tensor( - make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor( - make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE) + T* shared_A_ptr = reinterpret_cast(shared_memory); + T* shared_B_ptr = + reinterpret_cast(shared_memory + cosize(SA) * sizeof(T)); + T* shared_C_ptr = reinterpret_cast(shared_memory); + Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA); + Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB); + Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC); - // - // Partition the copying of A and B tiles across the threads - // + // Get the copies that correspond to this thread + auto thread_copy_a = copy_a.get_slice(threadIdx.x); + Tensor local_A_src = thread_copy_a.partition_S(local_A); + Tensor local_A_dst = thread_copy_a.partition_D(shared_A); + auto thread_copy_b = copy_b.get_slice(threadIdx.x); + Tensor local_B_src = thread_copy_a.partition_S(local_B); + Tensor local_B_dst = thread_copy_a.partition_D(shared_B); + auto thread_copy_c = copy_c.get_slice(threadIdx.x); + Tensor local_C_src = thread_copy_c.partition_S(shared_C); + Tensor local_C_dst = thread_copy_c.partition_D(local_C); - ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); - Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) - Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE) - - ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); - Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) - Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE) - - CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M - CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K - CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N - CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K - - // - // PREFETCH - // - - auto K_PIPE_MAX = size<3>(tAsA); - - // Total count of tiles - int k_tile_count = size<3>(tAgA); - // Current tile index in gmem to read from + // Start fetches + int k_tile_count = size<2>(local_A); int k_tile_next = 0; - - // Start async loads for all pipes but the last CUTE_UNROLL - for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) { - copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); - copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); + for (int k = 0; k < PIPE - 1; k++) { + copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k)); + copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k)); cp_async_fence(); - --k_tile_count; - if (k_tile_count > 0) { - ++k_tile_next; - } + k_tile_count--; + k_tile_next += (k_tile_count > 0); } - // - // Define A/B partitioning and C accumulators - // + // Get the MMA that corresponds to this thread and allocate registers + auto thread_mma = mma.get_slice(threadIdx.x); + Tensor mma_shared_A = thread_mma.partition_A(shared_A); + Tensor mma_shared_B = thread_mma.partition_B(shared_B); + Tensor mma_shared_C = thread_mma.partition_C(shared_C); + Tensor mma_global_C = thread_mma.partition_C(local_C); + Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0)); + Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0)); + Tensor mma_frag_C = mma.make_fragment_C(mma_global_C); + clear(mma_frag_C); - ThrMMA thr_mma = mma.get_slice(threadIdx.x); - Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + // Make shared to register copies + Copy_Atom s2r_atom_a; + Copy_Atom s2r_atom_b; + auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); + auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x); + auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); + auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x); + Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A); + Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A); + Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B); + Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B); - // Allocate registers for pipelining - Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K) - Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K) - // Allocate the accumulators -- same size as the projected data - Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + constexpr auto RPIPE = size<2>(mma_shared_A); + int smem_read = 0; + int smem_write = PIPE - 1; + Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read); + Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read); - CUTE_STATIC_ASSERT_V( - (shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N) - CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M - CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N - - // Clear the accumulators - clear(tCrC); - - // - // Copy Atom retiling - // - - TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); - ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x); - Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE) - Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K) - - TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); - ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x); - Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE) - Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K) - -#if 0 - if(thread0()) { - print(" mA : "); print( mA); print("\n"); - print(" gA : "); print( gA); print("\n"); - print(" sA : "); print( sA); print("\n"); - print("tAgA : "); print(tAgA); print("\n"); - print("tAsA : "); print(tAsA); print("\n"); - } -#endif - -#if 0 - if(thread0()) { - print(" mB : "); print( mB); print("\n"); - print(" gB : "); print( gB); print("\n"); - print(" sB : "); print( sB); print("\n"); - print("tBgB : "); print(tBgB); print("\n"); - print("tBsB : "); print(tBsB); print("\n"); - } -#endif - -#if 0 - if(thread0()) { - print(" mC : "); print( mC); print("\n"); - print(" gC : "); print( gC); print("\n"); - print("tCgC : "); print(tCgC); print("\n"); - print("tCrA : "); print(tCrA); print("\n"); - print("tCrB : "); print(tCrB); print("\n"); - print("tCrC : "); print(tCrC); print("\n"); - - print("tXsA : "); print(tXsA); print("\n"); - print("tXrA : "); print(tXrA); print("\n"); - print("tXsB : "); print(tXsB); print("\n"); - print("tXrB : "); print(tXrB); print("\n"); - } -#endif - -#if 1 - - // Current pipe index in smem to read from - int smem_pipe_read = 0; - // Current pipe index in smem to write to - int smem_pipe_write = K_PIPE_MAX - 1; - - // Pipe slice - Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read); - Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read); - - // Size of the register pipeline - auto K_BLOCK_MAX = size<2>(tCrA); - CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA)); - - // PREFETCH register pipeline - if (K_BLOCK_MAX > 1) { - // Wait until our first prefetched tile is loaded in - cp_async_wait(); + // Start the register pipeline + if constexpr (RPIPE > 1) { + cp_async_wait(); __syncthreads(); - - // Prefetch the first rmem from the first k-tile - copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{})); - copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{})); + copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{})); + copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{})); } - // - // PIPELINED MAIN LOOP - // TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's - // cp.async instructions - // and explicit pipelines in shared memory. - // Data is read from global(k_tile_next) to shared(smem_pipe_write). - // Data is read from shared(smem_pipe_read) to registers(k_block_next). - // Data is computed on registers(b_block). - // - // This allows all copies and compute to overlap: - // Copy from gmem->smem can overlap with copies from smem->rmem and - // compute on rmem. Copy from smem->rmem can overlap with compute on rmem. - // - CUTE_NO_UNROLL - while (k_tile_count > -(K_PIPE_MAX - 1)) { + while (k_tile_count > -(PIPE - 1)) { CUTE_UNROLL - for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { - if (k_block == K_BLOCK_MAX - 1) { - // Slice the smem_pipe_read smem - tXsA_p = tXsA(_, _, _, smem_pipe_read); - tXsB_p = tXsB(_, _, _, smem_pipe_read); - - // Commit the smem for smem_pipe_read - cp_async_wait(); + for (int k_block = 0; k_block < RPIPE; k_block++) { + if (k_block == RPIPE - 1) { + mma_A_src_p = mma_A_src(_, _, _, smem_read); + mma_B_src_p = mma_B_src(_, _, _, smem_read); + cp_async_wait(); __syncthreads(); } - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static - copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next)); - copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next)); - // Copy gmem to smem before computing gemm on each k-pipe + // Load the next register tile + auto k_block_next = (k_block + 1) % RPIPE; + copy( + s2r_copy_a, + mma_A_src_p(_, _, k_block_next), + mma_A_dst(_, _, k_block_next)); + copy( + s2r_copy_b, + mma_B_src_p(_, _, k_block_next), + mma_B_dst(_, _, k_block_next)); + if (k_block == 0) { copy( - copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write)); + copy_a, + local_A_src(_, _, _, k_tile_next), + local_A_dst(_, _, _, smem_write)); copy( - copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write)); + copy_b, + local_B_src(_, _, _, k_tile_next), + local_B_dst(_, _, _, smem_write)); cp_async_fence(); - - // Advance the gmem tile - --k_tile_count; - if (k_tile_count > 0) { - ++k_tile_next; - } - - // Advance the smem pipe - smem_pipe_write = smem_pipe_read; - smem_pipe_read = - (smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1; + k_tile_count--; + k_tile_next += (k_tile_count > 0); + smem_write = smem_read; + smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1); } - // Thread-level register gemm for k_block - gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + + gemm( + mma, + mma_frag_A(_, _, k_block), + mma_frag_B(_, _, k_block), + mma_frag_C); } } -#endif + copy(mma_frag_C, mma_shared_C); + __syncthreads(); + copy(copy_c, local_C_src, local_C_dst); + // if (threadIdx.x == 0) { + // print("fC: "); print(mma_frag_C); print("\n"); + // print("sC: "); print(mma_shared_C); print("\n"); + // print("dC: "); print(local_C_dst); print("\n"); // - // Epilogue - // - - copy(tCrC, tCgC); + // print(s2r_atom_a); print("\n"); + // } } } // namespace @@ -339,103 +319,74 @@ void cutlass_gemm( if constexpr (std::is_same_v) { using namespace cute; - // Define shapes (dynamic) - auto prob_shape = make_shape(M, N, K); + // Tile definitions + auto BM = Int<128>{}; + auto BN = Int<128>{}; + auto BK = Int<64>{}; + auto BP = Int<3>{}; + auto GM = Int<8>{}; - // Define TN strides (mixed) - auto dA = make_stride(K, Int<1>{}); - auto dB = make_stride(K, Int<1>{}); - auto dC = make_stride(N, Int<1>{}); + // Thread definitions + using TM = Int<2>; + using TN = Int<2>; + using TK = Int<1>; + constexpr int num_threads = TM::value * TN::value * 32; - // Define CTA tile sizes (static) - auto bM = Int<128>{}; - auto bN = Int<128>{}; - auto bK = Int<64>{}; - auto cta_tiler = make_shape(bM, bN, bK); - auto bP = Int<3>{}; + auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP)); + auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP)); + auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN)); - // Define the smem layouts (static) - // Swizzles for LDSM and 128b k-major loads - auto swizzle_atom = composition( - Swizzle<3, 3, 3>{}, - Layout< - cute::Shape<_8, cute::Shape<_8, _8>>, - cute::Stride<_8, cute::Stride<_1, _64>>>{}); + constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16); - auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP)); - auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP)); - auto sC = make_layout(make_shape(bM, bN)); + auto async_copy_op = + Copy_Atom, bf16>{}; + auto tiled_copy_a = make_tiled_copy( + async_copy_op, make_shape(BM, BK)); + auto tiled_copy_b = make_tiled_copy( + async_copy_op, make_shape(BN, BK)); - // Define the thread layouts (static) + auto sync_copy_op = Copy_Atom, bf16>{}; + auto tiled_copy_c = make_tiled_copy( + sync_copy_op, make_shape(BM, BN)); - TiledCopy copyA = make_tiled_copy( - Copy_Atom, cute::bfloat16_t>{}, - Layout, cute::Stride<_8, _1>>{}, // Thr layout - // 16x8 k-major - Layout>{}); // Val layout 1x8 k-major - TiledCopy copyB = make_tiled_copy( - Copy_Atom, cute::bfloat16_t>{}, - Layout, cute::Stride<_8, _1>>{}, // Thr layout - // 16x8 k-major - Layout>{}); // Val layout 1x8 n-major - - TiledMMA mmaC = make_tiled_mma( - SM80_16x8x16_F32BF16BF16F32_TN{}, - Layout>{}, // 2x2x1 MMA Atoms - Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM - - Copy_Atom s2r_atom_A; - Copy_Atom s2r_atom_B; - - int smem_size = int(sizeof(SharedStorage< - cute::bfloat16_t, - cute::bfloat16_t, - decltype(sA), - decltype(sB)>)); - dim3 dimBlock(size(mmaC)); - dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); - - auto kernel = gemm_device< - decltype(prob_shape), - decltype(cta_tiler), - cute::bfloat16_t, - decltype(dA), - decltype(sA), - decltype(copyA), - decltype(s2r_atom_A), - cute::bfloat16_t, - decltype(dB), - decltype(sB), - decltype(copyB), - decltype(s2r_atom_B), - cute::bfloat16_t, - decltype(dC), - decltype(sC), - decltype(mmaC)>; + auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{}; + auto tiled_mma = make_tiled_mma( + mma_op, Layout>{}, Tile<_32, _32, _16>{}); + auto kernel = matmul_kernel< + bf16, + decltype(SA), + decltype(SB), + decltype(SC), + decltype(tiled_copy_a), + decltype(tiled_copy_b), + decltype(tiled_copy_c), + decltype(tiled_mma), + GM.value>; configure_matmul(kernel, smem_size); + dim3 block(size(tiled_mma)); + dim3 grid( + size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM))); + enc.add_kernel_node( kernel, - dimGrid, - dimBlock, + grid, + block, smem_size, - prob_shape, - cta_tiler, - a.data(), - dA, - sA, - copyA, - s2r_atom_A, - b.data(), - dB, - sB, - copyB, - s2r_atom_B, - out.data(), - dC, - sC, - mmaC); + a.data(), + b.data(), + out.data(), + SA, + SB, + SC, + tiled_copy_a, + tiled_copy_b, + tiled_copy_c, + tiled_mma, + M, + N, + K); } else { throw std::runtime_error("Only bfloat16 supported"); }