mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Improve gemm
This commit is contained in:
		| @@ -26,15 +26,13 @@ void simple_gemm( | ||||
|  | ||||
|     auto kernel = ab_t_aligned<DataType, BM, BN, BK>; | ||||
|     cudaFuncSetAttribute( | ||||
|         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304); | ||||
|     cudaFuncSetAttribute( | ||||
|         kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); | ||||
|         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); | ||||
|  | ||||
|     dim3 grid(N / BN, M / BM); | ||||
|     enc.add_kernel_node( | ||||
|         kernel, | ||||
|         grid, | ||||
|         4 * WARP_SIZE, | ||||
|         8 * WARP_SIZE, | ||||
|         2 * sizeof(DataType) * (BM * BK + BN * BK), | ||||
|         a.data<DataType>(), | ||||
|         b.data<DataType>(), | ||||
|   | ||||
| @@ -4,6 +4,30 @@ | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
|  | ||||
| template <typename T, int BM, int BN, int BK, int WM, int WN> | ||||
| __device__ inline void gemm_ab_t( | ||||
|     RegisterTile<float, BM / WM, BN / WN>& C, | ||||
|     SharedTile<T, BM, BK>& As, | ||||
|     SharedTile<T, BM, BK>& Bs, | ||||
|     int lane_row_a, | ||||
|     int lane_row_b, | ||||
|     int lane_col) { | ||||
|   RegisterTile<T, BM / WM, 16> A[2]; | ||||
|   RegisterTile<T, BN / WN, 16> B[2]; | ||||
|  | ||||
|   A[0].load(As, As.base_addr(), lane_row_a, lane_col); | ||||
|   B[0].load(Bs, Bs.base_addr(), lane_row_b, lane_col); | ||||
|  | ||||
|   MLX_UNROLL | ||||
|   for (int k = 1; k < BK / 16; k++) { | ||||
|     A[k & 1].load(As, As.base_addr(), lane_row_a, lane_col + k * 16); | ||||
|     B[k & 1].load(Bs, Bs.base_addr(), lane_row_b, lane_col + k * 16); | ||||
|  | ||||
|     mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]); | ||||
|   } | ||||
|   mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]); | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * An example gemm written with the utils. | ||||
|  * | ||||
| @@ -11,7 +35,7 @@ namespace mlx::core::cu { | ||||
|  */ | ||||
| template <typename T, int BM, int BN, int BK> | ||||
| __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
|   constexpr int WARPS_M = 2; | ||||
|   constexpr int WARPS_M = 4; | ||||
|   constexpr int WARPS_N = 2; | ||||
|   constexpr int NUM_WARPS = WARPS_M * WARPS_N; | ||||
|   constexpr int WARP_STEP_M = BM / WARPS_M; | ||||
| @@ -24,6 +48,9 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
|   const int wn = warpid % WARPS_N; | ||||
|   const int offset_m = wm * WARP_STEP_M; | ||||
|   const int offset_n = wn * WARP_STEP_N; | ||||
|   const int lane_row_a = offset_m + (laneid & 15); | ||||
|   const int lane_row_b = offset_n + (laneid & 15); | ||||
|   const int lane_col = (laneid >> 4) << 3; | ||||
|  | ||||
|   // Allocate shared memory | ||||
|   extern __shared__ char shmem[]; | ||||
| @@ -33,8 +60,6 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
|  | ||||
|   // Allocate registers for the MMA | ||||
|   RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C; | ||||
|   RegisterTile<T, BM / WARPS_M, 16> A; | ||||
|   RegisterTile<T, BN / WARPS_N, 16> B; | ||||
|  | ||||
|   // Move the global pointers to the tile | ||||
|   a += blockIdx.y * BM * K; | ||||
| @@ -55,45 +80,17 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
|     load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K); | ||||
|     cp_async_commit(); | ||||
|     cp_async_wait<1>(); | ||||
|     __syncthreads(); | ||||
|  | ||||
|     MLX_UNROLL | ||||
|     for (int k = 0; k < BK / 16; k++) { | ||||
|       A.load( | ||||
|           as[tic], | ||||
|           as[tic].base_addr(), | ||||
|           offset_m + laneid % 16, | ||||
|           k * 16 + laneid / 16 * 8); | ||||
|       B.load( | ||||
|           bs[tic], | ||||
|           bs[tic].base_addr(), | ||||
|           offset_n + laneid % 16, | ||||
|           k * 16 + laneid / 16 * 8); | ||||
|  | ||||
|       mma_t(C, A, B); | ||||
|     } | ||||
|     gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>( | ||||
|         C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col); | ||||
|  | ||||
|     tic ^= 1; | ||||
|   } | ||||
|  | ||||
|   // Empty the pipeline | ||||
|   cp_async_wait_all(); | ||||
|   __syncthreads(); | ||||
|   MLX_UNROLL | ||||
|   for (int k = 0; k < BK / 16; k++) { | ||||
|     A.load( | ||||
|         as[tic], | ||||
|         as[tic].base_addr(), | ||||
|         offset_m + laneid % 16, | ||||
|         k * 16 + laneid / 16 * 8); | ||||
|     B.load( | ||||
|         bs[tic], | ||||
|         bs[tic].base_addr(), | ||||
|         offset_n + laneid % 16, | ||||
|         k * 16 + laneid / 16 * 8); | ||||
|  | ||||
|     mma_t(C, A, B); | ||||
|   } | ||||
|   gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>( | ||||
|       C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col); | ||||
|  | ||||
|   C.store_global(y, N, offset_m, offset_n); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos