Masked mm (#978)

* Add block masked matmul op and primitive
This commit is contained in:
Jagrit Digani
2024-04-16 14:45:39 -07:00
committed by GitHub
parent 107ba2891a
commit b18468bf81
15 changed files with 1137 additions and 2 deletions

View File

@@ -3572,6 +3572,172 @@ array addmm(
return out;
}
/** Compute matrix product with tile-level masking */
array block_masked_mm(
array a,
array b,
int block_size,
std::optional<array> mask_out /* = std::nullopt */,
std::optional<array> mask_lhs /* = std::nullopt */,
std::optional<array> mask_rhs /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
// If no masks, just perform regular matmul
if (!mask_out && !mask_lhs && !mask_rhs) {
return matmul(a, b, s);
}
bool has_out_mask = mask_out.has_value();
bool has_operand_mask = mask_lhs.has_value() || mask_rhs.has_value();
// Check valid tile sizes
// TODO: Add support for 16x16 tile
if (block_size != 32 && block_size != 64) {
std::ostringstream msg;
msg << "[block_masked_mm] Only block_sizes 32, 64 are supported."
<< "Got block size " << block_size << ".";
throw std::invalid_argument(msg.str());
}
// Do shape checks for operands
int in_a_ndim = a.ndim();
int in_b_ndim = b.ndim();
if (a.ndim() == 0 || b.ndim() == 0) {
throw std::invalid_argument(
"[block_masked_mm] Got 0 dimension input. Inputs must "
"have at least one dimension.");
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[block_masked_mm] Last dimension of first input with shape "
<< a.shape() << " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = result_type(a, b);
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[block_masked_mm] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype()
<< " were provided which results in " << out_type
<< ", which is not a real floating point type.";
throw std::invalid_argument(msg.str());
}
a = astype(a, out_type, s);
b = astype(b, out_type, s);
// Handle broadcasting
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
auto bsx_shape = broadcast_shapes(bsx_a, bsx_b);
bsx_shape.push_back(1);
bsx_shape.push_back(1);
int nd = bsx_shape.size();
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
// Prepare A
bsx_shape[nd - 2] = M;
bsx_shape[nd - 1] = K;
a = broadcast_to(a, bsx_shape, s);
// Prepare B
bsx_shape[nd - 2] = K;
bsx_shape[nd - 1] = N;
b = broadcast_to(b, bsx_shape, s);
// Get output shape
auto out_shape = bsx_shape;
out_shape[nd - 2] = M;
out_shape[nd - 1] = N;
// Determine mask shape requirments
int tm = (M + block_size - 1) / block_size;
int tn = (N + block_size - 1) / block_size;
int tk = (K + block_size - 1) / block_size;
// Broadcast and astype mask
auto broadcast_mask = [](array mask,
std::vector<int>& bs_shape,
int y,
int x,
StreamOrDevice s) {
int nd_bsx = bs_shape.size();
bs_shape[nd_bsx - 2] = y;
bs_shape[nd_bsx - 1] = x;
mask = astype(mask, bool_, s);
return broadcast_to(mask, bs_shape, s);
};
// Out mask
array mask_out_p = mask_out.value_or(array({true}));
if (in_a_ndim == 1 || in_b_ndim == 1) {
std::vector<int> ex_dims;
if (in_a_ndim == 1)
ex_dims.push_back(-2);
if (in_b_ndim == 1)
ex_dims.push_back(-1);
mask_out_p = expand_dims(mask_out_p, ex_dims, s);
}
mask_out_p = broadcast_mask(mask_out_p, bsx_shape, tm, tn, s);
std::vector<array> inputs = {a, b, mask_out_p};
// Operand masks
if (has_operand_mask) {
// LHS mask
array mask_lhs_p = mask_lhs.value_or(array({true}));
if (in_a_ndim == 1) {
mask_lhs_p = expand_dims(mask_lhs_p, -2, s);
}
mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, s);
// RHS mask
array mask_rhs_p = mask_rhs.value_or(array({true}));
if (in_b_ndim == 1) {
mask_rhs_p = expand_dims(mask_lhs_p, -1, s);
}
mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, s);
inputs.push_back(mask_lhs_p);
inputs.push_back(mask_rhs_p);
}
// Caculate array
auto out = array(
out_shape,
out_type,
std::make_shared<BlockMaskedMM>(to_stream(s), block_size),
std::move(inputs));
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out_shape.erase(
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
}
return out;
}
array diagonal(
const array& a,
int offset /* = 0 */,