diff --git a/mlx/backend/cuda/gemms/simple_gemm.cu b/mlx/backend/cuda/gemms/simple_gemm.cu index 51dc9569f..7598b3795 100644 --- a/mlx/backend/cuda/gemms/simple_gemm.cu +++ b/mlx/backend/cuda/gemms/simple_gemm.cu @@ -22,7 +22,7 @@ void simple_gemm( using DataType = cuda_type_t; constexpr int BM = 128; constexpr int BN = 128; - constexpr int BK = 64; + constexpr int BK = 32; auto kernel = ab_t_aligned; cudaFuncSetAttribute( @@ -33,7 +33,7 @@ void simple_gemm( kernel, grid, 8 * WARP_SIZE, - 2 * sizeof(DataType) * (BM * BK + BN * BK), + 4 * sizeof(DataType) * (BM * BK + BN * BK), a.data(), b.data(), out.data(), diff --git a/mlx/backend/cuda/steel/gemm.cuh b/mlx/backend/cuda/steel/gemm.cuh index 0c48d4a50..31ba00fcf 100644 --- a/mlx/backend/cuda/steel/gemm.cuh +++ b/mlx/backend/cuda/steel/gemm.cuh @@ -40,6 +40,7 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { constexpr int NUM_WARPS = WARPS_M * WARPS_N; constexpr int WARP_STEP_M = BM / WARPS_M; constexpr int WARP_STEP_N = BN / WARPS_N; + constexpr int PIPE = 4; // Precompute some offsets for each thread const int warpid = threadIdx.x / 32; @@ -54,43 +55,49 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { // Allocate shared memory extern __shared__ char shmem[]; - SharedTile(&as)[2] = *(SharedTile(*)[2])(&shmem[0]); - SharedTile(&bs)[2] = - *(SharedTile(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]); - - // Allocate registers for the MMA - RegisterTile C; + SharedTile(&as)[PIPE] = + *(SharedTile(*)[PIPE])(&shmem[0]); + SharedTile(&bs)[PIPE] = + *(SharedTile(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]); // Move the global pointers to the tile a += blockIdx.y * BM * K; b += blockIdx.x * BN * K; y += blockIdx.y * BM * N + blockIdx.x * BN; - // Zero the accumulators - C.fill(0); - // Start the SM pipeline - load_async(as[0], as[0].base_addr(), a, K); - load_async(bs[0], bs[0].base_addr(), b, K); - cp_async_commit(); - - int tic = 0; - for (int k_block = BK; k_block < K; k_block += BK) { - load_async(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K); - load_async(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K); + MLX_UNROLL + for (int i = 0; i < PIPE - 1; i++) { + load_async(as[i], as[i].base_addr(), a + i * BK, K); + load_async(bs[i], bs[i].base_addr(), b + i * BK, K); cp_async_commit(); - cp_async_wait<1>(); - - gemm_ab_t( - C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col); - - tic ^= 1; } - // Empty the pipeline - cp_async_wait_all(); - gemm_ab_t( - C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col); + // Allocate and zero the MMA accumulator + RegisterTile C; + C.fill(0); + + // Matmul loop + int num_blocks = K / BK; + int k_block = (PIPE - 1) * BK; + int sread = 0; + int swrite = PIPE - 1; + for (int i = 0; i < num_blocks; i++) { + cp_async_wait(); + + if (k_block < K) { + load_async(as[swrite], as[swrite].base_addr(), a + k_block, K); + load_async(bs[swrite], bs[swrite].base_addr(), b + k_block, K); + } + + gemm_ab_t( + C, as[sread], bs[sread], lane_row_a, lane_row_b, lane_col); + + cp_async_commit(); + + swrite = sread; + sread = (sread + 1) % PIPE; + } C.store_global(y, N, offset_m, offset_n); }