mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 14:59:22 +08:00
More pipelining for the sm_80 gemm
This commit is contained in:
parent
6fce01593a
commit
05583bcd10
@ -22,7 +22,7 @@ void simple_gemm(
|
|||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
constexpr int BM = 128;
|
constexpr int BM = 128;
|
||||||
constexpr int BN = 128;
|
constexpr int BN = 128;
|
||||||
constexpr int BK = 64;
|
constexpr int BK = 32;
|
||||||
|
|
||||||
auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
|
auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
|
||||||
cudaFuncSetAttribute(
|
cudaFuncSetAttribute(
|
||||||
@ -33,7 +33,7 @@ void simple_gemm(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
8 * WARP_SIZE,
|
8 * WARP_SIZE,
|
||||||
2 * sizeof(DataType) * (BM * BK + BN * BK),
|
4 * sizeof(DataType) * (BM * BK + BN * BK),
|
||||||
a.data<DataType>(),
|
a.data<DataType>(),
|
||||||
b.data<DataType>(),
|
b.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
|
@ -40,6 +40,7 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
|||||||
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
||||||
constexpr int WARP_STEP_M = BM / WARPS_M;
|
constexpr int WARP_STEP_M = BM / WARPS_M;
|
||||||
constexpr int WARP_STEP_N = BN / WARPS_N;
|
constexpr int WARP_STEP_N = BN / WARPS_N;
|
||||||
|
constexpr int PIPE = 4;
|
||||||
|
|
||||||
// Precompute some offsets for each thread
|
// Precompute some offsets for each thread
|
||||||
const int warpid = threadIdx.x / 32;
|
const int warpid = threadIdx.x / 32;
|
||||||
@ -54,43 +55,49 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
|||||||
|
|
||||||
// Allocate shared memory
|
// Allocate shared memory
|
||||||
extern __shared__ char shmem[];
|
extern __shared__ char shmem[];
|
||||||
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
|
SharedTile<T, BM, BK>(&as)[PIPE] =
|
||||||
SharedTile<T, BN, BK>(&bs)[2] =
|
*(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
|
||||||
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
|
SharedTile<T, BN, BK>(&bs)[PIPE] =
|
||||||
|
*(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
|
||||||
// Allocate registers for the MMA
|
|
||||||
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
|
||||||
|
|
||||||
// Move the global pointers to the tile
|
// Move the global pointers to the tile
|
||||||
a += blockIdx.y * BM * K;
|
a += blockIdx.y * BM * K;
|
||||||
b += blockIdx.x * BN * K;
|
b += blockIdx.x * BN * K;
|
||||||
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||||
|
|
||||||
// Zero the accumulators
|
|
||||||
C.fill(0);
|
|
||||||
|
|
||||||
// Start the SM pipeline
|
// Start the SM pipeline
|
||||||
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
|
MLX_UNROLL
|
||||||
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
|
for (int i = 0; i < PIPE - 1; i++) {
|
||||||
cp_async_commit();
|
load_async<NUM_WARPS>(as[i], as[i].base_addr(), a + i * BK, K);
|
||||||
|
load_async<NUM_WARPS>(bs[i], bs[i].base_addr(), b + i * BK, K);
|
||||||
int tic = 0;
|
|
||||||
for (int k_block = BK; k_block < K; k_block += BK) {
|
|
||||||
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
|
|
||||||
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
|
||||||
cp_async_commit();
|
cp_async_commit();
|
||||||
cp_async_wait<1>();
|
|
||||||
|
|
||||||
gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
|
|
||||||
C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
|
|
||||||
|
|
||||||
tic ^= 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty the pipeline
|
// Allocate and zero the MMA accumulator
|
||||||
cp_async_wait_all();
|
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
||||||
gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
|
C.fill(0);
|
||||||
C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
|
|
||||||
|
// Matmul loop
|
||||||
|
int num_blocks = K / BK;
|
||||||
|
int k_block = (PIPE - 1) * BK;
|
||||||
|
int sread = 0;
|
||||||
|
int swrite = PIPE - 1;
|
||||||
|
for (int i = 0; i < num_blocks; i++) {
|
||||||
|
cp_async_wait<PIPE - 2>();
|
||||||
|
|
||||||
|
if (k_block < K) {
|
||||||
|
load_async<NUM_WARPS>(as[swrite], as[swrite].base_addr(), a + k_block, K);
|
||||||
|
load_async<NUM_WARPS>(bs[swrite], bs[swrite].base_addr(), b + k_block, K);
|
||||||
|
}
|
||||||
|
|
||||||
|
gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
|
||||||
|
C, as[sread], bs[sread], lane_row_a, lane_row_b, lane_col);
|
||||||
|
|
||||||
|
cp_async_commit();
|
||||||
|
|
||||||
|
swrite = sread;
|
||||||
|
sread = (sread + 1) % PIPE;
|
||||||
|
}
|
||||||
|
|
||||||
C.store_global(y, N, offset_m, offset_n);
|
C.store_global(y, N, offset_m, offset_n);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user