mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 02:46:40 +08:00

* Enable collapsing batch dims in gemm * Update gemm to only make copies when neither of the last 2 axes are contiguous * Update addmm to support gemv shapes * Update addmm to support irregular batch strides * Update tests
293 lines
8.2 KiB
C++
293 lines
8.2 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
|
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
|
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
|
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
|
|
using namespace metal;
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// GEMM kernel class
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace mlx {
|
|
namespace steel {
|
|
|
|
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
|
struct LoopAlignment {};
|
|
|
|
template <
|
|
typename T,
|
|
typename U,
|
|
int BM,
|
|
int BN,
|
|
int BK,
|
|
int WM,
|
|
int WN,
|
|
bool transpose_a,
|
|
bool transpose_b,
|
|
bool MN_aligned,
|
|
bool K_aligned,
|
|
typename AccumType = typename AccumHelper<T>::accum_type,
|
|
typename Epilogue = TransformNone<U, AccumType>>
|
|
struct GEMMKernel {
|
|
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
|
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
|
STEEL_CONST short tgp_mem_size_a =
|
|
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
|
STEEL_CONST short tgp_mem_size_b =
|
|
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
|
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
|
|
|
STEEL_CONST short tgp_size = WM * WN * 32;
|
|
|
|
using loader_a_t = BlockLoader<
|
|
T,
|
|
transpose_a ? BK : BM,
|
|
transpose_a ? BM : BK,
|
|
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
|
!transpose_a,
|
|
tgp_size>;
|
|
using loader_b_t = BlockLoader<
|
|
T,
|
|
transpose_b ? BN : BK,
|
|
transpose_b ? BK : BN,
|
|
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
|
transpose_b,
|
|
tgp_size>;
|
|
using mma_t = BlockMMA<
|
|
T,
|
|
U,
|
|
BM,
|
|
BN,
|
|
BK,
|
|
WM,
|
|
WN,
|
|
transpose_a,
|
|
transpose_b,
|
|
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
|
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
|
AccumType,
|
|
Epilogue>;
|
|
|
|
/* Main kernel function */
|
|
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
|
static METAL_FUNC void gemm_loop(
|
|
threadgroup T* As [[threadgroup(0)]],
|
|
threadgroup T* Bs [[threadgroup(1)]],
|
|
const int gemm_k_iterations,
|
|
thread loader_a_t& loader_a,
|
|
thread loader_b_t& loader_b,
|
|
thread mma_t& mma_op,
|
|
thread const short& tgp_bm,
|
|
thread const short& tgp_bn,
|
|
thread const short& lbk,
|
|
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
|
// Appease the compiler
|
|
(void)l;
|
|
|
|
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
|
|
|
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
|
|
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Load elements into threadgroup
|
|
if (M_aligned) {
|
|
loader_a.load_unsafe();
|
|
} else {
|
|
loader_a.load_safe(tile_dims_A);
|
|
}
|
|
|
|
if (N_aligned) {
|
|
loader_b.load_unsafe();
|
|
} else {
|
|
loader_b.load_safe(tile_dims_B);
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Multiply and accumulate threadgroup elements
|
|
mma_op.mma(As, Bs);
|
|
|
|
// Prepare for next iteration
|
|
loader_a.next();
|
|
loader_b.next();
|
|
}
|
|
|
|
if (!K_aligned_) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
short2 tile_dims_A_last =
|
|
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
|
short2 tile_dims_B_last =
|
|
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
|
|
|
loader_a.load_safe(tile_dims_A_last);
|
|
loader_b.load_safe(tile_dims_B_last);
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
mma_op.mma(As, Bs);
|
|
}
|
|
}
|
|
|
|
/* Main kernel function */
|
|
static METAL_FUNC void run(
|
|
const device T* A [[buffer(0)]],
|
|
const device T* B [[buffer(1)]],
|
|
device U* D [[buffer(2)]],
|
|
const constant GEMMParams* params [[buffer(3)]],
|
|
threadgroup T* As [[threadgroup(0)]],
|
|
threadgroup T* Bs [[threadgroup(1)]],
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
// Pacifying compiler
|
|
(void)lid;
|
|
|
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
|
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
return;
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_none);
|
|
|
|
// Find block in A, B, C
|
|
const int c_row = tid_y * BM;
|
|
const int c_col = tid_x * BN;
|
|
|
|
A += transpose_a ? c_row : c_row * params->lda;
|
|
B += transpose_b ? c_col * params->ldb : c_col;
|
|
D += c_row * params->ldd + c_col;
|
|
|
|
// Prepare threadgroup loading operations
|
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
|
|
// Prepare threadgroup mma operation
|
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
|
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// MNK aligned loop
|
|
if (MN_aligned) {
|
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Load elements into threadgroup
|
|
loader_a.load_unsafe();
|
|
loader_b.load_unsafe();
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Multiply and accumulate threadgroup elements
|
|
mma_op.mma(As, Bs);
|
|
|
|
// Prepare for next iteration
|
|
loader_a.next();
|
|
loader_b.next();
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_none);
|
|
|
|
// Loop tail
|
|
if (!K_aligned) {
|
|
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
|
|
|
loader_a.load_safe(tile_dims_A);
|
|
loader_b.load_safe(tile_dims_B);
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
mma_op.mma(As, Bs);
|
|
}
|
|
|
|
// Store results to device memory
|
|
mma_op.store_result(D, params->ldd);
|
|
return;
|
|
|
|
}
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// MN unaligned loop
|
|
else { // Loop over K - unaligned case
|
|
short tgp_bm = min(BM, params->M - c_row);
|
|
short tgp_bn = min(BN, params->N - c_col);
|
|
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
|
|
if (tgp_bm == BM && tgp_bn == BN) {
|
|
gemm_loop<true, true, K_aligned>(
|
|
As,
|
|
Bs,
|
|
gemm_k_iterations,
|
|
loader_a,
|
|
loader_b,
|
|
mma_op,
|
|
tgp_bm,
|
|
tgp_bn,
|
|
leftover_bk);
|
|
|
|
mma_op.store_result(D, params->ldd);
|
|
return;
|
|
|
|
} else if (tgp_bn == BN) {
|
|
gemm_loop<false, true, K_aligned>(
|
|
As,
|
|
Bs,
|
|
gemm_k_iterations,
|
|
loader_a,
|
|
loader_b,
|
|
mma_op,
|
|
tgp_bm,
|
|
tgp_bn,
|
|
leftover_bk);
|
|
|
|
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
return;
|
|
|
|
} else if (tgp_bm == BM) {
|
|
gemm_loop<true, false, K_aligned>(
|
|
As,
|
|
Bs,
|
|
gemm_k_iterations,
|
|
loader_a,
|
|
loader_b,
|
|
mma_op,
|
|
tgp_bm,
|
|
tgp_bn,
|
|
leftover_bk);
|
|
|
|
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
return;
|
|
|
|
} else {
|
|
gemm_loop<false, false, K_aligned>(
|
|
As,
|
|
Bs,
|
|
gemm_k_iterations,
|
|
loader_a,
|
|
loader_b,
|
|
mma_op,
|
|
tgp_bm,
|
|
tgp_bn,
|
|
leftover_bk);
|
|
|
|
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace steel
|
|
} // namespace mlx
|