mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Improve the cutlass gemm
This commit is contained in:
		| @@ -5,11 +5,21 @@ | |||||||
| #include "mlx/dtype_utils.h" | #include "mlx/dtype_utils.h" | ||||||
|  |  | ||||||
| #include <cute/tensor.hpp> | #include <cute/tensor.hpp> | ||||||
|  | #include <cutlass/arch/arch.h> | ||||||
|  | #include <cutlass/cutlass.h> | ||||||
|  | #include <cutlass/gemm/device/gemm.h> | ||||||
|  | #include <cutlass/layout/matrix.h> | ||||||
|  | #include <cutlass/numeric_types.h> | ||||||
|  |  | ||||||
|  | #include <iostream> | ||||||
|  |  | ||||||
| namespace mlx::core::cu { | namespace mlx::core::cu { | ||||||
|  |  | ||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
|  | using namespace cute; | ||||||
|  | using bf16 = cute::bfloat16_t; | ||||||
|  |  | ||||||
| template <typename Kernel> | template <typename Kernel> | ||||||
| void configure_matmul(Kernel kernel, int smem_size) { | void configure_matmul(Kernel kernel, int smem_size) { | ||||||
|   static bool initialized = false; |   static bool initialized = false; | ||||||
| @@ -17,308 +27,278 @@ void configure_matmul(Kernel kernel, int smem_size) { | |||||||
|     initialized = true; |     initialized = true; | ||||||
|     cudaFuncSetAttribute( |     cudaFuncSetAttribute( | ||||||
|         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); |         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); | ||||||
|     cudaFuncSetAttribute( |  | ||||||
|         kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); |  | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| template <class ElementA, class ElementB, class SmemLayoutA, class SmemLayoutB> | template <bool transpose, typename Tiler> | ||||||
| struct SharedStorage { | constexpr int get_feature_size(Tiler smem) { | ||||||
|   cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A; |   int feature_size = (transpose) ? size<0>(smem) : size<1>(smem); | ||||||
|   cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B; |   return (feature_size >= 64) ? 64 : feature_size; | ||||||
| }; | } | ||||||
|  |  | ||||||
|  | constexpr int constexpr_log2(int x) { | ||||||
|  |   return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int feature_size, int itemsize, int copy_bits> | ||||||
|  | constexpr int get_swizzle_bits() { | ||||||
|  |   constexpr int swizzle_bits = | ||||||
|  |       constexpr_log2(feature_size * itemsize / copy_bits); | ||||||
|  |   return (swizzle_bits > 3) ? 3 : swizzle_bits; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int itemsize, bool transpose, int copy_bits, typename Tiler> | ||||||
|  | constexpr auto make_smem_layout(Tiler smem) { | ||||||
|  |   constexpr int feature_size = get_feature_size<transpose>(smem); | ||||||
|  |   constexpr int swizzle_bits = | ||||||
|  |       get_swizzle_bits<feature_size, itemsize, copy_bits>(); | ||||||
|  |  | ||||||
|  |   using F = Int<feature_size>; | ||||||
|  |   using BaseLayout = std::conditional_t< | ||||||
|  |       transpose, | ||||||
|  |       Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>, | ||||||
|  |       Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>; | ||||||
|  |  | ||||||
|  |   auto swizzled = | ||||||
|  |       make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{}); | ||||||
|  |  | ||||||
|  |   return tile_to_shape(swizzled, smem); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int itemsize, bool transpose, int copy_bits, typename Tiler> | ||||||
|  | constexpr auto make_result_smem_layout(Tiler smem) { | ||||||
|  |   constexpr int feature_size = get_feature_size<transpose>(smem); | ||||||
|  |   constexpr int swizzle_bits = | ||||||
|  |       get_swizzle_bits<feature_size, itemsize, copy_bits>(); | ||||||
|  |  | ||||||
|  |   using F = Int<feature_size>; | ||||||
|  |   using BaseLayout = std::conditional_t< | ||||||
|  |       transpose, | ||||||
|  |       Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>, | ||||||
|  |       Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>; | ||||||
|  |  | ||||||
|  |   auto swizzled = make_composed_layout( | ||||||
|  |       Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{}); | ||||||
|  |  | ||||||
|  |   return tile_to_shape(swizzled, smem); | ||||||
|  | } | ||||||
|  |  | ||||||
| template < | template < | ||||||
|     class ProblemShape, |     int num_threads, | ||||||
|     class CtaTiler, |     int itemsize, | ||||||
|     class TA, |     bool transpose, | ||||||
|     class AStride, |     int copy_bits, | ||||||
|     class ASmemLayout, |     typename Copier, | ||||||
|     class TiledCopyA, |     typename Tiler> | ||||||
|     class S2RAtomA, | constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) { | ||||||
|     class TB, |   constexpr int num_elements = copy_bits / itemsize; | ||||||
|     class BStride, |   constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem); | ||||||
|     class BSmemLayout, |   constexpr int copies_per_feature = feature_size / num_elements; | ||||||
|     class TiledCopyB, |  | ||||||
|     class S2RAtomB, |  | ||||||
|     class TC, |  | ||||||
|     class CStride, |  | ||||||
|     class CSmemLayout, |  | ||||||
|     class TiledMma> |  | ||||||
| __global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device( |  | ||||||
|     ProblemShape shape_MNK, |  | ||||||
|     CtaTiler cta_tiler, |  | ||||||
|     TA const* A, |  | ||||||
|     AStride dA, |  | ||||||
|     ASmemLayout sA_layout, |  | ||||||
|     TiledCopyA copy_a, |  | ||||||
|     S2RAtomA s2r_atom_a, |  | ||||||
|     TB const* B, |  | ||||||
|     BStride dB, |  | ||||||
|     BSmemLayout sB_layout, |  | ||||||
|     TiledCopyB copy_b, |  | ||||||
|     S2RAtomB s2r_atom_b, |  | ||||||
|     TC* C, |  | ||||||
|     CStride dC, |  | ||||||
|     CSmemLayout, |  | ||||||
|     TiledMma mma) { |  | ||||||
|   using namespace cute; |  | ||||||
|  |  | ||||||
|   // Preconditions |   using E = Int<num_elements>; | ||||||
|   CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) |   using C = Int<copies_per_feature>; | ||||||
|   CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) |   using R = Int<num_threads / copies_per_feature>; | ||||||
|  |  | ||||||
|   CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads |   using ThreadLayout = std::conditional_t< | ||||||
|   CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads |       transpose, | ||||||
|  |       Layout<cute::Shape<C, R>, cute::Stride<_1, C>>, | ||||||
|  |       Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>; | ||||||
|  |   using ValueLayout = std::conditional_t< | ||||||
|  |       transpose, | ||||||
|  |       Layout<cute::Shape<E, _1>>, | ||||||
|  |       Layout<cute::Shape<_1, E>>>; | ||||||
|  |  | ||||||
|   static_assert(is_static<ASmemLayout>::value); |   return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{}); | ||||||
|   static_assert(is_static<BSmemLayout>::value); | } | ||||||
|   static_assert(is_static<CSmemLayout>::value); |  | ||||||
|  |  | ||||||
|   CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M | template <int rasterization_factor> | ||||||
|   CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M | __device__ inline int2 raster_tile(int x, int y) { | ||||||
|   CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N |   return { | ||||||
|   CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N |       x / rasterization_factor, | ||||||
|   CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K |       (x % rasterization_factor) + y * rasterization_factor}; | ||||||
|   CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K | } | ||||||
|  |  | ||||||
|   CUTE_STATIC_ASSERT_V( | template < | ||||||
|       congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK |     typename T, | ||||||
|   CUTE_STATIC_ASSERT_V( |     typename SLayoutA, | ||||||
|       congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK |     typename SLayoutB, | ||||||
|   CUTE_STATIC_ASSERT_V( |     typename SLayoutC, | ||||||
|       congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN |     typename CopyA, | ||||||
|  |     typename CopyB, | ||||||
|  |     typename CopyC, | ||||||
|  |     typename MMA, | ||||||
|  |     int rasterization_factor> | ||||||
|  | __global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel( | ||||||
|  |     const T* __restrict__ A, | ||||||
|  |     const T* __restrict__ B, | ||||||
|  |     T* __restrict__ C, | ||||||
|  |     SLayoutA SA, | ||||||
|  |     SLayoutB SB, | ||||||
|  |     SLayoutC SC, | ||||||
|  |     CopyA copy_a, | ||||||
|  |     CopyB copy_b, | ||||||
|  |     CopyC copy_c, | ||||||
|  |     MMA mma, | ||||||
|  |     int M, | ||||||
|  |     int N, | ||||||
|  |     int K) { | ||||||
|  |   constexpr auto BM = size<0>(SA); | ||||||
|  |   constexpr auto BN = size<0>(SB); | ||||||
|  |   constexpr auto BK = size<1>(SA); | ||||||
|  |   constexpr auto PIPE = size<2>(SA); | ||||||
|  |  | ||||||
|   // |   const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y); | ||||||
|   // Full and Tiled Tensors |   const int blocks_m = ceil_div(M, BM); | ||||||
|   // |   const int blocks_n = ceil_div(N, BN); | ||||||
|  |  | ||||||
|   // Represent the full tensors |   // Exit early if the tile is OOB | ||||||
|   Tensor mA = |   if (tile.x >= blocks_m || tile.y >= blocks_n) { | ||||||
|       make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K) |     return; | ||||||
|   Tensor mB = |   } | ||||||
|       make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K) |  | ||||||
|   Tensor mC = |  | ||||||
|       make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N) |  | ||||||
|  |  | ||||||
|   // Get the appropriate blocks for this thread block |   // Make the full tensors | ||||||
|   auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) |   Tensor full_A = | ||||||
|   Tensor gA = local_tile( |       make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{})); | ||||||
|       mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k) |   Tensor full_B = | ||||||
|   Tensor gB = local_tile( |       make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{})); | ||||||
|       mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k) |   Tensor full_C = | ||||||
|   Tensor gC = |       make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{})); | ||||||
|       local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N) |  | ||||||
|  |  | ||||||
|   // Shared memory buffers |   // Partition the tensors into tiles and select the ones for this threadblock | ||||||
|  |   Tensor local_A = | ||||||
|  |       local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _)); | ||||||
|  |   Tensor local_B = | ||||||
|  |       local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _)); | ||||||
|  |   Tensor local_C = | ||||||
|  |       local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y)); | ||||||
|  |  | ||||||
|  |   // Make shared memory tensors | ||||||
|   extern __shared__ char shared_memory[]; |   extern __shared__ char shared_memory[]; | ||||||
|   using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>; |   T* shared_A_ptr = reinterpret_cast<T*>(shared_memory); | ||||||
|   SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory); |   T* shared_B_ptr = | ||||||
|   Tensor sA = make_tensor( |       reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T)); | ||||||
|       make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE) |   T* shared_C_ptr = reinterpret_cast<T*>(shared_memory); | ||||||
|   Tensor sB = make_tensor( |   Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA); | ||||||
|       make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE) |   Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB); | ||||||
|  |   Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC); | ||||||
|  |  | ||||||
|   // |   // Get the copies that correspond to this thread | ||||||
|   // Partition the copying of A and B tiles across the threads |   auto thread_copy_a = copy_a.get_slice(threadIdx.x); | ||||||
|   // |   Tensor local_A_src = thread_copy_a.partition_S(local_A); | ||||||
|  |   Tensor local_A_dst = thread_copy_a.partition_D(shared_A); | ||||||
|  |   auto thread_copy_b = copy_b.get_slice(threadIdx.x); | ||||||
|  |   Tensor local_B_src = thread_copy_a.partition_S(local_B); | ||||||
|  |   Tensor local_B_dst = thread_copy_a.partition_D(shared_B); | ||||||
|  |   auto thread_copy_c = copy_c.get_slice(threadIdx.x); | ||||||
|  |   Tensor local_C_src = thread_copy_c.partition_S(shared_C); | ||||||
|  |   Tensor local_C_dst = thread_copy_c.partition_D(local_C); | ||||||
|  |  | ||||||
|   ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); |   // Start fetches | ||||||
|   Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) |   int k_tile_count = size<2>(local_A); | ||||||
|   Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE) |  | ||||||
|  |  | ||||||
|   ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); |  | ||||||
|   Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) |  | ||||||
|   Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE) |  | ||||||
|  |  | ||||||
|   CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M |  | ||||||
|   CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K |  | ||||||
|   CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N |  | ||||||
|   CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K |  | ||||||
|  |  | ||||||
|   // |  | ||||||
|   // PREFETCH |  | ||||||
|   // |  | ||||||
|  |  | ||||||
|   auto K_PIPE_MAX = size<3>(tAsA); |  | ||||||
|  |  | ||||||
|   // Total count of tiles |  | ||||||
|   int k_tile_count = size<3>(tAgA); |  | ||||||
|   // Current tile index in gmem to read from |  | ||||||
|   int k_tile_next = 0; |   int k_tile_next = 0; | ||||||
|  |  | ||||||
|   // Start async loads for all pipes but the last |  | ||||||
|   CUTE_UNROLL |   CUTE_UNROLL | ||||||
|   for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) { |   for (int k = 0; k < PIPE - 1; k++) { | ||||||
|     copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe)); |     copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k)); | ||||||
|     copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe)); |     copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k)); | ||||||
|     cp_async_fence(); |     cp_async_fence(); | ||||||
|     --k_tile_count; |     k_tile_count--; | ||||||
|     if (k_tile_count > 0) { |     k_tile_next += (k_tile_count > 0); | ||||||
|       ++k_tile_next; |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // |   // Get the MMA that corresponds to this thread and allocate registers | ||||||
|   // Define A/B partitioning and C accumulators |   auto thread_mma = mma.get_slice(threadIdx.x); | ||||||
|   // |   Tensor mma_shared_A = thread_mma.partition_A(shared_A); | ||||||
|  |   Tensor mma_shared_B = thread_mma.partition_B(shared_B); | ||||||
|  |   Tensor mma_shared_C = thread_mma.partition_C(shared_C); | ||||||
|  |   Tensor mma_global_C = thread_mma.partition_C(local_C); | ||||||
|  |   Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0)); | ||||||
|  |   Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0)); | ||||||
|  |   Tensor mma_frag_C = mma.make_fragment_C(mma_global_C); | ||||||
|  |   clear(mma_frag_C); | ||||||
|  |  | ||||||
|   ThrMMA thr_mma = mma.get_slice(threadIdx.x); |   // Make shared to register copies | ||||||
|   Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) |   Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a; | ||||||
|  |   Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b; | ||||||
|  |   auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); | ||||||
|  |   auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x); | ||||||
|  |   auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); | ||||||
|  |   auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x); | ||||||
|  |   Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A); | ||||||
|  |   Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A); | ||||||
|  |   Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B); | ||||||
|  |   Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B); | ||||||
|  |  | ||||||
|   // Allocate registers for pipelining |   constexpr auto RPIPE = size<2>(mma_shared_A); | ||||||
|   Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K) |   int smem_read = 0; | ||||||
|   Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K) |   int smem_write = PIPE - 1; | ||||||
|   // Allocate the accumulators -- same size as the projected data |   Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read); | ||||||
|   Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) |   Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read); | ||||||
|  |  | ||||||
|   CUTE_STATIC_ASSERT_V( |   // Start the register pipeline | ||||||
|       (shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N) |   if constexpr (RPIPE > 1) { | ||||||
|   CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M |     cp_async_wait<PIPE - 2>(); | ||||||
|   CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N |  | ||||||
|  |  | ||||||
|   // Clear the accumulators |  | ||||||
|   clear(tCrC); |  | ||||||
|  |  | ||||||
|   // |  | ||||||
|   // Copy Atom retiling |  | ||||||
|   // |  | ||||||
|  |  | ||||||
|   TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); |  | ||||||
|   ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x); |  | ||||||
|   Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE) |  | ||||||
|   Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K) |  | ||||||
|  |  | ||||||
|   TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); |  | ||||||
|   ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x); |  | ||||||
|   Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE) |  | ||||||
|   Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K) |  | ||||||
|  |  | ||||||
| #if 0 |  | ||||||
|   if(thread0()) { |  | ||||||
|     print("  mA : "); print(  mA); print("\n"); |  | ||||||
|     print("  gA : "); print(  gA); print("\n"); |  | ||||||
|     print("  sA : "); print(  sA); print("\n"); |  | ||||||
|     print("tAgA : "); print(tAgA); print("\n"); |  | ||||||
|     print("tAsA : "); print(tAsA); print("\n"); |  | ||||||
|   } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #if 0 |  | ||||||
|   if(thread0()) { |  | ||||||
|     print("  mB : "); print(  mB); print("\n"); |  | ||||||
|     print("  gB : "); print(  gB); print("\n"); |  | ||||||
|     print("  sB : "); print(  sB); print("\n"); |  | ||||||
|     print("tBgB : "); print(tBgB); print("\n"); |  | ||||||
|     print("tBsB : "); print(tBsB); print("\n"); |  | ||||||
|   } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #if 0 |  | ||||||
|   if(thread0()) { |  | ||||||
|     print("  mC : "); print(  mC); print("\n"); |  | ||||||
|     print("  gC : "); print(  gC); print("\n"); |  | ||||||
|     print("tCgC : "); print(tCgC); print("\n"); |  | ||||||
|     print("tCrA : "); print(tCrA); print("\n"); |  | ||||||
|     print("tCrB : "); print(tCrB); print("\n"); |  | ||||||
|     print("tCrC : "); print(tCrC); print("\n"); |  | ||||||
|  |  | ||||||
|     print("tXsA : "); print(tXsA); print("\n"); |  | ||||||
|     print("tXrA : "); print(tXrA); print("\n"); |  | ||||||
|     print("tXsB : "); print(tXsB); print("\n"); |  | ||||||
|     print("tXrB : "); print(tXrB); print("\n"); |  | ||||||
|   } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #if 1 |  | ||||||
|  |  | ||||||
|   // Current pipe index in smem to read from |  | ||||||
|   int smem_pipe_read = 0; |  | ||||||
|   // Current pipe index in smem to write to |  | ||||||
|   int smem_pipe_write = K_PIPE_MAX - 1; |  | ||||||
|  |  | ||||||
|   // Pipe slice |  | ||||||
|   Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read); |  | ||||||
|   Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read); |  | ||||||
|  |  | ||||||
|   // Size of the register pipeline |  | ||||||
|   auto K_BLOCK_MAX = size<2>(tCrA); |  | ||||||
|   CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA)); |  | ||||||
|  |  | ||||||
|   // PREFETCH register pipeline |  | ||||||
|   if (K_BLOCK_MAX > 1) { |  | ||||||
|     // Wait until our first prefetched tile is loaded in |  | ||||||
|     cp_async_wait<K_PIPE_MAX - 2>(); |  | ||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
|  |     copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{})); | ||||||
|     // Prefetch the first rmem from the first k-tile |     copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{})); | ||||||
|     copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{})); |  | ||||||
|     copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{})); |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // |  | ||||||
|   // PIPELINED MAIN LOOP |  | ||||||
|   // TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's |  | ||||||
|   // cp.async instructions |  | ||||||
|   //           and explicit pipelines in shared memory. |  | ||||||
|   //   Data is read from global(k_tile_next) to shared(smem_pipe_write). |  | ||||||
|   //   Data is read from shared(smem_pipe_read) to registers(k_block_next). |  | ||||||
|   //   Data is computed on registers(b_block). |  | ||||||
|   // |  | ||||||
|   //   This allows all copies and compute to overlap: |  | ||||||
|   //     Copy from gmem->smem can overlap with copies from smem->rmem and |  | ||||||
|   //     compute on rmem. Copy from smem->rmem can overlap with compute on rmem. |  | ||||||
|   // |  | ||||||
|  |  | ||||||
|   CUTE_NO_UNROLL |   CUTE_NO_UNROLL | ||||||
|   while (k_tile_count > -(K_PIPE_MAX - 1)) { |   while (k_tile_count > -(PIPE - 1)) { | ||||||
|     CUTE_UNROLL |     CUTE_UNROLL | ||||||
|     for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { |     for (int k_block = 0; k_block < RPIPE; k_block++) { | ||||||
|       if (k_block == K_BLOCK_MAX - 1) { |       if (k_block == RPIPE - 1) { | ||||||
|         // Slice the smem_pipe_read smem |         mma_A_src_p = mma_A_src(_, _, _, smem_read); | ||||||
|         tXsA_p = tXsA(_, _, _, smem_pipe_read); |         mma_B_src_p = mma_B_src(_, _, _, smem_read); | ||||||
|         tXsB_p = tXsB(_, _, _, smem_pipe_read); |         cp_async_wait<PIPE - 2>(); | ||||||
|  |  | ||||||
|         // Commit the smem for smem_pipe_read |  | ||||||
|         cp_async_wait<K_PIPE_MAX - 2>(); |  | ||||||
|         __syncthreads(); |         __syncthreads(); | ||||||
|       } |       } | ||||||
|  |  | ||||||
|       // Load A, B shmem->regs for k_block+1 |       // Load the next register tile | ||||||
|       auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static |       auto k_block_next = (k_block + 1) % RPIPE; | ||||||
|       copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next)); |       copy( | ||||||
|       copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next)); |           s2r_copy_a, | ||||||
|       // Copy gmem to smem before computing gemm on each k-pipe |           mma_A_src_p(_, _, k_block_next), | ||||||
|  |           mma_A_dst(_, _, k_block_next)); | ||||||
|  |       copy( | ||||||
|  |           s2r_copy_b, | ||||||
|  |           mma_B_src_p(_, _, k_block_next), | ||||||
|  |           mma_B_dst(_, _, k_block_next)); | ||||||
|  |  | ||||||
|       if (k_block == 0) { |       if (k_block == 0) { | ||||||
|         copy( |         copy( | ||||||
|             copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write)); |             copy_a, | ||||||
|  |             local_A_src(_, _, _, k_tile_next), | ||||||
|  |             local_A_dst(_, _, _, smem_write)); | ||||||
|         copy( |         copy( | ||||||
|             copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write)); |             copy_b, | ||||||
|  |             local_B_src(_, _, _, k_tile_next), | ||||||
|  |             local_B_dst(_, _, _, smem_write)); | ||||||
|         cp_async_fence(); |         cp_async_fence(); | ||||||
|  |         k_tile_count--; | ||||||
|         // Advance the gmem tile |         k_tile_next += (k_tile_count > 0); | ||||||
|         --k_tile_count; |         smem_write = smem_read; | ||||||
|         if (k_tile_count > 0) { |         smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1); | ||||||
|           ++k_tile_next; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         // Advance the smem pipe |  | ||||||
|         smem_pipe_write = smem_pipe_read; |  | ||||||
|         smem_pipe_read = |  | ||||||
|             (smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1; |  | ||||||
|       } |       } | ||||||
|       // Thread-level register gemm for k_block |  | ||||||
|       gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); |       gemm( | ||||||
|  |           mma, | ||||||
|  |           mma_frag_A(_, _, k_block), | ||||||
|  |           mma_frag_B(_, _, k_block), | ||||||
|  |           mma_frag_C); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
| #endif |   copy(mma_frag_C, mma_shared_C); | ||||||
|  |   __syncthreads(); | ||||||
|  |   copy(copy_c, local_C_src, local_C_dst); | ||||||
|  |  | ||||||
|  |   // if (threadIdx.x == 0) { | ||||||
|  |   //   print("fC: "); print(mma_frag_C); print("\n"); | ||||||
|  |   //   print("sC: "); print(mma_shared_C); print("\n"); | ||||||
|  |   //   print("dC: "); print(local_C_dst); print("\n"); | ||||||
|   // |   // | ||||||
|   // Epilogue |   //   print(s2r_atom_a); print("\n"); | ||||||
|   // |   // } | ||||||
|  |  | ||||||
|   copy(tCrC, tCgC); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| } // namespace | } // namespace | ||||||
| @@ -339,103 +319,74 @@ void cutlass_gemm( | |||||||
|     if constexpr (std::is_same_v<DataType, __nv_bfloat16>) { |     if constexpr (std::is_same_v<DataType, __nv_bfloat16>) { | ||||||
|       using namespace cute; |       using namespace cute; | ||||||
|  |  | ||||||
|       // Define shapes (dynamic) |       // Tile definitions | ||||||
|       auto prob_shape = make_shape(M, N, K); |       auto BM = Int<128>{}; | ||||||
|  |       auto BN = Int<128>{}; | ||||||
|  |       auto BK = Int<64>{}; | ||||||
|  |       auto BP = Int<3>{}; | ||||||
|  |       auto GM = Int<8>{}; | ||||||
|  |  | ||||||
|       // Define TN strides (mixed) |       // Thread definitions | ||||||
|       auto dA = make_stride(K, Int<1>{}); |       using TM = Int<2>; | ||||||
|       auto dB = make_stride(K, Int<1>{}); |       using TN = Int<2>; | ||||||
|       auto dC = make_stride(N, Int<1>{}); |       using TK = Int<1>; | ||||||
|  |       constexpr int num_threads = TM::value * TN::value * 32; | ||||||
|  |  | ||||||
|       // Define CTA tile sizes (static) |       auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP)); | ||||||
|       auto bM = Int<128>{}; |       auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP)); | ||||||
|       auto bN = Int<128>{}; |       auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN)); | ||||||
|       auto bK = Int<64>{}; |  | ||||||
|       auto cta_tiler = make_shape(bM, bN, bK); |  | ||||||
|       auto bP = Int<3>{}; |  | ||||||
|  |  | ||||||
|       // Define the smem layouts (static) |       constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16); | ||||||
|       // Swizzles for LDSM and 128b k-major loads |  | ||||||
|       auto swizzle_atom = composition( |  | ||||||
|           Swizzle<3, 3, 3>{}, |  | ||||||
|           Layout< |  | ||||||
|               cute::Shape<_8, cute::Shape<_8, _8>>, |  | ||||||
|               cute::Stride<_8, cute::Stride<_1, _64>>>{}); |  | ||||||
|  |  | ||||||
|       auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP)); |       auto async_copy_op = | ||||||
|       auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP)); |           Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{}; | ||||||
|       auto sC = make_layout(make_shape(bM, bN)); |       auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>( | ||||||
|  |           async_copy_op, make_shape(BM, BK)); | ||||||
|  |       auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>( | ||||||
|  |           async_copy_op, make_shape(BN, BK)); | ||||||
|  |  | ||||||
|       // Define the thread layouts (static) |       auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{}; | ||||||
|  |       auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>( | ||||||
|  |           sync_copy_op, make_shape(BM, BN)); | ||||||
|  |  | ||||||
|       TiledCopy copyA = make_tiled_copy( |       auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{}; | ||||||
|           Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{}, |       auto tiled_mma = make_tiled_mma( | ||||||
|           Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout |           mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{}); | ||||||
|                                                                 // 16x8 k-major |  | ||||||
|           Layout<cute::Shape<_1, _8>>{}); // Val layout  1x8 k-major |  | ||||||
|       TiledCopy copyB = make_tiled_copy( |  | ||||||
|           Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{}, |  | ||||||
|           Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout |  | ||||||
|                                                                 // 16x8 k-major |  | ||||||
|           Layout<cute::Shape<_1, _8>>{}); // Val layout  1x8 n-major |  | ||||||
|  |  | ||||||
|       TiledMMA mmaC = make_tiled_mma( |  | ||||||
|           SM80_16x8x16_F32BF16BF16F32_TN{}, |  | ||||||
|           Layout<cute::Shape<_2, _2>>{}, // 2x2x1 MMA Atoms |  | ||||||
|           Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM |  | ||||||
|  |  | ||||||
|       Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_A; |  | ||||||
|       Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_B; |  | ||||||
|  |  | ||||||
|       int smem_size = int(sizeof(SharedStorage< |  | ||||||
|                                  cute::bfloat16_t, |  | ||||||
|                                  cute::bfloat16_t, |  | ||||||
|                                  decltype(sA), |  | ||||||
|                                  decltype(sB)>)); |  | ||||||
|       dim3 dimBlock(size(mmaC)); |  | ||||||
|       dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); |  | ||||||
|  |  | ||||||
|       auto kernel = gemm_device< |  | ||||||
|           decltype(prob_shape), |  | ||||||
|           decltype(cta_tiler), |  | ||||||
|           cute::bfloat16_t, |  | ||||||
|           decltype(dA), |  | ||||||
|           decltype(sA), |  | ||||||
|           decltype(copyA), |  | ||||||
|           decltype(s2r_atom_A), |  | ||||||
|           cute::bfloat16_t, |  | ||||||
|           decltype(dB), |  | ||||||
|           decltype(sB), |  | ||||||
|           decltype(copyB), |  | ||||||
|           decltype(s2r_atom_B), |  | ||||||
|           cute::bfloat16_t, |  | ||||||
|           decltype(dC), |  | ||||||
|           decltype(sC), |  | ||||||
|           decltype(mmaC)>; |  | ||||||
|  |  | ||||||
|  |       auto kernel = matmul_kernel< | ||||||
|  |           bf16, | ||||||
|  |           decltype(SA), | ||||||
|  |           decltype(SB), | ||||||
|  |           decltype(SC), | ||||||
|  |           decltype(tiled_copy_a), | ||||||
|  |           decltype(tiled_copy_b), | ||||||
|  |           decltype(tiled_copy_c), | ||||||
|  |           decltype(tiled_mma), | ||||||
|  |           GM.value>; | ||||||
|       configure_matmul(kernel, smem_size); |       configure_matmul(kernel, smem_size); | ||||||
|  |  | ||||||
|  |       dim3 block(size(tiled_mma)); | ||||||
|  |       dim3 grid( | ||||||
|  |           size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM))); | ||||||
|  |  | ||||||
|       enc.add_kernel_node( |       enc.add_kernel_node( | ||||||
|           kernel, |           kernel, | ||||||
|           dimGrid, |           grid, | ||||||
|           dimBlock, |           block, | ||||||
|           smem_size, |           smem_size, | ||||||
|           prob_shape, |           a.data<bf16>(), | ||||||
|           cta_tiler, |           b.data<bf16>(), | ||||||
|           a.data<cute::bfloat16_t>(), |           out.data<bf16>(), | ||||||
|           dA, |           SA, | ||||||
|           sA, |           SB, | ||||||
|           copyA, |           SC, | ||||||
|           s2r_atom_A, |           tiled_copy_a, | ||||||
|           b.data<cute::bfloat16_t>(), |           tiled_copy_b, | ||||||
|           dB, |           tiled_copy_c, | ||||||
|           sB, |           tiled_mma, | ||||||
|           copyB, |           M, | ||||||
|           s2r_atom_B, |           N, | ||||||
|           out.data<cute::bfloat16_t>(), |           K); | ||||||
|           dC, |  | ||||||
|           sC, |  | ||||||
|           mmaC); |  | ||||||
|     } else { |     } else { | ||||||
|       throw std::runtime_error("Only bfloat16 supported"); |       throw std::runtime_error("Only bfloat16 supported"); | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos