Float mask update (#1152)

* Float mask update

* Update CPU impl
This commit is contained in:
Jagrit Digani
2024-05-23 17:20:44 -07:00
committed by GitHub
parent 50dfb664db
commit eab2685c67
8 changed files with 713 additions and 253 deletions

View File

@@ -1307,7 +1307,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Check and collapse batch dimensions
bool has_op_mask = inputs.size() > 3;
auto& out_mask = inputs[2];
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
std::vector<int> batch_shape{1};
size_t A_batch_str = 0;
@@ -1350,14 +1350,17 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int wm = 2, wn = 2;
// Prepare kernel name
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
std::ostringstream kname;
kname << "steel_block_masked_gemm_" << (transpose_a ? 't' : 'n')
kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_"
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
<< (has_op_mask ? "T" : "N");
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -1397,17 +1400,23 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
std::vector<int> mask_strides;
mask_strides.push_back(*(out_mask.strides().end() - 1));
mask_strides.push_back(*(out_mask.strides().end() - 2));
if (has_out_mask) {
auto& out_mask = inputs[2];
mask_strides.push_back(*(out_mask.strides().end() - 1));
mask_strides.push_back(*(out_mask.strides().end() - 2));
compute_encoder.set_input_array(out_mask, 10);
}
if (has_op_mask) {
auto& lhs_mask = inputs[3];
auto& lhs_mask = inputs[2 + has_out_mask];
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
compute_encoder.set_input_array(lhs_mask, 11);
auto& rhs_mask = inputs[4];
auto& rhs_mask = inputs[3 + has_out_mask];
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
@@ -1424,7 +1433,6 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_input_array(out_mask, 10);
set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);