From e1303f616012562e9680a38ac8c159b7dc00d056 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 21 Aug 2025 01:29:43 -0700 Subject: [PATCH] Reset cutlass gemm to working state again --- mlx/backend/cuda/gemms/cutlass_gemm.cu | 671 +++++++++++++------------ 1 file changed, 363 insertions(+), 308 deletions(-) diff --git a/mlx/backend/cuda/gemms/cutlass_gemm.cu b/mlx/backend/cuda/gemms/cutlass_gemm.cu index 0ab60a689..c58d7e62e 100644 --- a/mlx/backend/cuda/gemms/cutlass_gemm.cu +++ b/mlx/backend/cuda/gemms/cutlass_gemm.cu @@ -10,219 +10,315 @@ namespace mlx::core::cu { namespace { +template +void configure_matmul(Kernel kernel, int smem_size) { + static bool initialized = false; + if (!initialized) { + initialized = true; + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + cudaFuncSetAttribute( + kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); + } +} + +template +struct SharedStorage { + cute::ArrayEngine> A; + cute::ArrayEngine> B; +}; + template < - class T, + class ProblemShape, class CtaTiler, - class SmemLayoutA, - class SmemLayoutB, - class TiledMMA, - class G2STiledCopyA, - class G2STiledCopyB, - class S2RCopyAtomA, - class S2RCopyAtomB> -__global__ void cute_gemm_v02( - unsigned int M, - unsigned int N, - unsigned int K, - const T* A, - size_t lda, - const T* B, - size_t ldb, - T* C, - size_t ldc) { + class TA, + class AStride, + class ASmemLayout, + class TiledCopyA, + class S2RAtomA, + class TB, + class BStride, + class BSmemLayout, + class TiledCopyB, + class S2RAtomB, + class TC, + class CStride, + class CSmemLayout, + class TiledMma> +__global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device( + ProblemShape shape_MNK, + CtaTiler cta_tiler, + TA const* A, + AStride dA, + ASmemLayout sA_layout, + TiledCopyA copy_a, + S2RAtomA s2r_atom_a, + TB const* B, + BStride dB, + BSmemLayout sB_layout, + TiledCopyB copy_b, + S2RAtomB s2r_atom_b, + TC* C, + CStride dC, + CSmemLayout, + TiledMma mma) { using namespace cute; - // global full tensor - // shape - auto shape_MNK = make_shape(M, N, K); + // Preconditions + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) - // stride - // cublas covenience for TN gemm - // all matrices are in column major - // A (m,k) --> transpose --> A(k, m) --> cute layout: A (m, k) : (k, 1) --> - // lda = k B (k,n) --> cute layout: B (n, k) : (k, 1) --> ldb = k C (m,n) --> - // cute layout: C (m, n) : (1, m) --> ldc = m - auto dA = make_stride(lda, _1{}); - auto dB = make_stride(ldb, _1{}); - auto dC = make_stride(_1{}, ldc); + CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads + CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_static::value); + + CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K + + CUTE_STATIC_ASSERT_V( + congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK + CUTE_STATIC_ASSERT_V( + congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK + CUTE_STATIC_ASSERT_V( + congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN + + // + // Full and Tiled Tensors + // + + // Represent the full tensors Tensor mA = - make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // M x K + make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K) Tensor mB = - make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // N x K + make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K) Tensor mC = - make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // M x N + make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N) - // global tile tensor - auto cta_tiler = CtaTiler{}; - auto cta_coord = make_coord(blockIdx.y, blockIdx.x, _); + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) Tensor gA = local_tile( - mA, - cta_tiler, - cta_coord, - Step<_1, X, _1>{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_TILES_K + mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k) Tensor gB = local_tile( - mB, - cta_tiler, - cta_coord, - Step{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_TILES_K - Tensor gC = local_tile( - mC, - cta_tiler, - cta_coord, - Step<_1, _1, X>{}); // BLOCK_SIZE_M x BLOCK_SIZE_N - - // shared memory - // __shared__ T Asmem[cosize_v]; - // __shared__ T Bsmem[cosize_v]; - - extern __shared__ T smem[]; - T* Asmem = smem; - T* Bsmem = smem + cosize_v; + mB, cta_tiler, cta_coord, Step{}); // (BLK_N,BLK_K,k) + Tensor gC = + local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N) + // Shared memory buffers + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& smem = *reinterpret_cast(shared_memory); Tensor sA = make_tensor( - make_smem_ptr(Asmem), - SmemLayoutA{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_STAGES + make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor( - make_smem_ptr(Bsmem), - SmemLayoutB{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_STAGES + make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE) - // MMA - // use TiledMMA --> get one thread work - auto tiled_mma = TiledMMA{}; - ThrMMA thr_mma = tiled_mma.get_slice(threadIdx.x); - auto tCgC = thr_mma.partition_C(gC); // MMA x MMA_M x MMA_N + // + // Partition the copying of A and B tiles across the threads + // - // thread private memory for MMA - auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // MMA x MMA_M x MMA_K - auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // MMA x MMA_N x MMA_K + ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); + Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) + Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE) - // thread private memory for accumulator for MMA - Tensor tCrC = thr_mma.partition_fragment_C(gC); // MMA x MMA_M x MMA_N + ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); + Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) + Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE) - clear(tCrC); + CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K - // initiate copy from global memory to shared memory - // use G2S TiledCopy --> get one thread copy work - auto g2s_tiled_copy_a = G2STiledCopyA{}; - auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(threadIdx.x); - const auto tAgA = - g2s_thr_copy_a.partition_S(gA); // CPY x CPY_M x CPY_K x NUM_TILES_K - auto tAsA = - g2s_thr_copy_a.partition_D(sA); // CPY x CPY_M x CPY_K x NUM_STAGES + // + // PREFETCH + // - auto g2s_tiled_copy_b = G2STiledCopyB{}; - auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(threadIdx.x); - const auto tBgB = - g2s_thr_copy_b.partition_S(gB); // CPY x CPY_N x CPY_K x NUM_TILES_K - auto tBsB = - g2s_thr_copy_b.partition_D(sB); // CPY x CPY_N x CPY_K x NUM_STAGES + auto K_PIPE_MAX = size<3>(tAsA); - // initiate copy from shared memory to thread private memory - // use S2R TiledCopy --> get one thread copy work - auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma); - auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(threadIdx.x); - const auto tCsA = - s2r_thr_copy_a.partition_S(sA); // CPY x CPY_M x CPY_K x NUM_STAGES - auto tCrA_copy_view = s2r_thr_copy_a.retile_D(tCrA); // CPY x CPY_M x CPY_K - - auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma); - auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(threadIdx.x); - const auto tCsB = - s2r_thr_copy_b.partition_S(sB); // CPY x CPY_N x CPY_K x NUM_STAGES - auto tCrB_copy_view = s2r_thr_copy_b.retile_D(tCrB); // CPY x CPY_N x CPY_K - - // pipeline - // counter - int itile_to_read = 0; // read index of the next tile - // 2 pointers of the buffer - int ismem_write = 0; - int ismem_read = 0; - - // NUM_STAGES = 5 --> Prefetech NUM_STAGES-1 = 4 tiles first - auto NUM_STAGES = size<3>(tAsA); + // Total count of tiles + int k_tile_count = size<3>(tAgA); + // Current tile index in gmem to read from + int k_tile_next = 0; + // Start async loads for all pipes but the last CUTE_UNROLL - for (int stage = 0; stage < NUM_STAGES - 1; ++stage) { - // prefetch - // issue copy - copy(g2s_tiled_copy_a, tAgA(_, _, _, itile_to_read), tAsA(_, _, _, stage)); - copy(g2s_tiled_copy_b, tBgB(_, _, _, itile_to_read), tBsB(_, _, _, stage)); - - // commit + for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) { + copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); + copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); cp_async_fence(); - - ismem_write++; - itile_to_read++; - } - - // wait for first tile to be prefetched: G^0 -> S^0 - cp_async_wait(); - __syncthreads(); - - // Having S^0, copy from S^0,0 to R^0 - int k = 0; - copy(s2r_tiled_copy_a, tCsA(_, _, k, ismem_read), tCrA_copy_view(_, _, k)); - copy(s2r_tiled_copy_b, tCsB(_, _, k, ismem_read), tCrB_copy_view(_, _, k)); - - // loop over tiles - auto NUM_TILES_K = size<3>(tAgA); - - CUTE_NO_UNROLL - for (int tile = 0; tile < NUM_TILES_K; ++tile) { - auto MMA_K = size<2>(tCrA); - // loop over MMAs in direction of K - - CUTE_UNROLL - for (int k = 0; k < MMA_K; ++k) { - int k_next = (k + 1) % MMA_K; - - // if this is the second last MMA, wait the next tile to be fetched - if (k == MMA_K - 1) { - cp_async_wait(); - __syncthreads(); - - ismem_read = (ismem_read + 1) % NUM_STAGES; - } - - // load data for the next MMA, from S^tile to registers - copy( - s2r_tiled_copy_a, - tCsA(_, _, k_next, ismem_read), - tCrA_copy_view(_, _, k_next)); - copy( - s2r_tiled_copy_b, - tCsB(_, _, k_next, ismem_read), - tCrB_copy_view(_, _, k_next)); - - if (k == 0) { - // prefetch the next tile - // issue copy - if (itile_to_read < NUM_TILES_K) { - copy( - g2s_tiled_copy_a, - tAgA(_, _, _, itile_to_read), - tAsA(_, _, _, ismem_write)); - copy( - g2s_tiled_copy_b, - tBgB(_, _, _, itile_to_read), - tBsB(_, _, _, ismem_write)); - - itile_to_read++; - ismem_write = (ismem_write + 1) % NUM_STAGES; - } - // commit - cp_async_fence(); - } - - // mma - gemm(tiled_mma, tCrC, tCrA(_, _, k), tCrB(_, _, k), tCrC); + --k_tile_count; + if (k_tile_count > 0) { + ++k_tile_next; } } + // + // Define A/B partitioning and C accumulators + // + + ThrMMA thr_mma = mma.get_slice(threadIdx.x); + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + + // Allocate registers for pipelining + Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K) + // Allocate the accumulators -- same size as the projected data + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + + CUTE_STATIC_ASSERT_V( + (shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N) + CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M + CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N + + // Clear the accumulators + clear(tCrC); + + // + // Copy Atom retiling + // + + TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); + ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x); + Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE) + Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K) + + TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); + ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x); + Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE) + Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K) + +#if 0 + if(thread0()) { + print(" mA : "); print( mA); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mB : "); print( mB); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sB : "); print( sB); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tBsB : "); print(tBsB); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mC : "); print( mC); print("\n"); + print(" gC : "); print( gC); print("\n"); + print("tCgC : "); print(tCgC); print("\n"); + print("tCrA : "); print(tCrA); print("\n"); + print("tCrB : "); print(tCrB); print("\n"); + print("tCrC : "); print(tCrC); print("\n"); + + print("tXsA : "); print(tXsA); print("\n"); + print("tXrA : "); print(tXrA); print("\n"); + print("tXsB : "); print(tXsB); print("\n"); + print("tXrB : "); print(tXrB); print("\n"); + } +#endif + +#if 1 + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = K_PIPE_MAX - 1; + + // Pipe slice + Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read); + Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read); + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA)); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + // Wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{})); + copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{})); + } + + // + // PIPELINED MAIN LOOP + // TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's + // cp.async instructions + // and explicit pipelines in shared memory. + // Data is read from global(k_tile_next) to shared(smem_pipe_write). + // Data is read from shared(smem_pipe_read) to registers(k_block_next). + // Data is computed on registers(b_block). + // + // This allows all copies and compute to overlap: + // Copy from gmem->smem can overlap with copies from smem->rmem and + // compute on rmem. Copy from smem->rmem can overlap with compute on rmem. + // + + CUTE_NO_UNROLL + while (k_tile_count > -(K_PIPE_MAX - 1)) { + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + if (k_block == K_BLOCK_MAX - 1) { + // Slice the smem_pipe_read smem + tXsA_p = tXsA(_, _, _, smem_pipe_read); + tXsB_p = tXsB(_, _, _, smem_pipe_read); + + // Commit the smem for smem_pipe_read + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static + copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next)); + copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) { + copy( + copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write)); + copy( + copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write)); + cp_async_fence(); + + // Advance the gmem tile + --k_tile_count; + if (k_tile_count > 0) { + ++k_tile_next; + } + + // Advance the smem pipe + smem_pipe_write = smem_pipe_read; + smem_pipe_read = + (smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1; + } + // Thread-level register gemm for k_block + gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + } + } + +#endif + + // + // Epilogue + // + copy(tCrC, tCgC); - // constexpr T alpha{1.0f}; - // constexpr T beta{0.0f}; - // axpby(alpha, tCrC, beta, tCgC); } } // namespace @@ -243,144 +339,103 @@ void cutlass_gemm( if constexpr (std::is_same_v) { using namespace cute; - // block shape and cta tiler - // additional dim: NUM_STAGES --> This is for later pipelining the k-slice - // GEMM - auto BLOCK_SIZE_M = _128{}; - auto BLOCK_SIZE_N = _128{}; - auto BLOCK_SIZE_K = _32{}; - auto NUM_STAGES = _5{}; - using CtaTiler = - decltype(make_shape(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)); + // Define shapes (dynamic) + auto prob_shape = make_shape(M, N, K); - // smem layout - // Swizzle parameters need to be chosen right - static constexpr int kShmLoadSwizzleM = 3; - static constexpr int kShmLoadSwizzleS = 3; - static constexpr int kShmLoadSwizzleB = 3; + // Define TN strides (mixed) + auto dA = make_stride(K, Int<1>{}); + auto dB = make_stride(K, Int<1>{}); + auto dC = make_stride(N, Int<1>{}); - using SmemLayoutAtom = decltype(composition( - Swizzle{}, - make_layout(make_shape(_8{}, BLOCK_SIZE_K), LayoutRight{}))); + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int<64>{}; + auto cta_tiler = make_shape(bM, bN, bK); + auto bP = Int<3>{}; - // what does this do? - // with BLOCK_SIZE_K = 32, shape: (8, 32) - // 2^M = 8, 1 new unit = 8 units --> 1 row contains 32/8 = 4 new units - // 2^S = 8, it will treat 1 row = 8 new units --> do 8-unit swizzle - // 2^B = 8, it will reset the swizzle pattern after 8 rows - // print_layout(SmemLayoutAtom{}); + // Define the smem layouts (static) + // Swizzles for LDSM and 128b k-major loads + auto swizzle_atom = composition( + Swizzle<3, 3, 3>{}, + Layout< + cute::Shape<_8, cute::Shape<_8, _8>>, + cute::Stride<_8, cute::Stride<_1, _64>>>{}); - // tile_to_shape extends the layout in LayoutLeft order - using SmemLayoutA = decltype(tile_to_shape( - SmemLayoutAtom{}, - make_shape(BLOCK_SIZE_M, BLOCK_SIZE_K, NUM_STAGES))); + auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP)); + auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP)); + auto sC = make_layout(make_shape(bM, bN)); - using SmemLayoutB = decltype(tile_to_shape( - SmemLayoutAtom{}, - make_shape(BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES))); + // Define the thread layouts (static) - // TiledMMA - using mma_op = SM80_16x8x16_F32BF16BF16F32_TN; - using mma_traits = MMA_Traits; - using mma_atom = MMA_Atom; + TiledCopy copyA = make_tiled_copy( + Copy_Atom, cute::bfloat16_t>{}, + Layout, cute::Stride<_8, _1>>{}, // Thr layout + // 16x8 k-major + Layout>{}); // Val layout 1x8 k-major + TiledCopy copyB = make_tiled_copy( + Copy_Atom, cute::bfloat16_t>{}, + Layout, cute::Stride<_8, _1>>{}, // Thr layout + // 16x8 k-major + Layout>{}); // Val layout 1x8 n-major - static constexpr int kMmaEURepeatM = 2; - static constexpr int kMmaEURepeatN = 2; - static constexpr int kMmaEURepeatK = 1; - // 32 x 2 x 2 = 128 threads + TiledMMA mmaC = make_tiled_mma( + SM80_16x8x16_F32BF16BF16F32_TN{}, + Layout>{}, // 2x2x1 MMA Atoms + Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM - using mma_atom_shape = mma_traits::Shape_MNK; - static constexpr int MmaVM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{}); - static constexpr int MmaVN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{}); - static constexpr int MmaVK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{}); + Copy_Atom s2r_atom_A; + Copy_Atom s2r_atom_B; - // this is for problem shape (16x2) x (8x2x2) x (16x1) = 32x32x16 - using MMA_EU_RepeatT = decltype(make_layout(make_shape( - Int{}, Int{}, Int{}))); - using MMA_V_T = Tile, Int, Int>; + int smem_size = int(sizeof(SharedStorage< + cute::bfloat16_t, + cute::bfloat16_t, + decltype(sA), + decltype(sB)>)); + dim3 dimBlock(size(mmaC)); + dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); - using TiledMMA = - decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_V_T{})); + auto kernel = gemm_device< + decltype(prob_shape), + decltype(cta_tiler), + cute::bfloat16_t, + decltype(dA), + decltype(sA), + decltype(copyA), + decltype(s2r_atom_A), + cute::bfloat16_t, + decltype(dB), + decltype(sB), + decltype(copyB), + decltype(s2r_atom_B), + cute::bfloat16_t, + decltype(dC), + decltype(sC), + decltype(mmaC)>; - // TiledCopy from global memory to shared memory - // uint128_t is 16 bytes = 4 floats = 8 halfs - static constexpr int NUM_VECTOR_UNITS = - sizeof(cute::uint128_t) / sizeof(DataType); - - using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL; - using g2s_copy_traits = Copy_Traits; - using g2s_copy_atom = Copy_Atom; - - // one block contains 128 threads - // --> find the compatible thread layout - using G2S_Copy_Thread_Layout = decltype(make_layout( - make_shape(_32{}, _4{}), // 32x4 = 128 threads - LayoutRight{} // A is in row-major - )); - - using G2S_Copy_Value_Layout = - decltype(make_layout(make_shape(_1{}, Int{}))); - - // This is for copy shape 32x4 of uint128_t - using G2STiledCopyA = decltype(make_tiled_copy( - g2s_copy_atom{}, G2S_Copy_Thread_Layout{}, G2S_Copy_Value_Layout{})); - - // Both A and B are in row-major so use the same TiledCopy for B - using G2STiledCopyB = G2STiledCopyA; - - // CopyAtom from shared memory to registers - // Why no need to do tiling atom here? Because we will do it later with - // the information from TiledMMA - using s2r_copy_op = SM75_U32x4_LDSM_N; - using s2r_copy_traits = Copy_Traits; - using s2r_copy_atom = Copy_Atom; - - using S2RCopyAtomA = s2r_copy_atom; - using S2RCopyAtomB = s2r_copy_atom; - - // print_latex( - // make_tiled_copy(S2RCopyAtomA{}, make_layout(Shape<_32>{}), - // make_layout(Shape<_1, _8>{})) - // ); - - // grid, block - dim3 block{size(TiledMMA{}), 1U, 1U}; - dim3 grid{ - size(ceil_div(static_cast(N), BLOCK_SIZE_N)), - size(ceil_div(static_cast(M), BLOCK_SIZE_M)), - 1U}; - - static constexpr int smem_size = - (cosize_v + cosize_v)*sizeof(DataType); - - auto kernel = cute_gemm_v02< - DataType, - CtaTiler, - SmemLayoutA, - SmemLayoutB, - TiledMMA, - G2STiledCopyA, - G2STiledCopyB, - S2RCopyAtomA, - S2RCopyAtomB>; - - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + configure_matmul(kernel, smem_size); enc.add_kernel_node( kernel, - grid, - block, + dimGrid, + dimBlock, smem_size, - M, - N, - K, - a.data(), - K, - b.data(), - K, - out.data(), - N); + prob_shape, + cta_tiler, + a.data(), + dA, + sA, + copyA, + s2r_atom_A, + b.data(), + dB, + sB, + copyB, + s2r_atom_B, + out.data(), + dC, + sC, + mmaC); } else { throw std::runtime_error("Only bfloat16 supported"); }