mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 12:53:37 +08:00
Improve gemm
This commit is contained in:
parent
97afe40b7b
commit
6fce01593a
@ -26,15 +26,13 @@ void simple_gemm(
|
||||
|
||||
auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304);
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
||||
|
||||
dim3 grid(N / BN, M / BM);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
4 * WARP_SIZE,
|
||||
8 * WARP_SIZE,
|
||||
2 * sizeof(DataType) * (BM * BK + BN * BK),
|
||||
a.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
|
@ -4,6 +4,30 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T, int BM, int BN, int BK, int WM, int WN>
|
||||
__device__ inline void gemm_ab_t(
|
||||
RegisterTile<float, BM / WM, BN / WN>& C,
|
||||
SharedTile<T, BM, BK>& As,
|
||||
SharedTile<T, BM, BK>& Bs,
|
||||
int lane_row_a,
|
||||
int lane_row_b,
|
||||
int lane_col) {
|
||||
RegisterTile<T, BM / WM, 16> A[2];
|
||||
RegisterTile<T, BN / WN, 16> B[2];
|
||||
|
||||
A[0].load(As, As.base_addr(), lane_row_a, lane_col);
|
||||
B[0].load(Bs, Bs.base_addr(), lane_row_b, lane_col);
|
||||
|
||||
MLX_UNROLL
|
||||
for (int k = 1; k < BK / 16; k++) {
|
||||
A[k & 1].load(As, As.base_addr(), lane_row_a, lane_col + k * 16);
|
||||
B[k & 1].load(Bs, Bs.base_addr(), lane_row_b, lane_col + k * 16);
|
||||
|
||||
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
|
||||
}
|
||||
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
|
||||
}
|
||||
|
||||
/**
|
||||
* An example gemm written with the utils.
|
||||
*
|
||||
@ -11,7 +35,7 @@ namespace mlx::core::cu {
|
||||
*/
|
||||
template <typename T, int BM, int BN, int BK>
|
||||
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
constexpr int WARPS_M = 2;
|
||||
constexpr int WARPS_M = 4;
|
||||
constexpr int WARPS_N = 2;
|
||||
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
||||
constexpr int WARP_STEP_M = BM / WARPS_M;
|
||||
@ -24,6 +48,9 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
const int wn = warpid % WARPS_N;
|
||||
const int offset_m = wm * WARP_STEP_M;
|
||||
const int offset_n = wn * WARP_STEP_N;
|
||||
const int lane_row_a = offset_m + (laneid & 15);
|
||||
const int lane_row_b = offset_n + (laneid & 15);
|
||||
const int lane_col = (laneid >> 4) << 3;
|
||||
|
||||
// Allocate shared memory
|
||||
extern __shared__ char shmem[];
|
||||
@ -33,8 +60,6 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
|
||||
// Allocate registers for the MMA
|
||||
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
||||
RegisterTile<T, BM / WARPS_M, 16> A;
|
||||
RegisterTile<T, BN / WARPS_N, 16> B;
|
||||
|
||||
// Move the global pointers to the tile
|
||||
a += blockIdx.y * BM * K;
|
||||
@ -55,45 +80,17 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
||||
cp_async_commit();
|
||||
cp_async_wait<1>();
|
||||
__syncthreads();
|
||||
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
as[tic],
|
||||
as[tic].base_addr(),
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
bs[tic],
|
||||
bs[tic].base_addr(),
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
|
||||
mma_t(C, A, B);
|
||||
}
|
||||
gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
|
||||
C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
|
||||
|
||||
tic ^= 1;
|
||||
}
|
||||
|
||||
// Empty the pipeline
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
as[tic],
|
||||
as[tic].base_addr(),
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
bs[tic],
|
||||
bs[tic].base_addr(),
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
|
||||
mma_t(C, A, B);
|
||||
}
|
||||
gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
|
||||
C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
|
||||
|
||||
C.store_global(y, N, offset_m, offset_n);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user