Masked mm (#978)

* Add block masked matmul op and primitive
This commit is contained in:
Jagrit Digani
2024-04-16 14:45:39 -07:00
committed by GitHub
parent 107ba2891a
commit b18468bf81
15 changed files with 1137 additions and 2 deletions

View File

@@ -0,0 +1,323 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
bool has_operand_mask=false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device bool *out_mask [[buffer(10)]],
const device bool *lhs_mask [[buffer(11)]],
const device bool *rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler
(void)lid;
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
if(params->batch_ndim > 1) {
const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim;
out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
if(has_operand_mask) {
const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim);
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
}
// Adjust for batch
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
// Write zeros and return
if(!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for(short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for(short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
bool M_aligned = (tgp_bm == BM);
bool N_aligned = (tgp_bn == BN);
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
if(M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \
[[kernel]] void block_masked_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, op_mask>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const device bool *out_mask [[buffer(10)]], \
const device bool *lhs_mask [[buffer(11)]], \
const device bool *rhs_mask [[buffer(12)]], \
const constant int* mask_strides [[buffer(13)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true)
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@@ -1064,4 +1064,181 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
using namespace mlx::steel;
// assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
copy_gpu(zero, out, CopyType::Scalar, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
int lda = a_cols;
int ldb = b_cols;
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
auto batch_size_out = out.size() / (M * N);
int matrix_stride_out = M * N;
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = block_size_, bn = block_size_, bk = 16;
int wm = 2, wn = 2;
// Prepare kernel name
std::ostringstream kname;
kname << "steel_block_masked_gemm_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
<< (inputs.size() > 3 ? "T" : "N");
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
std::vector<size_t> batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
std::vector<int> mask_strides;
auto& out_mask = inputs[2];
mask_strides.push_back(*(out_mask.strides().end() - 1));
mask_strides.push_back(*(out_mask.strides().end() - 2));
batch_strides.insert(
batch_strides.end(),
out_mask.strides().begin(),
out_mask.strides().end() - 2);
if (inputs.size() > 3) {
auto& lhs_mask = inputs[3];
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
batch_strides.insert(
batch_strides.end(),
lhs_mask.strides().begin(),
lhs_mask.strides().end() - 2);
compute_encoder.set_input_array(lhs_mask, 11);
auto& rhs_mask = inputs[4];
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
batch_strides.insert(
batch_strides.end(),
rhs_mask.strides().begin(),
rhs_mask.strides().end() - 2);
compute_encoder.set_input_array(rhs_mask, 12);
}
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_input_array(out_mask, 10);
set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
} // namespace mlx::core