mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
166
mlx/ops.cpp
166
mlx/ops.cpp
@@ -3572,6 +3572,172 @@ array addmm(
|
||||
return out;
|
||||
}
|
||||
|
||||
/** Compute matrix product with tile-level masking */
|
||||
array block_masked_mm(
|
||||
array a,
|
||||
array b,
|
||||
int block_size,
|
||||
std::optional<array> mask_out /* = std::nullopt */,
|
||||
std::optional<array> mask_lhs /* = std::nullopt */,
|
||||
std::optional<array> mask_rhs /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// If no masks, just perform regular matmul
|
||||
if (!mask_out && !mask_lhs && !mask_rhs) {
|
||||
return matmul(a, b, s);
|
||||
}
|
||||
|
||||
bool has_out_mask = mask_out.has_value();
|
||||
bool has_operand_mask = mask_lhs.has_value() || mask_rhs.has_value();
|
||||
|
||||
// Check valid tile sizes
|
||||
// TODO: Add support for 16x16 tile
|
||||
if (block_size != 32 && block_size != 64) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_masked_mm] Only block_sizes 32, 64 are supported."
|
||||
<< "Got block size " << block_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Do shape checks for operands
|
||||
int in_a_ndim = a.ndim();
|
||||
int in_b_ndim = b.ndim();
|
||||
|
||||
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[block_masked_mm] Got 0 dimension input. Inputs must "
|
||||
"have at least one dimension.");
|
||||
}
|
||||
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = reshape(a, {1, -1}, s);
|
||||
}
|
||||
if (b.ndim() == 1) {
|
||||
// Insert a singleton dim at the end
|
||||
b = reshape(b, {-1, 1}, s);
|
||||
}
|
||||
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_masked_mm] Last dimension of first input with shape "
|
||||
<< a.shape() << " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_masked_mm] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype()
|
||||
<< " were provided which results in " << out_type
|
||||
<< ", which is not a real floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
// Handle broadcasting
|
||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
|
||||
auto bsx_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||
|
||||
bsx_shape.push_back(1);
|
||||
bsx_shape.push_back(1);
|
||||
int nd = bsx_shape.size();
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
// Prepare A
|
||||
bsx_shape[nd - 2] = M;
|
||||
bsx_shape[nd - 1] = K;
|
||||
a = broadcast_to(a, bsx_shape, s);
|
||||
|
||||
// Prepare B
|
||||
bsx_shape[nd - 2] = K;
|
||||
bsx_shape[nd - 1] = N;
|
||||
b = broadcast_to(b, bsx_shape, s);
|
||||
|
||||
// Get output shape
|
||||
auto out_shape = bsx_shape;
|
||||
out_shape[nd - 2] = M;
|
||||
out_shape[nd - 1] = N;
|
||||
|
||||
// Determine mask shape requirments
|
||||
int tm = (M + block_size - 1) / block_size;
|
||||
int tn = (N + block_size - 1) / block_size;
|
||||
int tk = (K + block_size - 1) / block_size;
|
||||
|
||||
// Broadcast and astype mask
|
||||
auto broadcast_mask = [](array mask,
|
||||
std::vector<int>& bs_shape,
|
||||
int y,
|
||||
int x,
|
||||
StreamOrDevice s) {
|
||||
int nd_bsx = bs_shape.size();
|
||||
bs_shape[nd_bsx - 2] = y;
|
||||
bs_shape[nd_bsx - 1] = x;
|
||||
mask = astype(mask, bool_, s);
|
||||
return broadcast_to(mask, bs_shape, s);
|
||||
};
|
||||
|
||||
// Out mask
|
||||
array mask_out_p = mask_out.value_or(array({true}));
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
std::vector<int> ex_dims;
|
||||
if (in_a_ndim == 1)
|
||||
ex_dims.push_back(-2);
|
||||
if (in_b_ndim == 1)
|
||||
ex_dims.push_back(-1);
|
||||
mask_out_p = expand_dims(mask_out_p, ex_dims, s);
|
||||
}
|
||||
mask_out_p = broadcast_mask(mask_out_p, bsx_shape, tm, tn, s);
|
||||
|
||||
std::vector<array> inputs = {a, b, mask_out_p};
|
||||
|
||||
// Operand masks
|
||||
if (has_operand_mask) {
|
||||
// LHS mask
|
||||
array mask_lhs_p = mask_lhs.value_or(array({true}));
|
||||
if (in_a_ndim == 1) {
|
||||
mask_lhs_p = expand_dims(mask_lhs_p, -2, s);
|
||||
}
|
||||
mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, s);
|
||||
|
||||
// RHS mask
|
||||
array mask_rhs_p = mask_rhs.value_or(array({true}));
|
||||
if (in_b_ndim == 1) {
|
||||
mask_rhs_p = expand_dims(mask_lhs_p, -1, s);
|
||||
}
|
||||
mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, s);
|
||||
|
||||
inputs.push_back(mask_lhs_p);
|
||||
inputs.push_back(mask_rhs_p);
|
||||
}
|
||||
|
||||
// Caculate array
|
||||
auto out = array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_shared<BlockMaskedMM>(to_stream(s), block_size),
|
||||
std::move(inputs));
|
||||
|
||||
// Remove the possibly inserted singleton dimensions
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
out_shape.erase(
|
||||
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
|
||||
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
|
||||
out = reshape(out, out_shape, s);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
array diagonal(
|
||||
const array& a,
|
||||
int offset /* = 0 */,
|
||||
|
||||
Reference in New Issue
Block a user