diff --git a/mlx/backend/cuda/gemms/simple_gemm.cu b/mlx/backend/cuda/gemms/simple_gemm.cu index 12ceda068..51dc9569f 100644 --- a/mlx/backend/cuda/gemms/simple_gemm.cu +++ b/mlx/backend/cuda/gemms/simple_gemm.cu @@ -26,15 +26,13 @@ void simple_gemm( auto kernel = ab_t_aligned; cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304); - cudaFuncSetAttribute( - kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); dim3 grid(N / BN, M / BM); enc.add_kernel_node( kernel, grid, - 4 * WARP_SIZE, + 8 * WARP_SIZE, 2 * sizeof(DataType) * (BM * BK + BN * BK), a.data(), b.data(), diff --git a/mlx/backend/cuda/steel/gemm.cuh b/mlx/backend/cuda/steel/gemm.cuh index 99580d2de..0c48d4a50 100644 --- a/mlx/backend/cuda/steel/gemm.cuh +++ b/mlx/backend/cuda/steel/gemm.cuh @@ -4,6 +4,30 @@ namespace mlx::core::cu { +template +__device__ inline void gemm_ab_t( + RegisterTile& C, + SharedTile& As, + SharedTile& Bs, + int lane_row_a, + int lane_row_b, + int lane_col) { + RegisterTile A[2]; + RegisterTile B[2]; + + A[0].load(As, As.base_addr(), lane_row_a, lane_col); + B[0].load(Bs, Bs.base_addr(), lane_row_b, lane_col); + + MLX_UNROLL + for (int k = 1; k < BK / 16; k++) { + A[k & 1].load(As, As.base_addr(), lane_row_a, lane_col + k * 16); + B[k & 1].load(Bs, Bs.base_addr(), lane_row_b, lane_col + k * 16); + + mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]); + } + mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]); +} + /** * An example gemm written with the utils. * @@ -11,7 +35,7 @@ namespace mlx::core::cu { */ template __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { - constexpr int WARPS_M = 2; + constexpr int WARPS_M = 4; constexpr int WARPS_N = 2; constexpr int NUM_WARPS = WARPS_M * WARPS_N; constexpr int WARP_STEP_M = BM / WARPS_M; @@ -24,6 +48,9 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { const int wn = warpid % WARPS_N; const int offset_m = wm * WARP_STEP_M; const int offset_n = wn * WARP_STEP_N; + const int lane_row_a = offset_m + (laneid & 15); + const int lane_row_b = offset_n + (laneid & 15); + const int lane_col = (laneid >> 4) << 3; // Allocate shared memory extern __shared__ char shmem[]; @@ -33,8 +60,6 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { // Allocate registers for the MMA RegisterTile C; - RegisterTile A; - RegisterTile B; // Move the global pointers to the tile a += blockIdx.y * BM * K; @@ -55,45 +80,17 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { load_async(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K); cp_async_commit(); cp_async_wait<1>(); - __syncthreads(); - MLX_UNROLL - for (int k = 0; k < BK / 16; k++) { - A.load( - as[tic], - as[tic].base_addr(), - offset_m + laneid % 16, - k * 16 + laneid / 16 * 8); - B.load( - bs[tic], - bs[tic].base_addr(), - offset_n + laneid % 16, - k * 16 + laneid / 16 * 8); - - mma_t(C, A, B); - } + gemm_ab_t( + C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col); tic ^= 1; } // Empty the pipeline cp_async_wait_all(); - __syncthreads(); - MLX_UNROLL - for (int k = 0; k < BK / 16; k++) { - A.load( - as[tic], - as[tic].base_addr(), - offset_m + laneid % 16, - k * 16 + laneid / 16 * 8); - B.load( - bs[tic], - bs[tic].base_addr(), - offset_n + laneid % 16, - k * 16 + laneid / 16 * 8); - - mma_t(C, A, B); - } + gemm_ab_t( + C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col); C.store_global(y, N, offset_m, offset_n); }