mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	More pipelining for the sm_80 gemm
This commit is contained in:
		| @@ -22,7 +22,7 @@ void simple_gemm( | ||||
|     using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||||
|     constexpr int BM = 128; | ||||
|     constexpr int BN = 128; | ||||
|     constexpr int BK = 64; | ||||
|     constexpr int BK = 32; | ||||
|  | ||||
|     auto kernel = ab_t_aligned<DataType, BM, BN, BK>; | ||||
|     cudaFuncSetAttribute( | ||||
| @@ -33,7 +33,7 @@ void simple_gemm( | ||||
|         kernel, | ||||
|         grid, | ||||
|         8 * WARP_SIZE, | ||||
|         2 * sizeof(DataType) * (BM * BK + BN * BK), | ||||
|         4 * sizeof(DataType) * (BM * BK + BN * BK), | ||||
|         a.data<DataType>(), | ||||
|         b.data<DataType>(), | ||||
|         out.data<DataType>(), | ||||
|   | ||||
| @@ -40,6 +40,7 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
|   constexpr int NUM_WARPS = WARPS_M * WARPS_N; | ||||
|   constexpr int WARP_STEP_M = BM / WARPS_M; | ||||
|   constexpr int WARP_STEP_N = BN / WARPS_N; | ||||
|   constexpr int PIPE = 4; | ||||
|  | ||||
|   // Precompute some offsets for each thread | ||||
|   const int warpid = threadIdx.x / 32; | ||||
| @@ -54,43 +55,49 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
|  | ||||
|   // Allocate shared memory | ||||
|   extern __shared__ char shmem[]; | ||||
|   SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]); | ||||
|   SharedTile<T, BN, BK>(&bs)[2] = | ||||
|       *(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]); | ||||
|  | ||||
|   // Allocate registers for the MMA | ||||
|   RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C; | ||||
|   SharedTile<T, BM, BK>(&as)[PIPE] = | ||||
|       *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]); | ||||
|   SharedTile<T, BN, BK>(&bs)[PIPE] = | ||||
|       *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]); | ||||
|  | ||||
|   // Move the global pointers to the tile | ||||
|   a += blockIdx.y * BM * K; | ||||
|   b += blockIdx.x * BN * K; | ||||
|   y += blockIdx.y * BM * N + blockIdx.x * BN; | ||||
|  | ||||
|   // Zero the accumulators | ||||
|   C.fill(0); | ||||
|  | ||||
|   // Start the SM pipeline | ||||
|   load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K); | ||||
|   load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K); | ||||
|   cp_async_commit(); | ||||
|  | ||||
|   int tic = 0; | ||||
|   for (int k_block = BK; k_block < K; k_block += BK) { | ||||
|     load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K); | ||||
|     load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K); | ||||
|   MLX_UNROLL | ||||
|   for (int i = 0; i < PIPE - 1; i++) { | ||||
|     load_async<NUM_WARPS>(as[i], as[i].base_addr(), a + i * BK, K); | ||||
|     load_async<NUM_WARPS>(bs[i], bs[i].base_addr(), b + i * BK, K); | ||||
|     cp_async_commit(); | ||||
|     cp_async_wait<1>(); | ||||
|  | ||||
|     gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>( | ||||
|         C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col); | ||||
|  | ||||
|     tic ^= 1; | ||||
|   } | ||||
|  | ||||
|   // Empty the pipeline | ||||
|   cp_async_wait_all(); | ||||
|   gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>( | ||||
|       C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col); | ||||
|   // Allocate and zero the MMA accumulator | ||||
|   RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C; | ||||
|   C.fill(0); | ||||
|  | ||||
|   // Matmul loop | ||||
|   int num_blocks = K / BK; | ||||
|   int k_block = (PIPE - 1) * BK; | ||||
|   int sread = 0; | ||||
|   int swrite = PIPE - 1; | ||||
|   for (int i = 0; i < num_blocks; i++) { | ||||
|     cp_async_wait<PIPE - 2>(); | ||||
|  | ||||
|     if (k_block < K) { | ||||
|       load_async<NUM_WARPS>(as[swrite], as[swrite].base_addr(), a + k_block, K); | ||||
|       load_async<NUM_WARPS>(bs[swrite], bs[swrite].base_addr(), b + k_block, K); | ||||
|     } | ||||
|  | ||||
|     gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>( | ||||
|         C, as[sread], bs[sread], lane_row_a, lane_row_b, lane_col); | ||||
|  | ||||
|     cp_async_commit(); | ||||
|  | ||||
|     swrite = sread; | ||||
|     sread = (sread + 1) % PIPE; | ||||
|   } | ||||
|  | ||||
|   C.store_global(y, N, offset_m, offset_n); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos