mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Reset cutlass gemm to working state again
This commit is contained in:
		| @@ -10,219 +10,315 @@ namespace mlx::core::cu { | |||||||
|  |  | ||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
|  | template <typename Kernel> | ||||||
|  | void configure_matmul(Kernel kernel, int smem_size) { | ||||||
|  |   static bool initialized = false; | ||||||
|  |   if (!initialized) { | ||||||
|  |     initialized = true; | ||||||
|  |     cudaFuncSetAttribute( | ||||||
|  |         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); | ||||||
|  |     cudaFuncSetAttribute( | ||||||
|  |         kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <class ElementA, class ElementB, class SmemLayoutA, class SmemLayoutB> | ||||||
|  | struct SharedStorage { | ||||||
|  |   cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A; | ||||||
|  |   cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B; | ||||||
|  | }; | ||||||
|  |  | ||||||
| template < | template < | ||||||
|     class T, |     class ProblemShape, | ||||||
|     class CtaTiler, |     class CtaTiler, | ||||||
|     class SmemLayoutA, |     class TA, | ||||||
|     class SmemLayoutB, |     class AStride, | ||||||
|     class TiledMMA, |     class ASmemLayout, | ||||||
|     class G2STiledCopyA, |     class TiledCopyA, | ||||||
|     class G2STiledCopyB, |     class S2RAtomA, | ||||||
|     class S2RCopyAtomA, |     class TB, | ||||||
|     class S2RCopyAtomB> |     class BStride, | ||||||
| __global__ void cute_gemm_v02( |     class BSmemLayout, | ||||||
|     unsigned int M, |     class TiledCopyB, | ||||||
|     unsigned int N, |     class S2RAtomB, | ||||||
|     unsigned int K, |     class TC, | ||||||
|     const T* A, |     class CStride, | ||||||
|     size_t lda, |     class CSmemLayout, | ||||||
|     const T* B, |     class TiledMma> | ||||||
|     size_t ldb, | __global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device( | ||||||
|     T* C, |     ProblemShape shape_MNK, | ||||||
|     size_t ldc) { |     CtaTiler cta_tiler, | ||||||
|  |     TA const* A, | ||||||
|  |     AStride dA, | ||||||
|  |     ASmemLayout sA_layout, | ||||||
|  |     TiledCopyA copy_a, | ||||||
|  |     S2RAtomA s2r_atom_a, | ||||||
|  |     TB const* B, | ||||||
|  |     BStride dB, | ||||||
|  |     BSmemLayout sB_layout, | ||||||
|  |     TiledCopyB copy_b, | ||||||
|  |     S2RAtomB s2r_atom_b, | ||||||
|  |     TC* C, | ||||||
|  |     CStride dC, | ||||||
|  |     CSmemLayout, | ||||||
|  |     TiledMma mma) { | ||||||
|   using namespace cute; |   using namespace cute; | ||||||
|  |  | ||||||
|   // global full tensor |   // Preconditions | ||||||
|   // shape |   CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) | ||||||
|   auto shape_MNK = make_shape(M, N, K); |   CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) | ||||||
|  |  | ||||||
|   // stride |   CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads | ||||||
|   // cublas covenience for TN gemm |   CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads | ||||||
|   // all matrices are in column major |  | ||||||
|   // A (m,k) --> transpose --> A(k, m) --> cute layout: A (m, k) : (k, 1) --> |   static_assert(is_static<ASmemLayout>::value); | ||||||
|   // lda = k B (k,n) --> cute layout: B (n, k) : (k, 1) --> ldb = k C (m,n) --> |   static_assert(is_static<BSmemLayout>::value); | ||||||
|   // cute layout: C (m, n) : (1, m) --> ldc = m |   static_assert(is_static<CSmemLayout>::value); | ||||||
|   auto dA = make_stride(lda, _1{}); |  | ||||||
|   auto dB = make_stride(ldb, _1{}); |   CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M | ||||||
|   auto dC = make_stride(_1{}, ldc); |   CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M | ||||||
|  |   CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N | ||||||
|  |   CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N | ||||||
|  |   CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K | ||||||
|  |   CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K | ||||||
|  |  | ||||||
|  |   CUTE_STATIC_ASSERT_V( | ||||||
|  |       congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK | ||||||
|  |   CUTE_STATIC_ASSERT_V( | ||||||
|  |       congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK | ||||||
|  |   CUTE_STATIC_ASSERT_V( | ||||||
|  |       congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN | ||||||
|  |  | ||||||
|  |   // | ||||||
|  |   // Full and Tiled Tensors | ||||||
|  |   // | ||||||
|  |  | ||||||
|  |   // Represent the full tensors | ||||||
|   Tensor mA = |   Tensor mA = | ||||||
|       make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // M x K |       make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K) | ||||||
|   Tensor mB = |   Tensor mB = | ||||||
|       make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // N x K |       make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K) | ||||||
|   Tensor mC = |   Tensor mC = | ||||||
|       make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // M x N |       make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N) | ||||||
|  |  | ||||||
|   // global tile tensor |   // Get the appropriate blocks for this thread block | ||||||
|   auto cta_tiler = CtaTiler{}; |   auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) | ||||||
|   auto cta_coord = make_coord(blockIdx.y, blockIdx.x, _); |  | ||||||
|   Tensor gA = local_tile( |   Tensor gA = local_tile( | ||||||
|       mA, |       mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k) | ||||||
|       cta_tiler, |  | ||||||
|       cta_coord, |  | ||||||
|       Step<_1, X, _1>{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_TILES_K |  | ||||||
|   Tensor gB = local_tile( |   Tensor gB = local_tile( | ||||||
|       mB, |       mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k) | ||||||
|       cta_tiler, |   Tensor gC = | ||||||
|       cta_coord, |       local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N) | ||||||
|       Step<X, _1, _1>{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_TILES_K |  | ||||||
|   Tensor gC = local_tile( |  | ||||||
|       mC, |  | ||||||
|       cta_tiler, |  | ||||||
|       cta_coord, |  | ||||||
|       Step<_1, _1, X>{}); // BLOCK_SIZE_M x BLOCK_SIZE_N |  | ||||||
|  |  | ||||||
|   // shared memory |  | ||||||
|   // __shared__ T Asmem[cosize_v<SmemLayoutA>]; |  | ||||||
|   // __shared__ T Bsmem[cosize_v<SmemLayoutB>]; |  | ||||||
|  |  | ||||||
|   extern __shared__ T smem[]; |  | ||||||
|   T* Asmem = smem; |  | ||||||
|   T* Bsmem = smem + cosize_v<SmemLayoutA>; |  | ||||||
|  |  | ||||||
|  |   // Shared memory buffers | ||||||
|  |   extern __shared__ char shared_memory[]; | ||||||
|  |   using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>; | ||||||
|  |   SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory); | ||||||
|   Tensor sA = make_tensor( |   Tensor sA = make_tensor( | ||||||
|       make_smem_ptr(Asmem), |       make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE) | ||||||
|       SmemLayoutA{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_STAGES |  | ||||||
|   Tensor sB = make_tensor( |   Tensor sB = make_tensor( | ||||||
|       make_smem_ptr(Bsmem), |       make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE) | ||||||
|       SmemLayoutB{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_STAGES |  | ||||||
|  |  | ||||||
|   // MMA |   // | ||||||
|   // use TiledMMA --> get one thread work |   // Partition the copying of A and B tiles across the threads | ||||||
|   auto tiled_mma = TiledMMA{}; |   // | ||||||
|   ThrMMA thr_mma = tiled_mma.get_slice(threadIdx.x); |  | ||||||
|   auto tCgC = thr_mma.partition_C(gC); // MMA x MMA_M x MMA_N |  | ||||||
|  |  | ||||||
|   // thread private memory for MMA |   ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); | ||||||
|   auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // MMA x MMA_M x MMA_K |   Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) | ||||||
|   auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // MMA x MMA_N x MMA_K |   Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE) | ||||||
|  |  | ||||||
|   // thread private memory for accumulator for MMA |   ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); | ||||||
|   Tensor tCrC = thr_mma.partition_fragment_C(gC); // MMA x MMA_M x MMA_N |   Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) | ||||||
|  |   Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE) | ||||||
|  |  | ||||||
|   clear(tCrC); |   CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M | ||||||
|  |   CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K | ||||||
|  |   CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N | ||||||
|  |   CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K | ||||||
|  |  | ||||||
|   // initiate copy from global memory to shared memory |   // | ||||||
|   // use G2S TiledCopy --> get one thread copy work |   // PREFETCH | ||||||
|   auto g2s_tiled_copy_a = G2STiledCopyA{}; |   // | ||||||
|   auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(threadIdx.x); |  | ||||||
|   const auto tAgA = |  | ||||||
|       g2s_thr_copy_a.partition_S(gA); // CPY x CPY_M x CPY_K x NUM_TILES_K |  | ||||||
|   auto tAsA = |  | ||||||
|       g2s_thr_copy_a.partition_D(sA); // CPY x CPY_M x CPY_K x NUM_STAGES |  | ||||||
|  |  | ||||||
|   auto g2s_tiled_copy_b = G2STiledCopyB{}; |   auto K_PIPE_MAX = size<3>(tAsA); | ||||||
|   auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(threadIdx.x); |  | ||||||
|   const auto tBgB = |  | ||||||
|       g2s_thr_copy_b.partition_S(gB); // CPY x CPY_N x CPY_K x NUM_TILES_K |  | ||||||
|   auto tBsB = |  | ||||||
|       g2s_thr_copy_b.partition_D(sB); // CPY x CPY_N x CPY_K x NUM_STAGES |  | ||||||
|  |  | ||||||
|   // initiate copy from shared memory to thread private memory |   // Total count of tiles | ||||||
|   // use S2R TiledCopy --> get one thread copy work |   int k_tile_count = size<3>(tAgA); | ||||||
|   auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma); |   // Current tile index in gmem to read from | ||||||
|   auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(threadIdx.x); |   int k_tile_next = 0; | ||||||
|   const auto tCsA = |  | ||||||
|       s2r_thr_copy_a.partition_S(sA); // CPY x CPY_M x CPY_K x NUM_STAGES |  | ||||||
|   auto tCrA_copy_view = s2r_thr_copy_a.retile_D(tCrA); // CPY x CPY_M x CPY_K |  | ||||||
|  |  | ||||||
|   auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma); |  | ||||||
|   auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(threadIdx.x); |  | ||||||
|   const auto tCsB = |  | ||||||
|       s2r_thr_copy_b.partition_S(sB); // CPY x CPY_N x CPY_K x NUM_STAGES |  | ||||||
|   auto tCrB_copy_view = s2r_thr_copy_b.retile_D(tCrB); // CPY x CPY_N x CPY_K |  | ||||||
|  |  | ||||||
|   // pipeline |  | ||||||
|   // counter |  | ||||||
|   int itile_to_read = 0; // read index of the next tile |  | ||||||
|   // 2 pointers of the buffer |  | ||||||
|   int ismem_write = 0; |  | ||||||
|   int ismem_read = 0; |  | ||||||
|  |  | ||||||
|   // NUM_STAGES = 5 --> Prefetech NUM_STAGES-1 = 4 tiles first |  | ||||||
|   auto NUM_STAGES = size<3>(tAsA); |  | ||||||
|  |  | ||||||
|  |   // Start async loads for all pipes but the last | ||||||
|   CUTE_UNROLL |   CUTE_UNROLL | ||||||
|   for (int stage = 0; stage < NUM_STAGES - 1; ++stage) { |   for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) { | ||||||
|     // prefetch |     copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); | ||||||
|     // issue copy |     copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); | ||||||
|     copy(g2s_tiled_copy_a, tAgA(_, _, _, itile_to_read), tAsA(_, _, _, stage)); |  | ||||||
|     copy(g2s_tiled_copy_b, tBgB(_, _, _, itile_to_read), tBsB(_, _, _, stage)); |  | ||||||
|  |  | ||||||
|     // commit |  | ||||||
|     cp_async_fence(); |     cp_async_fence(); | ||||||
|  |     --k_tile_count; | ||||||
|     ismem_write++; |     if (k_tile_count > 0) { | ||||||
|     itile_to_read++; |       ++k_tile_next; | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // wait for first tile to be prefetched: G^0 -> S^0 |  | ||||||
|   cp_async_wait<NUM_STAGES - 2>(); |  | ||||||
|   __syncthreads(); |  | ||||||
|  |  | ||||||
|   // Having S^0, copy from S^0,0 to R^0 |  | ||||||
|   int k = 0; |  | ||||||
|   copy(s2r_tiled_copy_a, tCsA(_, _, k, ismem_read), tCrA_copy_view(_, _, k)); |  | ||||||
|   copy(s2r_tiled_copy_b, tCsB(_, _, k, ismem_read), tCrB_copy_view(_, _, k)); |  | ||||||
|  |  | ||||||
|   // loop over tiles |  | ||||||
|   auto NUM_TILES_K = size<3>(tAgA); |  | ||||||
|  |  | ||||||
|   CUTE_NO_UNROLL |  | ||||||
|   for (int tile = 0; tile < NUM_TILES_K; ++tile) { |  | ||||||
|     auto MMA_K = size<2>(tCrA); |  | ||||||
|     // loop over MMAs in direction of K |  | ||||||
|  |  | ||||||
|     CUTE_UNROLL |  | ||||||
|     for (int k = 0; k < MMA_K; ++k) { |  | ||||||
|       int k_next = (k + 1) % MMA_K; |  | ||||||
|  |  | ||||||
|       // if this is the second last MMA, wait the next tile to be fetched |  | ||||||
|       if (k == MMA_K - 1) { |  | ||||||
|         cp_async_wait<NUM_STAGES - 2>(); |  | ||||||
|         __syncthreads(); |  | ||||||
|  |  | ||||||
|         ismem_read = (ismem_read + 1) % NUM_STAGES; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       // load data for the next MMA, from S^tile to registers |  | ||||||
|       copy( |  | ||||||
|           s2r_tiled_copy_a, |  | ||||||
|           tCsA(_, _, k_next, ismem_read), |  | ||||||
|           tCrA_copy_view(_, _, k_next)); |  | ||||||
|       copy( |  | ||||||
|           s2r_tiled_copy_b, |  | ||||||
|           tCsB(_, _, k_next, ismem_read), |  | ||||||
|           tCrB_copy_view(_, _, k_next)); |  | ||||||
|  |  | ||||||
|       if (k == 0) { |  | ||||||
|         // prefetch the next tile |  | ||||||
|         // issue copy |  | ||||||
|         if (itile_to_read < NUM_TILES_K) { |  | ||||||
|           copy( |  | ||||||
|               g2s_tiled_copy_a, |  | ||||||
|               tAgA(_, _, _, itile_to_read), |  | ||||||
|               tAsA(_, _, _, ismem_write)); |  | ||||||
|           copy( |  | ||||||
|               g2s_tiled_copy_b, |  | ||||||
|               tBgB(_, _, _, itile_to_read), |  | ||||||
|               tBsB(_, _, _, ismem_write)); |  | ||||||
|  |  | ||||||
|           itile_to_read++; |  | ||||||
|           ismem_write = (ismem_write + 1) % NUM_STAGES; |  | ||||||
|         } |  | ||||||
|         // commit |  | ||||||
|         cp_async_fence(); |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       // mma |  | ||||||
|       gemm(tiled_mma, tCrC, tCrA(_, _, k), tCrB(_, _, k), tCrC); |  | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   // | ||||||
|  |   // Define A/B partitioning and C accumulators | ||||||
|  |   // | ||||||
|  |  | ||||||
|  |   ThrMMA thr_mma = mma.get_slice(threadIdx.x); | ||||||
|  |   Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) | ||||||
|  |  | ||||||
|  |   // Allocate registers for pipelining | ||||||
|  |   Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K) | ||||||
|  |   Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K) | ||||||
|  |   // Allocate the accumulators -- same size as the projected data | ||||||
|  |   Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) | ||||||
|  |  | ||||||
|  |   CUTE_STATIC_ASSERT_V( | ||||||
|  |       (shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N) | ||||||
|  |   CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M | ||||||
|  |   CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N | ||||||
|  |  | ||||||
|  |   // Clear the accumulators | ||||||
|  |   clear(tCrC); | ||||||
|  |  | ||||||
|  |   // | ||||||
|  |   // Copy Atom retiling | ||||||
|  |   // | ||||||
|  |  | ||||||
|  |   TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); | ||||||
|  |   ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x); | ||||||
|  |   Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE) | ||||||
|  |   Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K) | ||||||
|  |  | ||||||
|  |   TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); | ||||||
|  |   ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x); | ||||||
|  |   Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE) | ||||||
|  |   Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K) | ||||||
|  |  | ||||||
|  | #if 0 | ||||||
|  |   if(thread0()) { | ||||||
|  |     print("  mA : "); print(  mA); print("\n"); | ||||||
|  |     print("  gA : "); print(  gA); print("\n"); | ||||||
|  |     print("  sA : "); print(  sA); print("\n"); | ||||||
|  |     print("tAgA : "); print(tAgA); print("\n"); | ||||||
|  |     print("tAsA : "); print(tAsA); print("\n"); | ||||||
|  |   } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #if 0 | ||||||
|  |   if(thread0()) { | ||||||
|  |     print("  mB : "); print(  mB); print("\n"); | ||||||
|  |     print("  gB : "); print(  gB); print("\n"); | ||||||
|  |     print("  sB : "); print(  sB); print("\n"); | ||||||
|  |     print("tBgB : "); print(tBgB); print("\n"); | ||||||
|  |     print("tBsB : "); print(tBsB); print("\n"); | ||||||
|  |   } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #if 0 | ||||||
|  |   if(thread0()) { | ||||||
|  |     print("  mC : "); print(  mC); print("\n"); | ||||||
|  |     print("  gC : "); print(  gC); print("\n"); | ||||||
|  |     print("tCgC : "); print(tCgC); print("\n"); | ||||||
|  |     print("tCrA : "); print(tCrA); print("\n"); | ||||||
|  |     print("tCrB : "); print(tCrB); print("\n"); | ||||||
|  |     print("tCrC : "); print(tCrC); print("\n"); | ||||||
|  |  | ||||||
|  |     print("tXsA : "); print(tXsA); print("\n"); | ||||||
|  |     print("tXrA : "); print(tXrA); print("\n"); | ||||||
|  |     print("tXsB : "); print(tXsB); print("\n"); | ||||||
|  |     print("tXrB : "); print(tXrB); print("\n"); | ||||||
|  |   } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #if 1 | ||||||
|  |  | ||||||
|  |   // Current pipe index in smem to read from | ||||||
|  |   int smem_pipe_read = 0; | ||||||
|  |   // Current pipe index in smem to write to | ||||||
|  |   int smem_pipe_write = K_PIPE_MAX - 1; | ||||||
|  |  | ||||||
|  |   // Pipe slice | ||||||
|  |   Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read); | ||||||
|  |   Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read); | ||||||
|  |  | ||||||
|  |   // Size of the register pipeline | ||||||
|  |   auto K_BLOCK_MAX = size<2>(tCrA); | ||||||
|  |   CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA)); | ||||||
|  |  | ||||||
|  |   // PREFETCH register pipeline | ||||||
|  |   if (K_BLOCK_MAX > 1) { | ||||||
|  |     // Wait until our first prefetched tile is loaded in | ||||||
|  |     cp_async_wait<K_PIPE_MAX - 2>(); | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     // Prefetch the first rmem from the first k-tile | ||||||
|  |     copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{})); | ||||||
|  |     copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{})); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // | ||||||
|  |   // PIPELINED MAIN LOOP | ||||||
|  |   // TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's | ||||||
|  |   // cp.async instructions | ||||||
|  |   //           and explicit pipelines in shared memory. | ||||||
|  |   //   Data is read from global(k_tile_next) to shared(smem_pipe_write). | ||||||
|  |   //   Data is read from shared(smem_pipe_read) to registers(k_block_next). | ||||||
|  |   //   Data is computed on registers(b_block). | ||||||
|  |   // | ||||||
|  |   //   This allows all copies and compute to overlap: | ||||||
|  |   //     Copy from gmem->smem can overlap with copies from smem->rmem and | ||||||
|  |   //     compute on rmem. Copy from smem->rmem can overlap with compute on rmem. | ||||||
|  |   // | ||||||
|  |  | ||||||
|  |   CUTE_NO_UNROLL | ||||||
|  |   while (k_tile_count > -(K_PIPE_MAX - 1)) { | ||||||
|  |     CUTE_UNROLL | ||||||
|  |     for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { | ||||||
|  |       if (k_block == K_BLOCK_MAX - 1) { | ||||||
|  |         // Slice the smem_pipe_read smem | ||||||
|  |         tXsA_p = tXsA(_, _, _, smem_pipe_read); | ||||||
|  |         tXsB_p = tXsB(_, _, _, smem_pipe_read); | ||||||
|  |  | ||||||
|  |         // Commit the smem for smem_pipe_read | ||||||
|  |         cp_async_wait<K_PIPE_MAX - 2>(); | ||||||
|  |         __syncthreads(); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // Load A, B shmem->regs for k_block+1 | ||||||
|  |       auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static | ||||||
|  |       copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next)); | ||||||
|  |       copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next)); | ||||||
|  |       // Copy gmem to smem before computing gemm on each k-pipe | ||||||
|  |       if (k_block == 0) { | ||||||
|  |         copy( | ||||||
|  |             copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write)); | ||||||
|  |         copy( | ||||||
|  |             copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write)); | ||||||
|  |         cp_async_fence(); | ||||||
|  |  | ||||||
|  |         // Advance the gmem tile | ||||||
|  |         --k_tile_count; | ||||||
|  |         if (k_tile_count > 0) { | ||||||
|  |           ++k_tile_next; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Advance the smem pipe | ||||||
|  |         smem_pipe_write = smem_pipe_read; | ||||||
|  |         smem_pipe_read = | ||||||
|  |             (smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1; | ||||||
|  |       } | ||||||
|  |       // Thread-level register gemm for k_block | ||||||
|  |       gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |   // | ||||||
|  |   // Epilogue | ||||||
|  |   // | ||||||
|  |  | ||||||
|   copy(tCrC, tCgC); |   copy(tCrC, tCgC); | ||||||
|   // constexpr T alpha{1.0f}; |  | ||||||
|   // constexpr T beta{0.0f}; |  | ||||||
|   // axpby(alpha, tCrC, beta, tCgC); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| } // namespace | } // namespace | ||||||
| @@ -243,144 +339,103 @@ void cutlass_gemm( | |||||||
|     if constexpr (std::is_same_v<DataType, __nv_bfloat16>) { |     if constexpr (std::is_same_v<DataType, __nv_bfloat16>) { | ||||||
|       using namespace cute; |       using namespace cute; | ||||||
|  |  | ||||||
|       // block shape and cta tiler |       // Define shapes (dynamic) | ||||||
|       // additional dim: NUM_STAGES --> This is for later pipelining the k-slice |       auto prob_shape = make_shape(M, N, K); | ||||||
|       // GEMM |  | ||||||
|       auto BLOCK_SIZE_M = _128{}; |  | ||||||
|       auto BLOCK_SIZE_N = _128{}; |  | ||||||
|       auto BLOCK_SIZE_K = _32{}; |  | ||||||
|       auto NUM_STAGES = _5{}; |  | ||||||
|       using CtaTiler = |  | ||||||
|           decltype(make_shape(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)); |  | ||||||
|  |  | ||||||
|       // smem layout |       // Define TN strides (mixed) | ||||||
|       // Swizzle parameters need to be chosen right |       auto dA = make_stride(K, Int<1>{}); | ||||||
|       static constexpr int kShmLoadSwizzleM = 3; |       auto dB = make_stride(K, Int<1>{}); | ||||||
|       static constexpr int kShmLoadSwizzleS = 3; |       auto dC = make_stride(N, Int<1>{}); | ||||||
|       static constexpr int kShmLoadSwizzleB = 3; |  | ||||||
|  |  | ||||||
|       using SmemLayoutAtom = decltype(composition( |       // Define CTA tile sizes (static) | ||||||
|           Swizzle<kShmLoadSwizzleB, kShmLoadSwizzleM, kShmLoadSwizzleS>{}, |       auto bM = Int<128>{}; | ||||||
|           make_layout(make_shape(_8{}, BLOCK_SIZE_K), LayoutRight{}))); |       auto bN = Int<128>{}; | ||||||
|  |       auto bK = Int<64>{}; | ||||||
|  |       auto cta_tiler = make_shape(bM, bN, bK); | ||||||
|  |       auto bP = Int<3>{}; | ||||||
|  |  | ||||||
|       // what does this do? |       // Define the smem layouts (static) | ||||||
|       // with BLOCK_SIZE_K = 32, shape: (8, 32) |       // Swizzles for LDSM and 128b k-major loads | ||||||
|       // 2^M = 8, 1 new unit = 8 units --> 1 row contains 32/8 = 4 new units |       auto swizzle_atom = composition( | ||||||
|       // 2^S = 8, it will treat 1 row = 8 new units --> do 8-unit swizzle |           Swizzle<3, 3, 3>{}, | ||||||
|       // 2^B = 8, it will reset the swizzle pattern after 8 rows |           Layout< | ||||||
|       // print_layout(SmemLayoutAtom{}); |               cute::Shape<_8, cute::Shape<_8, _8>>, | ||||||
|  |               cute::Stride<_8, cute::Stride<_1, _64>>>{}); | ||||||
|  |  | ||||||
|       // tile_to_shape extends the layout in LayoutLeft order |       auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP)); | ||||||
|       using SmemLayoutA = decltype(tile_to_shape( |       auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP)); | ||||||
|           SmemLayoutAtom{}, |       auto sC = make_layout(make_shape(bM, bN)); | ||||||
|           make_shape(BLOCK_SIZE_M, BLOCK_SIZE_K, NUM_STAGES))); |  | ||||||
|  |  | ||||||
|       using SmemLayoutB = decltype(tile_to_shape( |       // Define the thread layouts (static) | ||||||
|           SmemLayoutAtom{}, |  | ||||||
|           make_shape(BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES))); |  | ||||||
|  |  | ||||||
|       // TiledMMA |       TiledCopy copyA = make_tiled_copy( | ||||||
|       using mma_op = SM80_16x8x16_F32BF16BF16F32_TN; |           Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{}, | ||||||
|       using mma_traits = MMA_Traits<mma_op>; |           Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout | ||||||
|       using mma_atom = MMA_Atom<mma_traits>; |                                                                 // 16x8 k-major | ||||||
|  |           Layout<cute::Shape<_1, _8>>{}); // Val layout  1x8 k-major | ||||||
|  |       TiledCopy copyB = make_tiled_copy( | ||||||
|  |           Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{}, | ||||||
|  |           Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout | ||||||
|  |                                                                 // 16x8 k-major | ||||||
|  |           Layout<cute::Shape<_1, _8>>{}); // Val layout  1x8 n-major | ||||||
|  |  | ||||||
|       static constexpr int kMmaEURepeatM = 2; |       TiledMMA mmaC = make_tiled_mma( | ||||||
|       static constexpr int kMmaEURepeatN = 2; |           SM80_16x8x16_F32BF16BF16F32_TN{}, | ||||||
|       static constexpr int kMmaEURepeatK = 1; |           Layout<cute::Shape<_2, _2>>{}, // 2x2x1 MMA Atoms | ||||||
|       // 32 x 2 x 2 = 128 threads |           Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM | ||||||
|  |  | ||||||
|       using mma_atom_shape = mma_traits::Shape_MNK; |       Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_A; | ||||||
|       static constexpr int MmaVM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{}); |       Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_B; | ||||||
|       static constexpr int MmaVN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{}); |  | ||||||
|       static constexpr int MmaVK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{}); |  | ||||||
|  |  | ||||||
|       // this is for problem shape (16x2) x (8x2x2) x (16x1) = 32x32x16 |       int smem_size = int(sizeof(SharedStorage< | ||||||
|       using MMA_EU_RepeatT = decltype(make_layout(make_shape( |                                  cute::bfloat16_t, | ||||||
|           Int<kMmaEURepeatM>{}, Int<kMmaEURepeatN>{}, Int<kMmaEURepeatK>{}))); |                                  cute::bfloat16_t, | ||||||
|       using MMA_V_T = Tile<Int<MmaVM>, Int<MmaVN>, Int<MmaVK>>; |                                  decltype(sA), | ||||||
|  |                                  decltype(sB)>)); | ||||||
|  |       dim3 dimBlock(size(mmaC)); | ||||||
|  |       dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); | ||||||
|  |  | ||||||
|       using TiledMMA = |       auto kernel = gemm_device< | ||||||
|           decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_V_T{})); |           decltype(prob_shape), | ||||||
|  |           decltype(cta_tiler), | ||||||
|  |           cute::bfloat16_t, | ||||||
|  |           decltype(dA), | ||||||
|  |           decltype(sA), | ||||||
|  |           decltype(copyA), | ||||||
|  |           decltype(s2r_atom_A), | ||||||
|  |           cute::bfloat16_t, | ||||||
|  |           decltype(dB), | ||||||
|  |           decltype(sB), | ||||||
|  |           decltype(copyB), | ||||||
|  |           decltype(s2r_atom_B), | ||||||
|  |           cute::bfloat16_t, | ||||||
|  |           decltype(dC), | ||||||
|  |           decltype(sC), | ||||||
|  |           decltype(mmaC)>; | ||||||
|  |  | ||||||
|       // TiledCopy from global memory to shared memory |       configure_matmul(kernel, smem_size); | ||||||
|       // uint128_t is 16 bytes = 4 floats = 8 halfs |  | ||||||
|       static constexpr int NUM_VECTOR_UNITS = |  | ||||||
|           sizeof(cute::uint128_t) / sizeof(DataType); |  | ||||||
|  |  | ||||||
|       using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>; |  | ||||||
|       using g2s_copy_traits = Copy_Traits<g2s_copy_op>; |  | ||||||
|       using g2s_copy_atom = Copy_Atom<g2s_copy_traits, DataType>; |  | ||||||
|  |  | ||||||
|       // one block contains 128 threads |  | ||||||
|       // --> find the compatible thread layout |  | ||||||
|       using G2S_Copy_Thread_Layout = decltype(make_layout( |  | ||||||
|           make_shape(_32{}, _4{}), // 32x4 = 128 threads |  | ||||||
|           LayoutRight{} // A is in row-major |  | ||||||
|           )); |  | ||||||
|  |  | ||||||
|       using G2S_Copy_Value_Layout = |  | ||||||
|           decltype(make_layout(make_shape(_1{}, Int<NUM_VECTOR_UNITS>{}))); |  | ||||||
|  |  | ||||||
|       // This is for copy shape 32x4 of uint128_t |  | ||||||
|       using G2STiledCopyA = decltype(make_tiled_copy( |  | ||||||
|           g2s_copy_atom{}, G2S_Copy_Thread_Layout{}, G2S_Copy_Value_Layout{})); |  | ||||||
|  |  | ||||||
|       // Both A and B are in row-major so use the same TiledCopy for B |  | ||||||
|       using G2STiledCopyB = G2STiledCopyA; |  | ||||||
|  |  | ||||||
|       // CopyAtom from shared memory to registers |  | ||||||
|       // Why no need to do tiling atom here? Because we will do it later with |  | ||||||
|       // the information from TiledMMA |  | ||||||
|       using s2r_copy_op = SM75_U32x4_LDSM_N; |  | ||||||
|       using s2r_copy_traits = Copy_Traits<s2r_copy_op>; |  | ||||||
|       using s2r_copy_atom = Copy_Atom<s2r_copy_traits, DataType>; |  | ||||||
|  |  | ||||||
|       using S2RCopyAtomA = s2r_copy_atom; |  | ||||||
|       using S2RCopyAtomB = s2r_copy_atom; |  | ||||||
|  |  | ||||||
|       // print_latex( |  | ||||||
|       //     make_tiled_copy(S2RCopyAtomA{}, make_layout(Shape<_32>{}), |  | ||||||
|       //     make_layout(Shape<_1, _8>{})) |  | ||||||
|       // ); |  | ||||||
|  |  | ||||||
|       // grid, block |  | ||||||
|       dim3 block{size(TiledMMA{}), 1U, 1U}; |  | ||||||
|       dim3 grid{ |  | ||||||
|           size(ceil_div(static_cast<unsigned int>(N), BLOCK_SIZE_N)), |  | ||||||
|           size(ceil_div(static_cast<unsigned int>(M), BLOCK_SIZE_M)), |  | ||||||
|           1U}; |  | ||||||
|  |  | ||||||
|       static constexpr int smem_size = |  | ||||||
|           (cosize_v<SmemLayoutA> + cosize_v<SmemLayoutB>)*sizeof(DataType); |  | ||||||
|  |  | ||||||
|       auto kernel = cute_gemm_v02< |  | ||||||
|           DataType, |  | ||||||
|           CtaTiler, |  | ||||||
|           SmemLayoutA, |  | ||||||
|           SmemLayoutB, |  | ||||||
|           TiledMMA, |  | ||||||
|           G2STiledCopyA, |  | ||||||
|           G2STiledCopyB, |  | ||||||
|           S2RCopyAtomA, |  | ||||||
|           S2RCopyAtomB>; |  | ||||||
|  |  | ||||||
|       cudaFuncSetAttribute( |  | ||||||
|           kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); |  | ||||||
|  |  | ||||||
|       enc.add_kernel_node( |       enc.add_kernel_node( | ||||||
|           kernel, |           kernel, | ||||||
|           grid, |           dimGrid, | ||||||
|           block, |           dimBlock, | ||||||
|           smem_size, |           smem_size, | ||||||
|           M, |           prob_shape, | ||||||
|           N, |           cta_tiler, | ||||||
|           K, |           a.data<cute::bfloat16_t>(), | ||||||
|           a.data<DataType>(), |           dA, | ||||||
|           K, |           sA, | ||||||
|           b.data<DataType>(), |           copyA, | ||||||
|           K, |           s2r_atom_A, | ||||||
|           out.data<DataType>(), |           b.data<cute::bfloat16_t>(), | ||||||
|           N); |           dB, | ||||||
|  |           sB, | ||||||
|  |           copyB, | ||||||
|  |           s2r_atom_B, | ||||||
|  |           out.data<cute::bfloat16_t>(), | ||||||
|  |           dC, | ||||||
|  |           sC, | ||||||
|  |           mmaC); | ||||||
|     } else { |     } else { | ||||||
|       throw std::runtime_error("Only bfloat16 supported"); |       throw std::runtime_error("Only bfloat16 supported"); | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos