Float mask update (#1152)

* Float mask update

* Update CPU impl
This commit is contained in:
Jagrit Digani
2024-05-23 17:20:44 -07:00
committed by GitHub
parent 50dfb664db
commit eab2685c67
8 changed files with 713 additions and 253 deletions

View File

@@ -11,8 +11,38 @@ using namespace mlx::steel;
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
struct _NoMask {
char x;
constexpr METAL_FUNC operator bool() {
return true;
}
constexpr METAL_FUNC operator bool() const threadgroup {
return true;
}
constexpr METAL_FUNC operator bool() const device {
return true;
}
constexpr METAL_FUNC operator bool() const constant {
return true;
}
};
template <typename OutT, typename InT = OutT>
struct ScaleOp {
OutT scale;
METAL_FUNC OutT apply(InT x) const {
return static_cast<OutT>(x) * scale;
}
};
typedef struct _NoMask nomask_t;
template <
typename T,
typename out_mask_t,
typename op_mask_t,
int BM,
int BN,
int BK,
@@ -21,8 +51,7 @@ template <
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
bool has_operand_mask = false>
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]],
@@ -31,9 +60,9 @@ block_masked_gemm(
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device bool* out_mask [[buffer(10)]],
const device bool* lhs_mask [[buffer(11)]],
const device bool* rhs_mask [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(10)]],
const device op_mask_t* lhs_mask [[buffer(11)]],
const device op_mask_t* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
@@ -42,6 +71,21 @@ block_masked_gemm(
// Appease the compiler
(void)lid;
static_assert(
BM == BN,
"block_masked_gemm must have the same block M and block N size");
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
constexpr bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
constexpr bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
constexpr short k_mask_factor = short(BM / BK);
using gemm_kernel = GEMMKernel<
T,
T,
@@ -63,15 +107,19 @@ block_masked_gemm(
return;
}
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
if (params->batch_ndim > 1) {
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
if (has_output_mask) {
out_mask += elem_to_loc(
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
mask_batch_strides += params->batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_lhs =
mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_lhs = mask_batch_strides;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
@@ -86,10 +134,14 @@ block_masked_gemm(
rhs_mask += batch_offsets.y;
}
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += params->batch_ndim;
}
if (has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
lhs_mask += tid.z * mask_batch_strides[0];
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
}
}
@@ -121,44 +173,69 @@ block_masked_gemm(
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
const constant int* out_mask_strides = mask_strides;
const constant int* lhs_mask_strides =
mask_strides + (has_output_mask ? 2 : 0);
const constant int* rhs_mask_strides =
lhs_mask_strides + (has_operand_mask ? 2 : 0);
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
const int out_mask_offset = !has_output_mask
? 0
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
short k_factor_cnt = k_mask_factor;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
ScaleOp<float> out_mask_op;
ScaleOp<T> lhs_mask_op;
ScaleOp<T> rhs_mask_op;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
if (has_mul_output_mask) {
out_mask_op.scale = float(mask_out);
}
return;
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
}
threadgroup_barrier(mem_flags::mem_none);
@@ -166,8 +243,6 @@ block_masked_gemm(
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
@@ -177,21 +252,88 @@ block_masked_gemm(
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm =
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
const short tgp_bn =
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// Do unaligned K iterations first
if (!K_aligned) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int mask_idx_last = k_last / BM;
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale =
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
rhs_mask_op.scale =
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
}
// Move loader source ahead to end
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
}
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
(bool(lhs_mask[lhs_mask_offset]) &&
bool(rhs_mask[rhs_mask_offset]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
}
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
@@ -201,29 +343,15 @@ block_masked_gemm(
// Prepare for next iteration
loader_a.next();
loader_b.next();
k_factor_cnt--;
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
if (has_mul_output_mask) {
mma_op.apply_epilogue(out_mask_op);
}
// Store results to device memory
@@ -233,24 +361,25 @@ block_masked_gemm(
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
else {
const bool M_aligned = (tgp_bm == BM);
const bool N_aligned = (tgp_bn == BN);
bool M_aligned = (tgp_bm == BM);
bool N_aligned = (tgp_bn == BN);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
(bool(lhs_mask[lhs_mask_offset]) &&
bool(rhs_mask[rhs_mask_offset]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
}
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
@@ -264,6 +393,11 @@ block_masked_gemm(
loader_b.load_safe(tile_dims_B);
}
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
@@ -273,29 +407,15 @@ block_masked_gemm(
// Prepare for next iteration
loader_a.next();
loader_b.next();
k_factor_cnt--;
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
}
if (!K_aligned) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
if (has_mul_output_mask) {
mma_op.apply_epilogue(out_mask_op);
}
if (M_aligned && N_aligned) {
@@ -311,6 +431,10 @@ block_masked_gemm(
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm( \
outmaskname, \
outmasktype, \
opmaskname, \
opmasktype, \
tname, \
trans_a, \
trans_b, \
@@ -326,15 +450,15 @@ block_masked_gemm(
aname, \
mn_aligned, \
kname, \
k_aligned, \
omname, \
op_mask) \
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \
k_aligned) \
template [[host_name("steel_gemm_block_outmask_" #outmaskname \
"_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname \
"_op_mask_" #omname)]] [[kernel]] void \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
block_masked_gemm< \
itype, \
outmasktype, \
opmasktype, \
bm, \
bn, \
bk, \
@@ -343,17 +467,16 @@ block_masked_gemm(
trans_a, \
trans_b, \
mn_aligned, \
k_aligned, \
op_mask>( \
k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const device bool* out_mask [[buffer(10)]], \
const device bool* lhs_mask [[buffer(11)]], \
const device bool* rhs_mask [[buffer(12)]], \
const device outmasktype* out_mask [[buffer(10)]], \
const device opmasktype* lhs_mask [[buffer(11)]], \
const device opmasktype* rhs_mask [[buffer(12)]], \
const constant int* mask_strides [[buffer(13)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
@@ -361,9 +484,15 @@ block_masked_gemm(
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(bool_, bool, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(iname, itype, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) // clang-format on
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \

View File

@@ -58,6 +58,18 @@ struct BlockLoader {
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL

View File

@@ -198,6 +198,24 @@ struct BlockMMA {
}
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
// Apply epilogue
accum[0] = epilogue_op.apply(accum[0]);
accum[1] = epilogue_op.apply(accum[1]);
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(

View File

@@ -1307,7 +1307,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Check and collapse batch dimensions
bool has_op_mask = inputs.size() > 3;
auto& out_mask = inputs[2];
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
std::vector<int> batch_shape{1};
size_t A_batch_str = 0;
@@ -1350,14 +1350,17 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int wm = 2, wn = 2;
// Prepare kernel name
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
std::ostringstream kname;
kname << "steel_block_masked_gemm_" << (transpose_a ? 't' : 'n')
kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_"
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
<< (has_op_mask ? "T" : "N");
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -1397,17 +1400,23 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
std::vector<int> mask_strides;
mask_strides.push_back(*(out_mask.strides().end() - 1));
mask_strides.push_back(*(out_mask.strides().end() - 2));
if (has_out_mask) {
auto& out_mask = inputs[2];
mask_strides.push_back(*(out_mask.strides().end() - 1));
mask_strides.push_back(*(out_mask.strides().end() - 2));
compute_encoder.set_input_array(out_mask, 10);
}
if (has_op_mask) {
auto& lhs_mask = inputs[3];
auto& lhs_mask = inputs[2 + has_out_mask];
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
compute_encoder.set_input_array(lhs_mask, 11);
auto& rhs_mask = inputs[4];
auto& rhs_mask = inputs[3 + has_out_mask];
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
@@ -1424,7 +1433,6 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_input_array(out_mask, 10);
set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);