Improve gemm

This commit is contained in:
Angelos Katharopoulos 2025-08-07 16:13:18 -07:00
parent 97afe40b7b
commit 6fce01593a
2 changed files with 34 additions and 39 deletions

View File

@ -26,15 +26,13 @@ void simple_gemm(
auto kernel = ab_t_aligned<DataType, BM, BN, BK>; auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
cudaFuncSetAttribute( cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
cudaFuncSetAttribute(
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
dim3 grid(N / BN, M / BM); dim3 grid(N / BN, M / BM);
enc.add_kernel_node( enc.add_kernel_node(
kernel, kernel,
grid, grid,
4 * WARP_SIZE, 8 * WARP_SIZE,
2 * sizeof(DataType) * (BM * BK + BN * BK), 2 * sizeof(DataType) * (BM * BK + BN * BK),
a.data<DataType>(), a.data<DataType>(),
b.data<DataType>(), b.data<DataType>(),

View File

@ -4,6 +4,30 @@
namespace mlx::core::cu { namespace mlx::core::cu {
template <typename T, int BM, int BN, int BK, int WM, int WN>
__device__ inline void gemm_ab_t(
RegisterTile<float, BM / WM, BN / WN>& C,
SharedTile<T, BM, BK>& As,
SharedTile<T, BM, BK>& Bs,
int lane_row_a,
int lane_row_b,
int lane_col) {
RegisterTile<T, BM / WM, 16> A[2];
RegisterTile<T, BN / WN, 16> B[2];
A[0].load(As, As.base_addr(), lane_row_a, lane_col);
B[0].load(Bs, Bs.base_addr(), lane_row_b, lane_col);
MLX_UNROLL
for (int k = 1; k < BK / 16; k++) {
A[k & 1].load(As, As.base_addr(), lane_row_a, lane_col + k * 16);
B[k & 1].load(Bs, Bs.base_addr(), lane_row_b, lane_col + k * 16);
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
}
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
}
/** /**
* An example gemm written with the utils. * An example gemm written with the utils.
* *
@ -11,7 +35,7 @@ namespace mlx::core::cu {
*/ */
template <typename T, int BM, int BN, int BK> template <typename T, int BM, int BN, int BK>
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int WARPS_M = 2; constexpr int WARPS_M = 4;
constexpr int WARPS_N = 2; constexpr int WARPS_N = 2;
constexpr int NUM_WARPS = WARPS_M * WARPS_N; constexpr int NUM_WARPS = WARPS_M * WARPS_N;
constexpr int WARP_STEP_M = BM / WARPS_M; constexpr int WARP_STEP_M = BM / WARPS_M;
@ -24,6 +48,9 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
const int wn = warpid % WARPS_N; const int wn = warpid % WARPS_N;
const int offset_m = wm * WARP_STEP_M; const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N; const int offset_n = wn * WARP_STEP_N;
const int lane_row_a = offset_m + (laneid & 15);
const int lane_row_b = offset_n + (laneid & 15);
const int lane_col = (laneid >> 4) << 3;
// Allocate shared memory // Allocate shared memory
extern __shared__ char shmem[]; extern __shared__ char shmem[];
@ -33,8 +60,6 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
// Allocate registers for the MMA // Allocate registers for the MMA
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C; RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
// Move the global pointers to the tile // Move the global pointers to the tile
a += blockIdx.y * BM * K; a += blockIdx.y * BM * K;
@ -55,45 +80,17 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K); load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
cp_async_commit(); cp_async_commit();
cp_async_wait<1>(); cp_async_wait<1>();
__syncthreads();
MLX_UNROLL gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
for (int k = 0; k < BK / 16; k++) { C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
tic ^= 1; tic ^= 1;
} }
// Empty the pipeline // Empty the pipeline
cp_async_wait_all(); cp_async_wait_all();
__syncthreads(); gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
MLX_UNROLL C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
C.store_global(y, N, offset_m, offset_n); C.store_global(y, N, offset_m, offset_n);
} }