Improve gemm

This commit is contained in:
Angelos Katharopoulos 2025-08-07 16:13:18 -07:00
parent 97afe40b7b
commit 6fce01593a
2 changed files with 34 additions and 39 deletions

View File

@ -26,15 +26,13 @@ void simple_gemm(
auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304);
cudaFuncSetAttribute(
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
dim3 grid(N / BN, M / BM);
enc.add_kernel_node(
kernel,
grid,
4 * WARP_SIZE,
8 * WARP_SIZE,
2 * sizeof(DataType) * (BM * BK + BN * BK),
a.data<DataType>(),
b.data<DataType>(),

View File

@ -4,6 +4,30 @@
namespace mlx::core::cu {
template <typename T, int BM, int BN, int BK, int WM, int WN>
__device__ inline void gemm_ab_t(
RegisterTile<float, BM / WM, BN / WN>& C,
SharedTile<T, BM, BK>& As,
SharedTile<T, BM, BK>& Bs,
int lane_row_a,
int lane_row_b,
int lane_col) {
RegisterTile<T, BM / WM, 16> A[2];
RegisterTile<T, BN / WN, 16> B[2];
A[0].load(As, As.base_addr(), lane_row_a, lane_col);
B[0].load(Bs, Bs.base_addr(), lane_row_b, lane_col);
MLX_UNROLL
for (int k = 1; k < BK / 16; k++) {
A[k & 1].load(As, As.base_addr(), lane_row_a, lane_col + k * 16);
B[k & 1].load(Bs, Bs.base_addr(), lane_row_b, lane_col + k * 16);
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
}
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
}
/**
* An example gemm written with the utils.
*
@ -11,7 +35,7 @@ namespace mlx::core::cu {
*/
template <typename T, int BM, int BN, int BK>
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int WARPS_M = 2;
constexpr int WARPS_M = 4;
constexpr int WARPS_N = 2;
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
constexpr int WARP_STEP_M = BM / WARPS_M;
@ -24,6 +48,9 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
const int wn = warpid % WARPS_N;
const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N;
const int lane_row_a = offset_m + (laneid & 15);
const int lane_row_b = offset_n + (laneid & 15);
const int lane_col = (laneid >> 4) << 3;
// Allocate shared memory
extern __shared__ char shmem[];
@ -33,8 +60,6 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
// Allocate registers for the MMA
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
// Move the global pointers to the tile
a += blockIdx.y * BM * K;
@ -55,45 +80,17 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
cp_async_commit();
cp_async_wait<1>();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
tic ^= 1;
}
// Empty the pipeline
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
C.store_global(y, N, offset_m, offset_n);
}