mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 12:32:30 +08:00
Improve gemm
This commit is contained in:
parent
97afe40b7b
commit
6fce01593a
@ -26,15 +26,13 @@ void simple_gemm(
|
|||||||
|
|
||||||
auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
|
auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
|
||||||
cudaFuncSetAttribute(
|
cudaFuncSetAttribute(
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304);
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
||||||
cudaFuncSetAttribute(
|
|
||||||
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
|
||||||
|
|
||||||
dim3 grid(N / BN, M / BM);
|
dim3 grid(N / BN, M / BM);
|
||||||
enc.add_kernel_node(
|
enc.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
4 * WARP_SIZE,
|
8 * WARP_SIZE,
|
||||||
2 * sizeof(DataType) * (BM * BK + BN * BK),
|
2 * sizeof(DataType) * (BM * BK + BN * BK),
|
||||||
a.data<DataType>(),
|
a.data<DataType>(),
|
||||||
b.data<DataType>(),
|
b.data<DataType>(),
|
||||||
|
@ -4,6 +4,30 @@
|
|||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
template <typename T, int BM, int BN, int BK, int WM, int WN>
|
||||||
|
__device__ inline void gemm_ab_t(
|
||||||
|
RegisterTile<float, BM / WM, BN / WN>& C,
|
||||||
|
SharedTile<T, BM, BK>& As,
|
||||||
|
SharedTile<T, BM, BK>& Bs,
|
||||||
|
int lane_row_a,
|
||||||
|
int lane_row_b,
|
||||||
|
int lane_col) {
|
||||||
|
RegisterTile<T, BM / WM, 16> A[2];
|
||||||
|
RegisterTile<T, BN / WN, 16> B[2];
|
||||||
|
|
||||||
|
A[0].load(As, As.base_addr(), lane_row_a, lane_col);
|
||||||
|
B[0].load(Bs, Bs.base_addr(), lane_row_b, lane_col);
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 1; k < BK / 16; k++) {
|
||||||
|
A[k & 1].load(As, As.base_addr(), lane_row_a, lane_col + k * 16);
|
||||||
|
B[k & 1].load(Bs, Bs.base_addr(), lane_row_b, lane_col + k * 16);
|
||||||
|
|
||||||
|
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
|
||||||
|
}
|
||||||
|
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An example gemm written with the utils.
|
* An example gemm written with the utils.
|
||||||
*
|
*
|
||||||
@ -11,7 +35,7 @@ namespace mlx::core::cu {
|
|||||||
*/
|
*/
|
||||||
template <typename T, int BM, int BN, int BK>
|
template <typename T, int BM, int BN, int BK>
|
||||||
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||||
constexpr int WARPS_M = 2;
|
constexpr int WARPS_M = 4;
|
||||||
constexpr int WARPS_N = 2;
|
constexpr int WARPS_N = 2;
|
||||||
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
||||||
constexpr int WARP_STEP_M = BM / WARPS_M;
|
constexpr int WARP_STEP_M = BM / WARPS_M;
|
||||||
@ -24,6 +48,9 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
|||||||
const int wn = warpid % WARPS_N;
|
const int wn = warpid % WARPS_N;
|
||||||
const int offset_m = wm * WARP_STEP_M;
|
const int offset_m = wm * WARP_STEP_M;
|
||||||
const int offset_n = wn * WARP_STEP_N;
|
const int offset_n = wn * WARP_STEP_N;
|
||||||
|
const int lane_row_a = offset_m + (laneid & 15);
|
||||||
|
const int lane_row_b = offset_n + (laneid & 15);
|
||||||
|
const int lane_col = (laneid >> 4) << 3;
|
||||||
|
|
||||||
// Allocate shared memory
|
// Allocate shared memory
|
||||||
extern __shared__ char shmem[];
|
extern __shared__ char shmem[];
|
||||||
@ -33,8 +60,6 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
|||||||
|
|
||||||
// Allocate registers for the MMA
|
// Allocate registers for the MMA
|
||||||
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
||||||
RegisterTile<T, BM / WARPS_M, 16> A;
|
|
||||||
RegisterTile<T, BN / WARPS_N, 16> B;
|
|
||||||
|
|
||||||
// Move the global pointers to the tile
|
// Move the global pointers to the tile
|
||||||
a += blockIdx.y * BM * K;
|
a += blockIdx.y * BM * K;
|
||||||
@ -55,45 +80,17 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
|||||||
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
||||||
cp_async_commit();
|
cp_async_commit();
|
||||||
cp_async_wait<1>();
|
cp_async_wait<1>();
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
MLX_UNROLL
|
gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
|
||||||
for (int k = 0; k < BK / 16; k++) {
|
C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
|
||||||
A.load(
|
|
||||||
as[tic],
|
|
||||||
as[tic].base_addr(),
|
|
||||||
offset_m + laneid % 16,
|
|
||||||
k * 16 + laneid / 16 * 8);
|
|
||||||
B.load(
|
|
||||||
bs[tic],
|
|
||||||
bs[tic].base_addr(),
|
|
||||||
offset_n + laneid % 16,
|
|
||||||
k * 16 + laneid / 16 * 8);
|
|
||||||
|
|
||||||
mma_t(C, A, B);
|
|
||||||
}
|
|
||||||
|
|
||||||
tic ^= 1;
|
tic ^= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty the pipeline
|
// Empty the pipeline
|
||||||
cp_async_wait_all();
|
cp_async_wait_all();
|
||||||
__syncthreads();
|
gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>(
|
||||||
MLX_UNROLL
|
C, as[tic], bs[tic], lane_row_a, lane_row_b, lane_col);
|
||||||
for (int k = 0; k < BK / 16; k++) {
|
|
||||||
A.load(
|
|
||||||
as[tic],
|
|
||||||
as[tic].base_addr(),
|
|
||||||
offset_m + laneid % 16,
|
|
||||||
k * 16 + laneid / 16 * 8);
|
|
||||||
B.load(
|
|
||||||
bs[tic],
|
|
||||||
bs[tic].base_addr(),
|
|
||||||
offset_n + laneid % 16,
|
|
||||||
k * 16 + laneid / 16 * 8);
|
|
||||||
|
|
||||||
mma_t(C, A, B);
|
|
||||||
}
|
|
||||||
|
|
||||||
C.store_global(y, N, offset_m, offset_n);
|
C.store_global(y, N, offset_m, offset_n);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user