mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
Reset cutlass gemm to working state again
This commit is contained in:
parent
cf5eef095d
commit
e1303f6160
@ -10,219 +10,315 @@ namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Kernel>
|
||||
void configure_matmul(Kernel kernel, int smem_size) {
|
||||
static bool initialized = false;
|
||||
if (!initialized) {
|
||||
initialized = true;
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementA, class ElementB, class SmemLayoutA, class SmemLayoutB>
|
||||
struct SharedStorage {
|
||||
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A;
|
||||
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B;
|
||||
};
|
||||
|
||||
template <
|
||||
class T,
|
||||
class ProblemShape,
|
||||
class CtaTiler,
|
||||
class SmemLayoutA,
|
||||
class SmemLayoutB,
|
||||
class TiledMMA,
|
||||
class G2STiledCopyA,
|
||||
class G2STiledCopyB,
|
||||
class S2RCopyAtomA,
|
||||
class S2RCopyAtomB>
|
||||
__global__ void cute_gemm_v02(
|
||||
unsigned int M,
|
||||
unsigned int N,
|
||||
unsigned int K,
|
||||
const T* A,
|
||||
size_t lda,
|
||||
const T* B,
|
||||
size_t ldb,
|
||||
T* C,
|
||||
size_t ldc) {
|
||||
class TA,
|
||||
class AStride,
|
||||
class ASmemLayout,
|
||||
class TiledCopyA,
|
||||
class S2RAtomA,
|
||||
class TB,
|
||||
class BStride,
|
||||
class BSmemLayout,
|
||||
class TiledCopyB,
|
||||
class S2RAtomB,
|
||||
class TC,
|
||||
class CStride,
|
||||
class CSmemLayout,
|
||||
class TiledMma>
|
||||
__global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device(
|
||||
ProblemShape shape_MNK,
|
||||
CtaTiler cta_tiler,
|
||||
TA const* A,
|
||||
AStride dA,
|
||||
ASmemLayout sA_layout,
|
||||
TiledCopyA copy_a,
|
||||
S2RAtomA s2r_atom_a,
|
||||
TB const* B,
|
||||
BStride dB,
|
||||
BSmemLayout sB_layout,
|
||||
TiledCopyB copy_b,
|
||||
S2RAtomB s2r_atom_b,
|
||||
TC* C,
|
||||
CStride dC,
|
||||
CSmemLayout,
|
||||
TiledMma mma) {
|
||||
using namespace cute;
|
||||
|
||||
// global full tensor
|
||||
// shape
|
||||
auto shape_MNK = make_shape(M, N, K);
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
|
||||
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
// stride
|
||||
// cublas covenience for TN gemm
|
||||
// all matrices are in column major
|
||||
// A (m,k) --> transpose --> A(k, m) --> cute layout: A (m, k) : (k, 1) -->
|
||||
// lda = k B (k,n) --> cute layout: B (n, k) : (k, 1) --> ldb = k C (m,n) -->
|
||||
// cute layout: C (m, n) : (1, m) --> ldc = m
|
||||
auto dA = make_stride(lda, _1{});
|
||||
auto dB = make_stride(ldb, _1{});
|
||||
auto dC = make_stride(_1{}, ldc);
|
||||
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
|
||||
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
|
||||
|
||||
static_assert(is_static<ASmemLayout>::value);
|
||||
static_assert(is_static<BSmemLayout>::value);
|
||||
static_assert(is_static<CSmemLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
|
||||
CUTE_STATIC_ASSERT_V(
|
||||
congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK
|
||||
CUTE_STATIC_ASSERT_V(
|
||||
congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK
|
||||
CUTE_STATIC_ASSERT_V(
|
||||
congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN
|
||||
|
||||
//
|
||||
// Full and Tiled Tensors
|
||||
//
|
||||
|
||||
// Represent the full tensors
|
||||
Tensor mA =
|
||||
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // M x K
|
||||
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K)
|
||||
Tensor mB =
|
||||
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // N x K
|
||||
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K)
|
||||
Tensor mC =
|
||||
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // M x N
|
||||
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N)
|
||||
|
||||
// global tile tensor
|
||||
auto cta_tiler = CtaTiler{};
|
||||
auto cta_coord = make_coord(blockIdx.y, blockIdx.x, _);
|
||||
// Get the appropriate blocks for this thread block
|
||||
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||
Tensor gA = local_tile(
|
||||
mA,
|
||||
cta_tiler,
|
||||
cta_coord,
|
||||
Step<_1, X, _1>{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_TILES_K
|
||||
mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = local_tile(
|
||||
mB,
|
||||
cta_tiler,
|
||||
cta_coord,
|
||||
Step<X, _1, _1>{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_TILES_K
|
||||
Tensor gC = local_tile(
|
||||
mC,
|
||||
cta_tiler,
|
||||
cta_coord,
|
||||
Step<_1, _1, X>{}); // BLOCK_SIZE_M x BLOCK_SIZE_N
|
||||
|
||||
// shared memory
|
||||
// __shared__ T Asmem[cosize_v<SmemLayoutA>];
|
||||
// __shared__ T Bsmem[cosize_v<SmemLayoutB>];
|
||||
|
||||
extern __shared__ T smem[];
|
||||
T* Asmem = smem;
|
||||
T* Bsmem = smem + cosize_v<SmemLayoutA>;
|
||||
mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
|
||||
Tensor gC =
|
||||
local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
// Shared memory buffers
|
||||
extern __shared__ char shared_memory[];
|
||||
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
|
||||
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
Tensor sA = make_tensor(
|
||||
make_smem_ptr(Asmem),
|
||||
SmemLayoutA{}); // BLOCK_SIZE_M x BLOCK_SIZE_K x NUM_STAGES
|
||||
make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(
|
||||
make_smem_ptr(Bsmem),
|
||||
SmemLayoutB{}); // BLOCK_SIZE_N x BLOCK_SIZE_K x NUM_STAGES
|
||||
make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
// MMA
|
||||
// use TiledMMA --> get one thread work
|
||||
auto tiled_mma = TiledMMA{};
|
||||
ThrMMA thr_mma = tiled_mma.get_slice(threadIdx.x);
|
||||
auto tCgC = thr_mma.partition_C(gC); // MMA x MMA_M x MMA_N
|
||||
//
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
//
|
||||
|
||||
// thread private memory for MMA
|
||||
auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // MMA x MMA_M x MMA_K
|
||||
auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // MMA x MMA_N x MMA_K
|
||||
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
|
||||
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
|
||||
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE)
|
||||
|
||||
// thread private memory for accumulator for MMA
|
||||
Tensor tCrC = thr_mma.partition_fragment_C(gC); // MMA x MMA_M x MMA_N
|
||||
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
|
||||
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
|
||||
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE)
|
||||
|
||||
clear(tCrC);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
|
||||
|
||||
// initiate copy from global memory to shared memory
|
||||
// use G2S TiledCopy --> get one thread copy work
|
||||
auto g2s_tiled_copy_a = G2STiledCopyA{};
|
||||
auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(threadIdx.x);
|
||||
const auto tAgA =
|
||||
g2s_thr_copy_a.partition_S(gA); // CPY x CPY_M x CPY_K x NUM_TILES_K
|
||||
auto tAsA =
|
||||
g2s_thr_copy_a.partition_D(sA); // CPY x CPY_M x CPY_K x NUM_STAGES
|
||||
//
|
||||
// PREFETCH
|
||||
//
|
||||
|
||||
auto g2s_tiled_copy_b = G2STiledCopyB{};
|
||||
auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(threadIdx.x);
|
||||
const auto tBgB =
|
||||
g2s_thr_copy_b.partition_S(gB); // CPY x CPY_N x CPY_K x NUM_TILES_K
|
||||
auto tBsB =
|
||||
g2s_thr_copy_b.partition_D(sB); // CPY x CPY_N x CPY_K x NUM_STAGES
|
||||
auto K_PIPE_MAX = size<3>(tAsA);
|
||||
|
||||
// initiate copy from shared memory to thread private memory
|
||||
// use S2R TiledCopy --> get one thread copy work
|
||||
auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma);
|
||||
auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(threadIdx.x);
|
||||
const auto tCsA =
|
||||
s2r_thr_copy_a.partition_S(sA); // CPY x CPY_M x CPY_K x NUM_STAGES
|
||||
auto tCrA_copy_view = s2r_thr_copy_a.retile_D(tCrA); // CPY x CPY_M x CPY_K
|
||||
|
||||
auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma);
|
||||
auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(threadIdx.x);
|
||||
const auto tCsB =
|
||||
s2r_thr_copy_b.partition_S(sB); // CPY x CPY_N x CPY_K x NUM_STAGES
|
||||
auto tCrB_copy_view = s2r_thr_copy_b.retile_D(tCrB); // CPY x CPY_N x CPY_K
|
||||
|
||||
// pipeline
|
||||
// counter
|
||||
int itile_to_read = 0; // read index of the next tile
|
||||
// 2 pointers of the buffer
|
||||
int ismem_write = 0;
|
||||
int ismem_read = 0;
|
||||
|
||||
// NUM_STAGES = 5 --> Prefetech NUM_STAGES-1 = 4 tiles first
|
||||
auto NUM_STAGES = size<3>(tAsA);
|
||||
// Total count of tiles
|
||||
int k_tile_count = size<3>(tAgA);
|
||||
// Current tile index in gmem to read from
|
||||
int k_tile_next = 0;
|
||||
|
||||
// Start async loads for all pipes but the last
|
||||
CUTE_UNROLL
|
||||
for (int stage = 0; stage < NUM_STAGES - 1; ++stage) {
|
||||
// prefetch
|
||||
// issue copy
|
||||
copy(g2s_tiled_copy_a, tAgA(_, _, _, itile_to_read), tAsA(_, _, _, stage));
|
||||
copy(g2s_tiled_copy_b, tBgB(_, _, _, itile_to_read), tBsB(_, _, _, stage));
|
||||
|
||||
// commit
|
||||
for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) {
|
||||
copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe));
|
||||
copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe));
|
||||
cp_async_fence();
|
||||
|
||||
ismem_write++;
|
||||
itile_to_read++;
|
||||
}
|
||||
|
||||
// wait for first tile to be prefetched: G^0 -> S^0
|
||||
cp_async_wait<NUM_STAGES - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Having S^0, copy from S^0,0 to R^0
|
||||
int k = 0;
|
||||
copy(s2r_tiled_copy_a, tCsA(_, _, k, ismem_read), tCrA_copy_view(_, _, k));
|
||||
copy(s2r_tiled_copy_b, tCsB(_, _, k, ismem_read), tCrB_copy_view(_, _, k));
|
||||
|
||||
// loop over tiles
|
||||
auto NUM_TILES_K = size<3>(tAgA);
|
||||
|
||||
CUTE_NO_UNROLL
|
||||
for (int tile = 0; tile < NUM_TILES_K; ++tile) {
|
||||
auto MMA_K = size<2>(tCrA);
|
||||
// loop over MMAs in direction of K
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < MMA_K; ++k) {
|
||||
int k_next = (k + 1) % MMA_K;
|
||||
|
||||
// if this is the second last MMA, wait the next tile to be fetched
|
||||
if (k == MMA_K - 1) {
|
||||
cp_async_wait<NUM_STAGES - 2>();
|
||||
__syncthreads();
|
||||
|
||||
ismem_read = (ismem_read + 1) % NUM_STAGES;
|
||||
}
|
||||
|
||||
// load data for the next MMA, from S^tile to registers
|
||||
copy(
|
||||
s2r_tiled_copy_a,
|
||||
tCsA(_, _, k_next, ismem_read),
|
||||
tCrA_copy_view(_, _, k_next));
|
||||
copy(
|
||||
s2r_tiled_copy_b,
|
||||
tCsB(_, _, k_next, ismem_read),
|
||||
tCrB_copy_view(_, _, k_next));
|
||||
|
||||
if (k == 0) {
|
||||
// prefetch the next tile
|
||||
// issue copy
|
||||
if (itile_to_read < NUM_TILES_K) {
|
||||
copy(
|
||||
g2s_tiled_copy_a,
|
||||
tAgA(_, _, _, itile_to_read),
|
||||
tAsA(_, _, _, ismem_write));
|
||||
copy(
|
||||
g2s_tiled_copy_b,
|
||||
tBgB(_, _, _, itile_to_read),
|
||||
tBsB(_, _, _, ismem_write));
|
||||
|
||||
itile_to_read++;
|
||||
ismem_write = (ismem_write + 1) % NUM_STAGES;
|
||||
}
|
||||
// commit
|
||||
cp_async_fence();
|
||||
}
|
||||
|
||||
// mma
|
||||
gemm(tiled_mma, tCrC, tCrA(_, _, k), tCrB(_, _, k), tCrC);
|
||||
--k_tile_count;
|
||||
if (k_tile_count > 0) {
|
||||
++k_tile_next;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Define A/B partitioning and C accumulators
|
||||
//
|
||||
|
||||
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
|
||||
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Allocate registers for pipelining
|
||||
Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K)
|
||||
// Allocate the accumulators -- same size as the projected data
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(
|
||||
(shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
|
||||
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
//
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
|
||||
ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x);
|
||||
Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE)
|
||||
Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K)
|
||||
|
||||
TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
|
||||
ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x);
|
||||
Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE)
|
||||
Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K)
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mA : "); print( mA); print("\n");
|
||||
print(" gA : "); print( gA); print("\n");
|
||||
print(" sA : "); print( sA); print("\n");
|
||||
print("tAgA : "); print(tAgA); print("\n");
|
||||
print("tAsA : "); print(tAsA); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mB : "); print( mB); print("\n");
|
||||
print(" gB : "); print( gB); print("\n");
|
||||
print(" sB : "); print( sB); print("\n");
|
||||
print("tBgB : "); print(tBgB); print("\n");
|
||||
print("tBsB : "); print(tBsB); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mC : "); print( mC); print("\n");
|
||||
print(" gC : "); print( gC); print("\n");
|
||||
print("tCgC : "); print(tCgC); print("\n");
|
||||
print("tCrA : "); print(tCrA); print("\n");
|
||||
print("tCrB : "); print(tCrB); print("\n");
|
||||
print("tCrC : "); print(tCrC); print("\n");
|
||||
|
||||
print("tXsA : "); print(tXsA); print("\n");
|
||||
print("tXrA : "); print(tXrA); print("\n");
|
||||
print("tXsB : "); print(tXsB); print("\n");
|
||||
print("tXrB : "); print(tXrB); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
|
||||
// Current pipe index in smem to read from
|
||||
int smem_pipe_read = 0;
|
||||
// Current pipe index in smem to write to
|
||||
int smem_pipe_write = K_PIPE_MAX - 1;
|
||||
|
||||
// Pipe slice
|
||||
Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read);
|
||||
Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read);
|
||||
|
||||
// Size of the register pipeline
|
||||
auto K_BLOCK_MAX = size<2>(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA));
|
||||
|
||||
// PREFETCH register pipeline
|
||||
if (K_BLOCK_MAX > 1) {
|
||||
// Wait until our first prefetched tile is loaded in
|
||||
cp_async_wait<K_PIPE_MAX - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Prefetch the first rmem from the first k-tile
|
||||
copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{}));
|
||||
copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
// TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's
|
||||
// cp.async instructions
|
||||
// and explicit pipelines in shared memory.
|
||||
// Data is read from global(k_tile_next) to shared(smem_pipe_write).
|
||||
// Data is read from shared(smem_pipe_read) to registers(k_block_next).
|
||||
// Data is computed on registers(b_block).
|
||||
//
|
||||
// This allows all copies and compute to overlap:
|
||||
// Copy from gmem->smem can overlap with copies from smem->rmem and
|
||||
// compute on rmem. Copy from smem->rmem can overlap with compute on rmem.
|
||||
//
|
||||
|
||||
CUTE_NO_UNROLL
|
||||
while (k_tile_count > -(K_PIPE_MAX - 1)) {
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
|
||||
if (k_block == K_BLOCK_MAX - 1) {
|
||||
// Slice the smem_pipe_read smem
|
||||
tXsA_p = tXsA(_, _, _, smem_pipe_read);
|
||||
tXsB_p = tXsB(_, _, _, smem_pipe_read);
|
||||
|
||||
// Commit the smem for smem_pipe_read
|
||||
cp_async_wait<K_PIPE_MAX - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static
|
||||
copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next));
|
||||
copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next));
|
||||
// Copy gmem to smem before computing gemm on each k-pipe
|
||||
if (k_block == 0) {
|
||||
copy(
|
||||
copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write));
|
||||
copy(
|
||||
copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write));
|
||||
cp_async_fence();
|
||||
|
||||
// Advance the gmem tile
|
||||
--k_tile_count;
|
||||
if (k_tile_count > 0) {
|
||||
++k_tile_next;
|
||||
}
|
||||
|
||||
// Advance the smem pipe
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
smem_pipe_read =
|
||||
(smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
copy(tCrC, tCgC);
|
||||
// constexpr T alpha{1.0f};
|
||||
// constexpr T beta{0.0f};
|
||||
// axpby(alpha, tCrC, beta, tCgC);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -243,144 +339,103 @@ void cutlass_gemm(
|
||||
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
|
||||
using namespace cute;
|
||||
|
||||
// block shape and cta tiler
|
||||
// additional dim: NUM_STAGES --> This is for later pipelining the k-slice
|
||||
// GEMM
|
||||
auto BLOCK_SIZE_M = _128{};
|
||||
auto BLOCK_SIZE_N = _128{};
|
||||
auto BLOCK_SIZE_K = _32{};
|
||||
auto NUM_STAGES = _5{};
|
||||
using CtaTiler =
|
||||
decltype(make_shape(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K));
|
||||
// Define shapes (dynamic)
|
||||
auto prob_shape = make_shape(M, N, K);
|
||||
|
||||
// smem layout
|
||||
// Swizzle parameters need to be chosen right
|
||||
static constexpr int kShmLoadSwizzleM = 3;
|
||||
static constexpr int kShmLoadSwizzleS = 3;
|
||||
static constexpr int kShmLoadSwizzleB = 3;
|
||||
// Define TN strides (mixed)
|
||||
auto dA = make_stride(K, Int<1>{});
|
||||
auto dB = make_stride(K, Int<1>{});
|
||||
auto dC = make_stride(N, Int<1>{});
|
||||
|
||||
using SmemLayoutAtom = decltype(composition(
|
||||
Swizzle<kShmLoadSwizzleB, kShmLoadSwizzleM, kShmLoadSwizzleS>{},
|
||||
make_layout(make_shape(_8{}, BLOCK_SIZE_K), LayoutRight{})));
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int<64>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK);
|
||||
auto bP = Int<3>{};
|
||||
|
||||
// what does this do?
|
||||
// with BLOCK_SIZE_K = 32, shape: (8, 32)
|
||||
// 2^M = 8, 1 new unit = 8 units --> 1 row contains 32/8 = 4 new units
|
||||
// 2^S = 8, it will treat 1 row = 8 new units --> do 8-unit swizzle
|
||||
// 2^B = 8, it will reset the swizzle pattern after 8 rows
|
||||
// print_layout(SmemLayoutAtom{});
|
||||
// Define the smem layouts (static)
|
||||
// Swizzles for LDSM and 128b k-major loads
|
||||
auto swizzle_atom = composition(
|
||||
Swizzle<3, 3, 3>{},
|
||||
Layout<
|
||||
cute::Shape<_8, cute::Shape<_8, _8>>,
|
||||
cute::Stride<_8, cute::Stride<_1, _64>>>{});
|
||||
|
||||
// tile_to_shape extends the layout in LayoutLeft order
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtom{},
|
||||
make_shape(BLOCK_SIZE_M, BLOCK_SIZE_K, NUM_STAGES)));
|
||||
auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP));
|
||||
auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP));
|
||||
auto sC = make_layout(make_shape(bM, bN));
|
||||
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtom{},
|
||||
make_shape(BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES)));
|
||||
// Define the thread layouts (static)
|
||||
|
||||
// TiledMMA
|
||||
using mma_op = SM80_16x8x16_F32BF16BF16F32_TN;
|
||||
using mma_traits = MMA_Traits<mma_op>;
|
||||
using mma_atom = MMA_Atom<mma_traits>;
|
||||
TiledCopy copyA = make_tiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
|
||||
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
|
||||
// 16x8 k-major
|
||||
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 k-major
|
||||
TiledCopy copyB = make_tiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
|
||||
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
|
||||
// 16x8 k-major
|
||||
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 n-major
|
||||
|
||||
static constexpr int kMmaEURepeatM = 2;
|
||||
static constexpr int kMmaEURepeatN = 2;
|
||||
static constexpr int kMmaEURepeatK = 1;
|
||||
// 32 x 2 x 2 = 128 threads
|
||||
TiledMMA mmaC = make_tiled_mma(
|
||||
SM80_16x8x16_F32BF16BF16F32_TN{},
|
||||
Layout<cute::Shape<_2, _2>>{}, // 2x2x1 MMA Atoms
|
||||
Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM
|
||||
|
||||
using mma_atom_shape = mma_traits::Shape_MNK;
|
||||
static constexpr int MmaVM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{});
|
||||
static constexpr int MmaVN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{});
|
||||
static constexpr int MmaVK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{});
|
||||
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_A;
|
||||
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_B;
|
||||
|
||||
// this is for problem shape (16x2) x (8x2x2) x (16x1) = 32x32x16
|
||||
using MMA_EU_RepeatT = decltype(make_layout(make_shape(
|
||||
Int<kMmaEURepeatM>{}, Int<kMmaEURepeatN>{}, Int<kMmaEURepeatK>{})));
|
||||
using MMA_V_T = Tile<Int<MmaVM>, Int<MmaVN>, Int<MmaVK>>;
|
||||
int smem_size = int(sizeof(SharedStorage<
|
||||
cute::bfloat16_t,
|
||||
cute::bfloat16_t,
|
||||
decltype(sA),
|
||||
decltype(sB)>));
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN)));
|
||||
|
||||
using TiledMMA =
|
||||
decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_V_T{}));
|
||||
auto kernel = gemm_device<
|
||||
decltype(prob_shape),
|
||||
decltype(cta_tiler),
|
||||
cute::bfloat16_t,
|
||||
decltype(dA),
|
||||
decltype(sA),
|
||||
decltype(copyA),
|
||||
decltype(s2r_atom_A),
|
||||
cute::bfloat16_t,
|
||||
decltype(dB),
|
||||
decltype(sB),
|
||||
decltype(copyB),
|
||||
decltype(s2r_atom_B),
|
||||
cute::bfloat16_t,
|
||||
decltype(dC),
|
||||
decltype(sC),
|
||||
decltype(mmaC)>;
|
||||
|
||||
// TiledCopy from global memory to shared memory
|
||||
// uint128_t is 16 bytes = 4 floats = 8 halfs
|
||||
static constexpr int NUM_VECTOR_UNITS =
|
||||
sizeof(cute::uint128_t) / sizeof(DataType);
|
||||
|
||||
using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
|
||||
using g2s_copy_traits = Copy_Traits<g2s_copy_op>;
|
||||
using g2s_copy_atom = Copy_Atom<g2s_copy_traits, DataType>;
|
||||
|
||||
// one block contains 128 threads
|
||||
// --> find the compatible thread layout
|
||||
using G2S_Copy_Thread_Layout = decltype(make_layout(
|
||||
make_shape(_32{}, _4{}), // 32x4 = 128 threads
|
||||
LayoutRight{} // A is in row-major
|
||||
));
|
||||
|
||||
using G2S_Copy_Value_Layout =
|
||||
decltype(make_layout(make_shape(_1{}, Int<NUM_VECTOR_UNITS>{})));
|
||||
|
||||
// This is for copy shape 32x4 of uint128_t
|
||||
using G2STiledCopyA = decltype(make_tiled_copy(
|
||||
g2s_copy_atom{}, G2S_Copy_Thread_Layout{}, G2S_Copy_Value_Layout{}));
|
||||
|
||||
// Both A and B are in row-major so use the same TiledCopy for B
|
||||
using G2STiledCopyB = G2STiledCopyA;
|
||||
|
||||
// CopyAtom from shared memory to registers
|
||||
// Why no need to do tiling atom here? Because we will do it later with
|
||||
// the information from TiledMMA
|
||||
using s2r_copy_op = SM75_U32x4_LDSM_N;
|
||||
using s2r_copy_traits = Copy_Traits<s2r_copy_op>;
|
||||
using s2r_copy_atom = Copy_Atom<s2r_copy_traits, DataType>;
|
||||
|
||||
using S2RCopyAtomA = s2r_copy_atom;
|
||||
using S2RCopyAtomB = s2r_copy_atom;
|
||||
|
||||
// print_latex(
|
||||
// make_tiled_copy(S2RCopyAtomA{}, make_layout(Shape<_32>{}),
|
||||
// make_layout(Shape<_1, _8>{}))
|
||||
// );
|
||||
|
||||
// grid, block
|
||||
dim3 block{size(TiledMMA{}), 1U, 1U};
|
||||
dim3 grid{
|
||||
size(ceil_div(static_cast<unsigned int>(N), BLOCK_SIZE_N)),
|
||||
size(ceil_div(static_cast<unsigned int>(M), BLOCK_SIZE_M)),
|
||||
1U};
|
||||
|
||||
static constexpr int smem_size =
|
||||
(cosize_v<SmemLayoutA> + cosize_v<SmemLayoutB>)*sizeof(DataType);
|
||||
|
||||
auto kernel = cute_gemm_v02<
|
||||
DataType,
|
||||
CtaTiler,
|
||||
SmemLayoutA,
|
||||
SmemLayoutB,
|
||||
TiledMMA,
|
||||
G2STiledCopyA,
|
||||
G2STiledCopyB,
|
||||
S2RCopyAtomA,
|
||||
S2RCopyAtomB>;
|
||||
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
configure_matmul(kernel, smem_size);
|
||||
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
dimGrid,
|
||||
dimBlock,
|
||||
smem_size,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a.data<DataType>(),
|
||||
K,
|
||||
b.data<DataType>(),
|
||||
K,
|
||||
out.data<DataType>(),
|
||||
N);
|
||||
prob_shape,
|
||||
cta_tiler,
|
||||
a.data<cute::bfloat16_t>(),
|
||||
dA,
|
||||
sA,
|
||||
copyA,
|
||||
s2r_atom_A,
|
||||
b.data<cute::bfloat16_t>(),
|
||||
dB,
|
||||
sB,
|
||||
copyB,
|
||||
s2r_atom_B,
|
||||
out.data<cute::bfloat16_t>(),
|
||||
dC,
|
||||
sC,
|
||||
mmaC);
|
||||
} else {
|
||||
throw std::runtime_error("Only bfloat16 supported");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user