mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -1307,7 +1307,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
auto& out_mask = inputs[2];
|
||||
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||
|
||||
std::vector<int> batch_shape{1};
|
||||
size_t A_batch_str = 0;
|
||||
@@ -1350,14 +1350,17 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
// Prepare kernel name
|
||||
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
||||
std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_block_masked_gemm_" << (transpose_a ? 't' : 'n')
|
||||
kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_"
|
||||
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
|
||||
<< (has_op_mask ? "T" : "N");
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -1397,17 +1400,23 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
std::vector<int> mask_strides;
|
||||
mask_strides.push_back(*(out_mask.strides().end() - 1));
|
||||
mask_strides.push_back(*(out_mask.strides().end() - 2));
|
||||
|
||||
if (has_out_mask) {
|
||||
auto& out_mask = inputs[2];
|
||||
mask_strides.push_back(*(out_mask.strides().end() - 1));
|
||||
mask_strides.push_back(*(out_mask.strides().end() - 2));
|
||||
|
||||
compute_encoder.set_input_array(out_mask, 10);
|
||||
}
|
||||
|
||||
if (has_op_mask) {
|
||||
auto& lhs_mask = inputs[3];
|
||||
auto& lhs_mask = inputs[2 + has_out_mask];
|
||||
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
|
||||
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
|
||||
|
||||
compute_encoder.set_input_array(lhs_mask, 11);
|
||||
|
||||
auto& rhs_mask = inputs[4];
|
||||
auto& rhs_mask = inputs[3 + has_out_mask];
|
||||
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
|
||||
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
|
||||
|
||||
@@ -1424,7 +1433,6 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
|
||||
compute_encoder.set_input_array(out_mask, 10);
|
||||
set_vector_bytes(compute_encoder, mask_strides, 13);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
Reference in New Issue
Block a user