mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-27 08:46:41 +08:00
Improve the cutlass gemm
This commit is contained in:
parent
e1303f6160
commit
4987e7615a
@ -5,11 +5,21 @@
|
|||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <cute/tensor.hpp>
|
#include <cute/tensor.hpp>
|
||||||
|
#include <cutlass/arch/arch.h>
|
||||||
|
#include <cutlass/cutlass.h>
|
||||||
|
#include <cutlass/gemm/device/gemm.h>
|
||||||
|
#include <cutlass/layout/matrix.h>
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using bf16 = cute::bfloat16_t;
|
||||||
|
|
||||||
template <typename Kernel>
|
template <typename Kernel>
|
||||||
void configure_matmul(Kernel kernel, int smem_size) {
|
void configure_matmul(Kernel kernel, int smem_size) {
|
||||||
static bool initialized = false;
|
static bool initialized = false;
|
||||||
@ -17,308 +27,278 @@ void configure_matmul(Kernel kernel, int smem_size) {
|
|||||||
initialized = true;
|
initialized = true;
|
||||||
cudaFuncSetAttribute(
|
cudaFuncSetAttribute(
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
cudaFuncSetAttribute(
|
|
||||||
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class ElementA, class ElementB, class SmemLayoutA, class SmemLayoutB>
|
template <bool transpose, typename Tiler>
|
||||||
struct SharedStorage {
|
constexpr int get_feature_size(Tiler smem) {
|
||||||
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A;
|
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
|
||||||
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B;
|
return (feature_size >= 64) ? 64 : feature_size;
|
||||||
};
|
}
|
||||||
|
|
||||||
|
constexpr int constexpr_log2(int x) {
|
||||||
|
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int feature_size, int itemsize, int copy_bits>
|
||||||
|
constexpr int get_swizzle_bits() {
|
||||||
|
constexpr int swizzle_bits =
|
||||||
|
constexpr_log2(feature_size * itemsize / copy_bits);
|
||||||
|
return (swizzle_bits > 3) ? 3 : swizzle_bits;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
||||||
|
constexpr auto make_smem_layout(Tiler smem) {
|
||||||
|
constexpr int feature_size = get_feature_size<transpose>(smem);
|
||||||
|
constexpr int swizzle_bits =
|
||||||
|
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
||||||
|
|
||||||
|
using F = Int<feature_size>;
|
||||||
|
using BaseLayout = std::conditional_t<
|
||||||
|
transpose,
|
||||||
|
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
||||||
|
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
||||||
|
|
||||||
|
auto swizzled =
|
||||||
|
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
|
||||||
|
|
||||||
|
return tile_to_shape(swizzled, smem);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
||||||
|
constexpr auto make_result_smem_layout(Tiler smem) {
|
||||||
|
constexpr int feature_size = get_feature_size<transpose>(smem);
|
||||||
|
constexpr int swizzle_bits =
|
||||||
|
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
||||||
|
|
||||||
|
using F = Int<feature_size>;
|
||||||
|
using BaseLayout = std::conditional_t<
|
||||||
|
transpose,
|
||||||
|
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
||||||
|
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
||||||
|
|
||||||
|
auto swizzled = make_composed_layout(
|
||||||
|
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
|
||||||
|
|
||||||
|
return tile_to_shape(swizzled, smem);
|
||||||
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
class ProblemShape,
|
int num_threads,
|
||||||
class CtaTiler,
|
int itemsize,
|
||||||
class TA,
|
bool transpose,
|
||||||
class AStride,
|
int copy_bits,
|
||||||
class ASmemLayout,
|
typename Copier,
|
||||||
class TiledCopyA,
|
typename Tiler>
|
||||||
class S2RAtomA,
|
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
|
||||||
class TB,
|
constexpr int num_elements = copy_bits / itemsize;
|
||||||
class BStride,
|
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
|
||||||
class BSmemLayout,
|
constexpr int copies_per_feature = feature_size / num_elements;
|
||||||
class TiledCopyB,
|
|
||||||
class S2RAtomB,
|
|
||||||
class TC,
|
|
||||||
class CStride,
|
|
||||||
class CSmemLayout,
|
|
||||||
class TiledMma>
|
|
||||||
__global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device(
|
|
||||||
ProblemShape shape_MNK,
|
|
||||||
CtaTiler cta_tiler,
|
|
||||||
TA const* A,
|
|
||||||
AStride dA,
|
|
||||||
ASmemLayout sA_layout,
|
|
||||||
TiledCopyA copy_a,
|
|
||||||
S2RAtomA s2r_atom_a,
|
|
||||||
TB const* B,
|
|
||||||
BStride dB,
|
|
||||||
BSmemLayout sB_layout,
|
|
||||||
TiledCopyB copy_b,
|
|
||||||
S2RAtomB s2r_atom_b,
|
|
||||||
TC* C,
|
|
||||||
CStride dC,
|
|
||||||
CSmemLayout,
|
|
||||||
TiledMma mma) {
|
|
||||||
using namespace cute;
|
|
||||||
|
|
||||||
// Preconditions
|
using E = Int<num_elements>;
|
||||||
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
|
using C = Int<copies_per_feature>;
|
||||||
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
|
using R = Int<num_threads / copies_per_feature>;
|
||||||
|
|
||||||
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
|
using ThreadLayout = std::conditional_t<
|
||||||
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
|
transpose,
|
||||||
|
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
|
||||||
|
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
|
||||||
|
using ValueLayout = std::conditional_t<
|
||||||
|
transpose,
|
||||||
|
Layout<cute::Shape<E, _1>>,
|
||||||
|
Layout<cute::Shape<_1, E>>>;
|
||||||
|
|
||||||
static_assert(is_static<ASmemLayout>::value);
|
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
|
||||||
static_assert(is_static<BSmemLayout>::value);
|
}
|
||||||
static_assert(is_static<CSmemLayout>::value);
|
|
||||||
|
|
||||||
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
template <int rasterization_factor>
|
||||||
CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
__device__ inline int2 raster_tile(int x, int y) {
|
||||||
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
return {
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
x / rasterization_factor,
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
(x % rasterization_factor) + y * rasterization_factor};
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
}
|
||||||
|
|
||||||
CUTE_STATIC_ASSERT_V(
|
template <
|
||||||
congruent(select<0, 2>(shape_MNK), dA)); // dA strides for shape MK
|
typename T,
|
||||||
CUTE_STATIC_ASSERT_V(
|
typename SLayoutA,
|
||||||
congruent(select<1, 2>(shape_MNK), dB)); // dB strides for shape NK
|
typename SLayoutB,
|
||||||
CUTE_STATIC_ASSERT_V(
|
typename SLayoutC,
|
||||||
congruent(select<0, 1>(shape_MNK), dC)); // dC strides for shape MN
|
typename CopyA,
|
||||||
|
typename CopyB,
|
||||||
|
typename CopyC,
|
||||||
|
typename MMA,
|
||||||
|
int rasterization_factor>
|
||||||
|
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
|
||||||
|
const T* __restrict__ A,
|
||||||
|
const T* __restrict__ B,
|
||||||
|
T* __restrict__ C,
|
||||||
|
SLayoutA SA,
|
||||||
|
SLayoutB SB,
|
||||||
|
SLayoutC SC,
|
||||||
|
CopyA copy_a,
|
||||||
|
CopyB copy_b,
|
||||||
|
CopyC copy_c,
|
||||||
|
MMA mma,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr auto BM = size<0>(SA);
|
||||||
|
constexpr auto BN = size<0>(SB);
|
||||||
|
constexpr auto BK = size<1>(SA);
|
||||||
|
constexpr auto PIPE = size<2>(SA);
|
||||||
|
|
||||||
//
|
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
|
||||||
// Full and Tiled Tensors
|
const int blocks_m = ceil_div(M, BM);
|
||||||
//
|
const int blocks_n = ceil_div(N, BN);
|
||||||
|
|
||||||
// Represent the full tensors
|
// Exit early if the tile is OOB
|
||||||
Tensor mA =
|
if (tile.x >= blocks_m || tile.y >= blocks_n) {
|
||||||
make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K)
|
return;
|
||||||
Tensor mB =
|
}
|
||||||
make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K)
|
|
||||||
Tensor mC =
|
|
||||||
make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N)
|
|
||||||
|
|
||||||
// Get the appropriate blocks for this thread block
|
// Make the full tensors
|
||||||
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
Tensor full_A =
|
||||||
Tensor gA = local_tile(
|
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
|
||||||
mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
|
Tensor full_B =
|
||||||
Tensor gB = local_tile(
|
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
|
||||||
mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
|
Tensor full_C =
|
||||||
Tensor gC =
|
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
|
||||||
local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)
|
|
||||||
|
|
||||||
// Shared memory buffers
|
// Partition the tensors into tiles and select the ones for this threadblock
|
||||||
|
Tensor local_A =
|
||||||
|
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
|
||||||
|
Tensor local_B =
|
||||||
|
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
|
||||||
|
Tensor local_C =
|
||||||
|
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
|
||||||
|
|
||||||
|
// Make shared memory tensors
|
||||||
extern __shared__ char shared_memory[];
|
extern __shared__ char shared_memory[];
|
||||||
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
|
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
|
||||||
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
|
T* shared_B_ptr =
|
||||||
Tensor sA = make_tensor(
|
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
|
||||||
make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE)
|
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
|
||||||
Tensor sB = make_tensor(
|
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
|
||||||
make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE)
|
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
|
||||||
|
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
|
||||||
|
|
||||||
//
|
// Get the copies that correspond to this thread
|
||||||
// Partition the copying of A and B tiles across the threads
|
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
|
||||||
//
|
Tensor local_A_src = thread_copy_a.partition_S(local_A);
|
||||||
|
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
|
||||||
|
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
|
||||||
|
Tensor local_B_src = thread_copy_a.partition_S(local_B);
|
||||||
|
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
|
||||||
|
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
|
||||||
|
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
|
||||||
|
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
|
||||||
|
|
||||||
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
|
// Start fetches
|
||||||
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
|
int k_tile_count = size<2>(local_A);
|
||||||
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE)
|
|
||||||
|
|
||||||
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
|
|
||||||
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
|
|
||||||
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE)
|
|
||||||
|
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
|
|
||||||
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
|
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
|
|
||||||
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
|
|
||||||
|
|
||||||
//
|
|
||||||
// PREFETCH
|
|
||||||
//
|
|
||||||
|
|
||||||
auto K_PIPE_MAX = size<3>(tAsA);
|
|
||||||
|
|
||||||
// Total count of tiles
|
|
||||||
int k_tile_count = size<3>(tAgA);
|
|
||||||
// Current tile index in gmem to read from
|
|
||||||
int k_tile_next = 0;
|
int k_tile_next = 0;
|
||||||
|
|
||||||
// Start async loads for all pipes but the last
|
|
||||||
CUTE_UNROLL
|
CUTE_UNROLL
|
||||||
for (int k_pipe = 0; k_pipe < K_PIPE_MAX - 1; ++k_pipe) {
|
for (int k = 0; k < PIPE - 1; k++) {
|
||||||
copy(copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe));
|
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
|
||||||
copy(copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe));
|
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
--k_tile_count;
|
k_tile_count--;
|
||||||
if (k_tile_count > 0) {
|
k_tile_next += (k_tile_count > 0);
|
||||||
++k_tile_next;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
// Get the MMA that corresponds to this thread and allocate registers
|
||||||
// Define A/B partitioning and C accumulators
|
auto thread_mma = mma.get_slice(threadIdx.x);
|
||||||
//
|
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
|
||||||
|
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
|
||||||
|
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
|
||||||
|
Tensor mma_global_C = thread_mma.partition_C(local_C);
|
||||||
|
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
|
||||||
|
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
|
||||||
|
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
|
||||||
|
clear(mma_frag_C);
|
||||||
|
|
||||||
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
|
// Make shared to register copies
|
||||||
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
|
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
|
||||||
|
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
|
||||||
|
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
|
||||||
|
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
|
||||||
|
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
|
||||||
|
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
|
||||||
|
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
|
||||||
|
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
|
||||||
|
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
|
||||||
|
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
|
||||||
|
|
||||||
// Allocate registers for pipelining
|
constexpr auto RPIPE = size<2>(mma_shared_A);
|
||||||
Tensor tCrA = thr_mma.partition_fragment_A(sA(_, _, 0)); // (MMA,MMA_M,MMA_K)
|
int smem_read = 0;
|
||||||
Tensor tCrB = thr_mma.partition_fragment_B(sB(_, _, 0)); // (MMA,MMA_N,MMA_K)
|
int smem_write = PIPE - 1;
|
||||||
// Allocate the accumulators -- same size as the projected data
|
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
||||||
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
||||||
|
|
||||||
CUTE_STATIC_ASSERT_V(
|
// Start the register pipeline
|
||||||
(shape(tCrC) == take<0, 3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
|
if constexpr (RPIPE > 1) {
|
||||||
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M
|
cp_async_wait<PIPE - 2>();
|
||||||
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N
|
|
||||||
|
|
||||||
// Clear the accumulators
|
|
||||||
clear(tCrC);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Copy Atom retiling
|
|
||||||
//
|
|
||||||
|
|
||||||
TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
|
|
||||||
ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x);
|
|
||||||
Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE)
|
|
||||||
Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K)
|
|
||||||
|
|
||||||
TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
|
|
||||||
ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x);
|
|
||||||
Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE)
|
|
||||||
Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K)
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
if(thread0()) {
|
|
||||||
print(" mA : "); print( mA); print("\n");
|
|
||||||
print(" gA : "); print( gA); print("\n");
|
|
||||||
print(" sA : "); print( sA); print("\n");
|
|
||||||
print("tAgA : "); print(tAgA); print("\n");
|
|
||||||
print("tAsA : "); print(tAsA); print("\n");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
if(thread0()) {
|
|
||||||
print(" mB : "); print( mB); print("\n");
|
|
||||||
print(" gB : "); print( gB); print("\n");
|
|
||||||
print(" sB : "); print( sB); print("\n");
|
|
||||||
print("tBgB : "); print(tBgB); print("\n");
|
|
||||||
print("tBsB : "); print(tBsB); print("\n");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
if(thread0()) {
|
|
||||||
print(" mC : "); print( mC); print("\n");
|
|
||||||
print(" gC : "); print( gC); print("\n");
|
|
||||||
print("tCgC : "); print(tCgC); print("\n");
|
|
||||||
print("tCrA : "); print(tCrA); print("\n");
|
|
||||||
print("tCrB : "); print(tCrB); print("\n");
|
|
||||||
print("tCrC : "); print(tCrC); print("\n");
|
|
||||||
|
|
||||||
print("tXsA : "); print(tXsA); print("\n");
|
|
||||||
print("tXrA : "); print(tXrA); print("\n");
|
|
||||||
print("tXsB : "); print(tXsB); print("\n");
|
|
||||||
print("tXrB : "); print(tXrB); print("\n");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if 1
|
|
||||||
|
|
||||||
// Current pipe index in smem to read from
|
|
||||||
int smem_pipe_read = 0;
|
|
||||||
// Current pipe index in smem to write to
|
|
||||||
int smem_pipe_write = K_PIPE_MAX - 1;
|
|
||||||
|
|
||||||
// Pipe slice
|
|
||||||
Tensor tXsA_p = tXsA(_, _, _, smem_pipe_read);
|
|
||||||
Tensor tXsB_p = tXsB(_, _, _, smem_pipe_read);
|
|
||||||
|
|
||||||
// Size of the register pipeline
|
|
||||||
auto K_BLOCK_MAX = size<2>(tCrA);
|
|
||||||
CUTE_STATIC_ASSERT_V(K_BLOCK_MAX == size<2>(tXrA));
|
|
||||||
|
|
||||||
// PREFETCH register pipeline
|
|
||||||
if (K_BLOCK_MAX > 1) {
|
|
||||||
// Wait until our first prefetched tile is loaded in
|
|
||||||
cp_async_wait<K_PIPE_MAX - 2>();
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
|
||||||
// Prefetch the first rmem from the first k-tile
|
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
|
||||||
copy(s2r_atom_a, tXsA_p(_, _, Int<0>{}), tXrA(_, _, Int<0>{}));
|
|
||||||
copy(s2r_atom_b, tXsB_p(_, _, Int<0>{}), tXrB(_, _, Int<0>{}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// PIPELINED MAIN LOOP
|
|
||||||
// TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's
|
|
||||||
// cp.async instructions
|
|
||||||
// and explicit pipelines in shared memory.
|
|
||||||
// Data is read from global(k_tile_next) to shared(smem_pipe_write).
|
|
||||||
// Data is read from shared(smem_pipe_read) to registers(k_block_next).
|
|
||||||
// Data is computed on registers(b_block).
|
|
||||||
//
|
|
||||||
// This allows all copies and compute to overlap:
|
|
||||||
// Copy from gmem->smem can overlap with copies from smem->rmem and
|
|
||||||
// compute on rmem. Copy from smem->rmem can overlap with compute on rmem.
|
|
||||||
//
|
|
||||||
|
|
||||||
CUTE_NO_UNROLL
|
CUTE_NO_UNROLL
|
||||||
while (k_tile_count > -(K_PIPE_MAX - 1)) {
|
while (k_tile_count > -(PIPE - 1)) {
|
||||||
CUTE_UNROLL
|
CUTE_UNROLL
|
||||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
|
for (int k_block = 0; k_block < RPIPE; k_block++) {
|
||||||
if (k_block == K_BLOCK_MAX - 1) {
|
if (k_block == RPIPE - 1) {
|
||||||
// Slice the smem_pipe_read smem
|
mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
||||||
tXsA_p = tXsA(_, _, _, smem_pipe_read);
|
mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
||||||
tXsB_p = tXsB(_, _, _, smem_pipe_read);
|
cp_async_wait<PIPE - 2>();
|
||||||
|
|
||||||
// Commit the smem for smem_pipe_read
|
|
||||||
cp_async_wait<K_PIPE_MAX - 2>();
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load A, B shmem->regs for k_block+1
|
// Load the next register tile
|
||||||
auto k_block_next = (k_block + 1) % K_BLOCK_MAX; // static
|
auto k_block_next = (k_block + 1) % RPIPE;
|
||||||
copy(s2r_atom_a, tXsA_p(_, _, k_block_next), tXrA(_, _, k_block_next));
|
copy(
|
||||||
copy(s2r_atom_b, tXsB_p(_, _, k_block_next), tXrB(_, _, k_block_next));
|
s2r_copy_a,
|
||||||
// Copy gmem to smem before computing gemm on each k-pipe
|
mma_A_src_p(_, _, k_block_next),
|
||||||
|
mma_A_dst(_, _, k_block_next));
|
||||||
|
copy(
|
||||||
|
s2r_copy_b,
|
||||||
|
mma_B_src_p(_, _, k_block_next),
|
||||||
|
mma_B_dst(_, _, k_block_next));
|
||||||
|
|
||||||
if (k_block == 0) {
|
if (k_block == 0) {
|
||||||
copy(
|
copy(
|
||||||
copy_a, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write));
|
copy_a,
|
||||||
|
local_A_src(_, _, _, k_tile_next),
|
||||||
|
local_A_dst(_, _, _, smem_write));
|
||||||
copy(
|
copy(
|
||||||
copy_b, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write));
|
copy_b,
|
||||||
|
local_B_src(_, _, _, k_tile_next),
|
||||||
|
local_B_dst(_, _, _, smem_write));
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
|
k_tile_count--;
|
||||||
// Advance the gmem tile
|
k_tile_next += (k_tile_count > 0);
|
||||||
--k_tile_count;
|
smem_write = smem_read;
|
||||||
if (k_tile_count > 0) {
|
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
|
||||||
++k_tile_next;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Advance the smem pipe
|
|
||||||
smem_pipe_write = smem_pipe_read;
|
|
||||||
smem_pipe_read =
|
|
||||||
(smem_pipe_read == K_PIPE_MAX - 1) ? 0 : smem_pipe_read + 1;
|
|
||||||
}
|
}
|
||||||
// Thread-level register gemm for k_block
|
|
||||||
gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
gemm(
|
||||||
|
mma,
|
||||||
|
mma_frag_A(_, _, k_block),
|
||||||
|
mma_frag_B(_, _, k_block),
|
||||||
|
mma_frag_C);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
copy(mma_frag_C, mma_shared_C);
|
||||||
|
__syncthreads();
|
||||||
|
copy(copy_c, local_C_src, local_C_dst);
|
||||||
|
|
||||||
|
// if (threadIdx.x == 0) {
|
||||||
|
// print("fC: "); print(mma_frag_C); print("\n");
|
||||||
|
// print("sC: "); print(mma_shared_C); print("\n");
|
||||||
|
// print("dC: "); print(local_C_dst); print("\n");
|
||||||
//
|
//
|
||||||
// Epilogue
|
// print(s2r_atom_a); print("\n");
|
||||||
//
|
// }
|
||||||
|
|
||||||
copy(tCrC, tCgC);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -339,103 +319,74 @@ void cutlass_gemm(
|
|||||||
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
|
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
// Define shapes (dynamic)
|
// Tile definitions
|
||||||
auto prob_shape = make_shape(M, N, K);
|
auto BM = Int<128>{};
|
||||||
|
auto BN = Int<128>{};
|
||||||
|
auto BK = Int<64>{};
|
||||||
|
auto BP = Int<3>{};
|
||||||
|
auto GM = Int<8>{};
|
||||||
|
|
||||||
// Define TN strides (mixed)
|
// Thread definitions
|
||||||
auto dA = make_stride(K, Int<1>{});
|
using TM = Int<2>;
|
||||||
auto dB = make_stride(K, Int<1>{});
|
using TN = Int<2>;
|
||||||
auto dC = make_stride(N, Int<1>{});
|
using TK = Int<1>;
|
||||||
|
constexpr int num_threads = TM::value * TN::value * 32;
|
||||||
|
|
||||||
// Define CTA tile sizes (static)
|
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
|
||||||
auto bM = Int<128>{};
|
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
|
||||||
auto bN = Int<128>{};
|
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
|
||||||
auto bK = Int<64>{};
|
|
||||||
auto cta_tiler = make_shape(bM, bN, bK);
|
|
||||||
auto bP = Int<3>{};
|
|
||||||
|
|
||||||
// Define the smem layouts (static)
|
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
|
||||||
// Swizzles for LDSM and 128b k-major loads
|
|
||||||
auto swizzle_atom = composition(
|
|
||||||
Swizzle<3, 3, 3>{},
|
|
||||||
Layout<
|
|
||||||
cute::Shape<_8, cute::Shape<_8, _8>>,
|
|
||||||
cute::Stride<_8, cute::Stride<_1, _64>>>{});
|
|
||||||
|
|
||||||
auto sA = tile_to_shape(swizzle_atom, make_shape(bM, bK, bP));
|
auto async_copy_op =
|
||||||
auto sB = tile_to_shape(swizzle_atom, make_shape(bN, bK, bP));
|
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
|
||||||
auto sC = make_layout(make_shape(bM, bN));
|
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
|
||||||
|
async_copy_op, make_shape(BM, BK));
|
||||||
|
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
|
||||||
|
async_copy_op, make_shape(BN, BK));
|
||||||
|
|
||||||
// Define the thread layouts (static)
|
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
|
||||||
|
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
|
||||||
|
sync_copy_op, make_shape(BM, BN));
|
||||||
|
|
||||||
TiledCopy copyA = make_tiled_copy(
|
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
|
||||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
|
auto tiled_mma = make_tiled_mma(
|
||||||
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
|
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
|
||||||
// 16x8 k-major
|
|
||||||
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 k-major
|
|
||||||
TiledCopy copyB = make_tiled_copy(
|
|
||||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::bfloat16_t>{},
|
|
||||||
Layout<cute::Shape<_16, _8>, cute::Stride<_8, _1>>{}, // Thr layout
|
|
||||||
// 16x8 k-major
|
|
||||||
Layout<cute::Shape<_1, _8>>{}); // Val layout 1x8 n-major
|
|
||||||
|
|
||||||
TiledMMA mmaC = make_tiled_mma(
|
|
||||||
SM80_16x8x16_F32BF16BF16F32_TN{},
|
|
||||||
Layout<cute::Shape<_2, _2>>{}, // 2x2x1 MMA Atoms
|
|
||||||
Tile<_32, _32, _16>{}); // 32x32x16 Tiled MMA for LDSM
|
|
||||||
|
|
||||||
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_A;
|
|
||||||
Copy_Atom<SM75_U32x4_LDSM_N, cute::bfloat16_t> s2r_atom_B;
|
|
||||||
|
|
||||||
int smem_size = int(sizeof(SharedStorage<
|
|
||||||
cute::bfloat16_t,
|
|
||||||
cute::bfloat16_t,
|
|
||||||
decltype(sA),
|
|
||||||
decltype(sB)>));
|
|
||||||
dim3 dimBlock(size(mmaC));
|
|
||||||
dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN)));
|
|
||||||
|
|
||||||
auto kernel = gemm_device<
|
|
||||||
decltype(prob_shape),
|
|
||||||
decltype(cta_tiler),
|
|
||||||
cute::bfloat16_t,
|
|
||||||
decltype(dA),
|
|
||||||
decltype(sA),
|
|
||||||
decltype(copyA),
|
|
||||||
decltype(s2r_atom_A),
|
|
||||||
cute::bfloat16_t,
|
|
||||||
decltype(dB),
|
|
||||||
decltype(sB),
|
|
||||||
decltype(copyB),
|
|
||||||
decltype(s2r_atom_B),
|
|
||||||
cute::bfloat16_t,
|
|
||||||
decltype(dC),
|
|
||||||
decltype(sC),
|
|
||||||
decltype(mmaC)>;
|
|
||||||
|
|
||||||
|
auto kernel = matmul_kernel<
|
||||||
|
bf16,
|
||||||
|
decltype(SA),
|
||||||
|
decltype(SB),
|
||||||
|
decltype(SC),
|
||||||
|
decltype(tiled_copy_a),
|
||||||
|
decltype(tiled_copy_b),
|
||||||
|
decltype(tiled_copy_c),
|
||||||
|
decltype(tiled_mma),
|
||||||
|
GM.value>;
|
||||||
configure_matmul(kernel, smem_size);
|
configure_matmul(kernel, smem_size);
|
||||||
|
|
||||||
|
dim3 block(size(tiled_mma));
|
||||||
|
dim3 grid(
|
||||||
|
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
|
||||||
|
|
||||||
enc.add_kernel_node(
|
enc.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
dimGrid,
|
grid,
|
||||||
dimBlock,
|
block,
|
||||||
smem_size,
|
smem_size,
|
||||||
prob_shape,
|
a.data<bf16>(),
|
||||||
cta_tiler,
|
b.data<bf16>(),
|
||||||
a.data<cute::bfloat16_t>(),
|
out.data<bf16>(),
|
||||||
dA,
|
SA,
|
||||||
sA,
|
SB,
|
||||||
copyA,
|
SC,
|
||||||
s2r_atom_A,
|
tiled_copy_a,
|
||||||
b.data<cute::bfloat16_t>(),
|
tiled_copy_b,
|
||||||
dB,
|
tiled_copy_c,
|
||||||
sB,
|
tiled_mma,
|
||||||
copyB,
|
M,
|
||||||
s2r_atom_B,
|
N,
|
||||||
out.data<cute::bfloat16_t>(),
|
K);
|
||||||
dC,
|
|
||||||
sC,
|
|
||||||
mmaC);
|
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Only bfloat16 supported");
|
throw std::runtime_error("Only bfloat16 supported");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user