Float mask update (#1152)

* Float mask update

* Update CPU impl
This commit is contained in:
Jagrit Digani
2024-05-23 17:20:44 -07:00
committed by GitHub
parent 50dfb664db
commit eab2685c67
8 changed files with 713 additions and 253 deletions

View File

@@ -3870,48 +3870,60 @@ array block_masked_mm(
int tn = (N + block_size - 1) / block_size;
int tk = (K + block_size - 1) / block_size;
std::vector<array> inputs = {a, b};
// Broadcast and astype mask
auto broadcast_mask = [](array mask,
std::vector<int>& bs_shape,
int y,
int x,
Dtype mask_dtype,
StreamOrDevice s) {
int nd_bsx = bs_shape.size();
bs_shape[nd_bsx - 2] = y;
bs_shape[nd_bsx - 1] = x;
mask = astype(mask, bool_, s);
mask = astype(mask, mask_dtype, s);
return broadcast_to(mask, bs_shape, s);
};
// Out mask
array mask_out_p = mask_out.value_or(array({true}));
if (in_a_ndim == 1 || in_b_ndim == 1) {
std::vector<int> ex_dims;
if (in_a_ndim == 1)
ex_dims.push_back(-2);
if (in_b_ndim == 1)
ex_dims.push_back(-1);
mask_out_p = expand_dims(mask_out_p, ex_dims, s);
}
mask_out_p = broadcast_mask(mask_out_p, bsx_shape, tm, tn, s);
if (mask_out.has_value()) {
array mask_out_p = mask_out.value_or(array({true}));
if (in_a_ndim == 1 || in_b_ndim == 1) {
std::vector<int> ex_dims;
if (in_a_ndim == 1)
ex_dims.push_back(-2);
if (in_b_ndim == 1)
ex_dims.push_back(-1);
mask_out_p = expand_dims(mask_out_p, ex_dims, s);
}
auto maskout_dtype = mask_out_p.dtype() == bool_ ? bool_ : out_type;
mask_out_p =
broadcast_mask(mask_out_p, bsx_shape, tm, tn, maskout_dtype, s);
std::vector<array> inputs = {a, b, mask_out_p};
inputs.push_back(mask_out_p);
}
// Operand masks
if (has_operand_mask) {
// LHS mask
// Pull masks
array mask_lhs_p = mask_lhs.value_or(array({true}));
array mask_rhs_p = mask_rhs.value_or(array({true}));
auto mask_dtype =
(mask_lhs_p.dtype() == bool_ && mask_rhs_p.dtype() == bool_) ? bool_
: out_type;
// LHS mask
if (in_a_ndim == 1) {
mask_lhs_p = expand_dims(mask_lhs_p, -2, s);
}
mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, s);
mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, mask_dtype, s);
// RHS mask
array mask_rhs_p = mask_rhs.value_or(array({true}));
if (in_b_ndim == 1) {
mask_rhs_p = expand_dims(mask_lhs_p, -1, s);
mask_rhs_p = expand_dims(mask_rhs_p, -1, s);
}
mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, s);
mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, mask_dtype, s);
inputs.push_back(mask_lhs_p);
inputs.push_back(mask_rhs_p);