mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
46
mlx/ops.cpp
46
mlx/ops.cpp
@@ -3870,48 +3870,60 @@ array block_masked_mm(
|
||||
int tn = (N + block_size - 1) / block_size;
|
||||
int tk = (K + block_size - 1) / block_size;
|
||||
|
||||
std::vector<array> inputs = {a, b};
|
||||
|
||||
// Broadcast and astype mask
|
||||
auto broadcast_mask = [](array mask,
|
||||
std::vector<int>& bs_shape,
|
||||
int y,
|
||||
int x,
|
||||
Dtype mask_dtype,
|
||||
StreamOrDevice s) {
|
||||
int nd_bsx = bs_shape.size();
|
||||
bs_shape[nd_bsx - 2] = y;
|
||||
bs_shape[nd_bsx - 1] = x;
|
||||
mask = astype(mask, bool_, s);
|
||||
mask = astype(mask, mask_dtype, s);
|
||||
return broadcast_to(mask, bs_shape, s);
|
||||
};
|
||||
|
||||
// Out mask
|
||||
array mask_out_p = mask_out.value_or(array({true}));
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
std::vector<int> ex_dims;
|
||||
if (in_a_ndim == 1)
|
||||
ex_dims.push_back(-2);
|
||||
if (in_b_ndim == 1)
|
||||
ex_dims.push_back(-1);
|
||||
mask_out_p = expand_dims(mask_out_p, ex_dims, s);
|
||||
}
|
||||
mask_out_p = broadcast_mask(mask_out_p, bsx_shape, tm, tn, s);
|
||||
if (mask_out.has_value()) {
|
||||
array mask_out_p = mask_out.value_or(array({true}));
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
std::vector<int> ex_dims;
|
||||
if (in_a_ndim == 1)
|
||||
ex_dims.push_back(-2);
|
||||
if (in_b_ndim == 1)
|
||||
ex_dims.push_back(-1);
|
||||
mask_out_p = expand_dims(mask_out_p, ex_dims, s);
|
||||
}
|
||||
auto maskout_dtype = mask_out_p.dtype() == bool_ ? bool_ : out_type;
|
||||
mask_out_p =
|
||||
broadcast_mask(mask_out_p, bsx_shape, tm, tn, maskout_dtype, s);
|
||||
|
||||
std::vector<array> inputs = {a, b, mask_out_p};
|
||||
inputs.push_back(mask_out_p);
|
||||
}
|
||||
|
||||
// Operand masks
|
||||
if (has_operand_mask) {
|
||||
// LHS mask
|
||||
// Pull masks
|
||||
array mask_lhs_p = mask_lhs.value_or(array({true}));
|
||||
array mask_rhs_p = mask_rhs.value_or(array({true}));
|
||||
auto mask_dtype =
|
||||
(mask_lhs_p.dtype() == bool_ && mask_rhs_p.dtype() == bool_) ? bool_
|
||||
: out_type;
|
||||
|
||||
// LHS mask
|
||||
if (in_a_ndim == 1) {
|
||||
mask_lhs_p = expand_dims(mask_lhs_p, -2, s);
|
||||
}
|
||||
mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, s);
|
||||
mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, mask_dtype, s);
|
||||
|
||||
// RHS mask
|
||||
array mask_rhs_p = mask_rhs.value_or(array({true}));
|
||||
if (in_b_ndim == 1) {
|
||||
mask_rhs_p = expand_dims(mask_lhs_p, -1, s);
|
||||
mask_rhs_p = expand_dims(mask_rhs_p, -1, s);
|
||||
}
|
||||
mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, s);
|
||||
mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, mask_dtype, s);
|
||||
|
||||
inputs.push_back(mask_lhs_p);
|
||||
inputs.push_back(mask_rhs_p);
|
||||
|
||||
Reference in New Issue
Block a user