#include "mlx/backend/cuda/steel/mma.cuh" #include "mlx/backend/cuda/steel/tiles.cuh" namespace mlx::core::cu { /** * An example gemm written with the utils. * * Computes A @ B.T when A and B are all aligned with the block sizes. */ template __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { constexpr int WARPS_M = 2; constexpr int WARPS_N = 2; constexpr int NUM_WARPS = WARPS_M * WARPS_N; constexpr int WARP_STEP_M = BM / WARPS_M; constexpr int WARP_STEP_N = BN / WARPS_N; // Precompute some offsets for each thread const int warpid = threadIdx.x / 32; const int laneid = threadIdx.x % 32; const int wm = warpid / WARPS_N; const int wn = warpid % WARPS_N; const int offset_m = wm * WARP_STEP_M; const int offset_n = wn * WARP_STEP_N; // Allocate shared memory extern __shared__ char shmem[]; SharedTile(&as)[2] = *(SharedTile(*)[2])(&shmem[0]); SharedTile(&bs)[2] = *(SharedTile(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]); // Allocate registers for the MMA RegisterTile C; RegisterTile A; RegisterTile B; // Move the global pointers to the tile a += blockIdx.y * BM * K; b += blockIdx.x * BN * K; y += blockIdx.y * BM * N + blockIdx.x * BN; // Zero the accumulators C.fill(0); // Start the SM pipeline load_async(as[0], as[0].base_addr(), a, K); load_async(bs[0], bs[0].base_addr(), b, K); cp_async_commit(); int tic = 0; for (int k_block = BK; k_block < K; k_block += BK) { load_async(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K); load_async(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K); cp_async_commit(); cp_async_wait<1>(); __syncthreads(); MLX_UNROLL for (int k = 0; k < BK / 16; k++) { A.load( as[tic], as[tic].base_addr(), offset_m + laneid % 16, k * 16 + laneid / 16 * 8); B.load( bs[tic], bs[tic].base_addr(), offset_n + laneid % 16, k * 16 + laneid / 16 * 8); mma_t(C, A, B); } tic ^= 1; } // Empty the pipeline cp_async_wait_all(); __syncthreads(); MLX_UNROLL for (int k = 0; k < BK / 16; k++) { A.load( as[tic], as[tic].base_addr(), offset_m + laneid % 16, k * 16 + laneid / 16 * 8); B.load( bs[tic], bs[tic].base_addr(), offset_n + laneid % 16, k * 16 + laneid / 16 * 8); mma_t(C, A, B); } C.store_global(y, N, offset_m, offset_n); } } // namespace mlx::core::cu