Float mask update (#1152)

* Float mask update

* Update CPU impl
This commit is contained in:
Jagrit Digani
2024-05-23 17:20:44 -07:00
committed by GitHub
parent 50dfb664db
commit eab2685c67
8 changed files with 713 additions and 253 deletions

View File

@@ -3487,42 +3487,251 @@ std::vector<array> BlockMaskedMM::vjp(
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
/////////////////////////////////////////////////////////////////////////////
// The operation that is done w/o intermediates by the primitive is
// - tm = (M + block_size - 1) // block_size; MP = tm * block_size;
// - tn = (N + block_size - 1) // block_size; NP = tn * block_size;
// - tm = (K + block_size - 1) // block_size; KP = tk * block_size;
// - mask_b <- mask broadcasted to block sizes
// - A_m = A [..., M, K] * mask_b_lhs [..., MP, KP]
// - B_m = B [..., K, N] * mask_b_rhs [..., KP, MP]
// - C = A_m [..., M, K] @ B_m [..., K, N]
// - C_m = C [..., M, N] * mask_b_out [..., MP, NP]
//
// The grads are therefore
// - dC_m = cotan [..., M, N]
// - dmask_b_out = cotan [..., M, N] * C [..., M, N]
// - dC = cotan [..., M, N] * mask_b_out [..., MP, NP]
// - dA_m = dC [..., M, N] @ B_m.T [..., N, K]
// - dB_m = A_m.T [..., K, M] @ dC [..., M, N]
// - dA = dA_m * mask_b_lhs [..., MP, KP]
// - dB = dB_m * mask_b_rhs [..., KP, MP]
// - dmask_b_lhs = dA_m [..., M, K] * A [..., M, K] // need [..., MP, KP]
// - dmask_b_rhs = dB_m [..., K, N] * B [..., K, N] // need [..., KP, NP]
//
// Observations:
// * If dmask_b_lhs is not needed, then dA can be calulated in one go as a
// as a block_masked_mm with mask_b_lhs as the out_mask without needing to
// materialize the intermediate dA_m. Similar for dB.
// * If dmask_b_lhs is needed, we need to materialize dA_m directly and then
// point-wise multiply with A. But the output needs to be padded
std::vector<array> vjps;
auto& cotan = cotangents[0];
std::vector<int> reorder(cotan.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
bool has_op_mask = primals.size() > 3;
bool has_out_mask = primals.size() == 3 || primals.size() == 5;
const int op_mask_idx = has_out_mask ? 3 : 2;
bool needs_lhs_mask_vjp = has_op_mask;
bool needs_rhs_mask_vjp = has_op_mask;
bool needs_lhs_vjp = false;
bool needs_rhs_vjp = false;
for (auto arg : argnums) {
needs_lhs_vjp = arg == 0;
needs_rhs_vjp = arg == 1;
needs_lhs_mask_vjp = arg == op_mask_idx;
needs_rhs_mask_vjp = arg == op_mask_idx + 1;
}
if ((needs_lhs_mask_vjp && primals[op_mask_idx].dtype() == bool_) ||
(needs_rhs_mask_vjp && primals[op_mask_idx + 1].dtype() == bool_)) {
throw std::invalid_argument(
"[BlockMaskedMM] Cannot calculate VJP with respect to boolean masks.");
}
auto expand_mask = [&](array mask, int Y, int X) {
// Exapnd mask
auto mask_reshape = mask.shape();
mask = expand_dims(mask, {-3, -1}, stream());
auto mask_shape = mask.shape();
int mask_ndim = mask_shape.size();
// Broadcast mask
mask_shape[mask_ndim - 1] = block_size_;
mask_shape[mask_ndim - 3] = block_size_;
mask = broadcast_to(mask, mask_shape, stream());
// Reshape mask to squeeze in braodcasted dims
mask_ndim = mask_reshape.size();
mask_reshape[mask_ndim - 2] *= block_size_;
mask_reshape[mask_ndim - 1] *= block_size_;
mask = reshape(mask, mask_reshape, stream());
// Slice mask
mask_reshape[mask_ndim - 2] = Y;
mask_reshape[mask_ndim - 1] = X;
mask = slice(mask, std::vector<int>(mask_ndim, 0), mask_reshape, stream());
return mask;
};
array zero = array(0, cotan.dtype());
auto multiply_pad_reduce = [&](array p, array q, int align_Y, int align_X) {
// Multiply with cotan
auto r = multiply(p, q, stream());
// Pad if needed
if ((align_Y != 0) || (align_X != 0)) {
r = pad(r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, stream());
}
// Reshape
std::vector<int> r_reshape(r.shape().begin(), r.shape().end() - 2);
r_reshape.push_back(r.shape(-2) / block_size_);
r_reshape.push_back(block_size_);
r_reshape.push_back(r.shape(-1) / block_size_);
r_reshape.push_back(block_size_);
r = reshape(r, r_reshape, stream());
// Reduce
return sum(r, {-3, -1}, false, stream());
};
// Prepare for padding if needed
int M = cotan.shape(-2);
int N = cotan.shape(-1);
int K = primals[0].shape(-1);
int align_M = (M % block_size_);
int align_N = (N % block_size_);
int align_K = (K % block_size_);
// Potential intermediates
auto unmasked_lhs_grad = primals[0];
auto unmasked_rhs_grad = primals[1];
bool unmasked_lhs_grad_calculated = false;
bool unmasked_rhs_grad_calculated = false;
for (auto arg : argnums) {
if (arg == 0) {
// M X N * (K X N).T -> M X K
auto b_t = transpose(primals[1], reorder, stream());
auto out_mask = primals[2];
auto lhs_mask =
has_op_mask ? std::make_optional<array>(primals[3]) : std::nullopt;
auto out_mask =
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
auto lhs_mask = has_op_mask && !needs_lhs_mask_vjp
? std::make_optional<array>(primals[op_mask_idx])
: std::nullopt;
auto rhs_mask_t = has_op_mask
? std::make_optional<array>(transpose(primals[4], reorder, stream()))
? std::make_optional<array>(
transpose(primals[op_mask_idx + 1], reorder, stream()))
: std::nullopt;
auto grad = block_masked_mm(
cotan, b_t, block_size_, lhs_mask, out_mask, rhs_mask_t, stream());
if (needs_lhs_mask_vjp) {
unmasked_lhs_grad = grad;
unmasked_lhs_grad_calculated = true;
auto exp_mask = expand_mask(primals[op_mask_idx], M, K);
grad = multiply(grad, exp_mask, stream());
}
vjps.push_back(grad);
} else if (arg == 1) {
// (M X K).T * M X N -> K X N
auto a_t = transpose(primals[0], reorder, stream());
auto out_mask = primals[2];
auto out_mask =
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
auto lhs_mask_t = has_op_mask
? std::make_optional<array>(transpose(primals[3], reorder, stream()))
? std::make_optional<array>(
transpose(primals[op_mask_idx], reorder, stream()))
: std::nullopt;
auto rhs_mask = has_op_mask && !needs_rhs_mask_vjp
? std::make_optional<array>(primals[op_mask_idx + 1])
: std::nullopt;
auto rhs_mask =
has_op_mask ? std::make_optional<array>(primals[4]) : std::nullopt;
auto grad = block_masked_mm(
a_t, cotan, block_size_, rhs_mask, lhs_mask_t, out_mask, stream());
if (needs_rhs_mask_vjp) {
unmasked_rhs_grad = grad;
unmasked_rhs_grad_calculated = true;
auto exp_mask = expand_mask(primals[op_mask_idx + 1], K, N);
grad = multiply(grad, exp_mask, stream());
}
vjps.push_back(grad);
} else if (arg == 2 && has_out_mask) {
// Produce the forward result
auto lhs_mask = has_op_mask
? std::make_optional<array>(primals[op_mask_idx])
: std::nullopt;
auto rhs_mask = has_op_mask
? std::make_optional<array>(primals[op_mask_idx + 1])
: std::nullopt;
auto C = block_masked_mm(
primals[0],
primals[1],
block_size_,
primals[2],
lhs_mask,
rhs_mask,
stream());
// Multiply, Pad and Reduce if needed
auto grad = multiply_pad_reduce(cotan, C, align_M, align_N);
vjps.push_back(grad);
} else if (arg == op_mask_idx && has_op_mask) {
if (!unmasked_lhs_grad_calculated) {
// (M X K).T * M X N -> K X N
auto b_t = transpose(primals[1], reorder, stream());
auto out_mask =
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
auto rhs_mask_t =
transpose(primals[op_mask_idx + 1], reorder, stream());
unmasked_lhs_grad = block_masked_mm(
cotan,
b_t,
block_size_,
std::nullopt,
out_mask,
rhs_mask_t,
stream());
unmasked_lhs_grad_calculated = true;
}
// Multiply, Pad and Reduce if needed
auto grad =
multiply_pad_reduce(primals[0], unmasked_lhs_grad, align_M, align_K);
vjps.push_back(grad);
} else if (arg == op_mask_idx + 1 && has_op_mask) {
if (!unmasked_rhs_grad_calculated) {
// (M X K).T * M X N -> K X N
auto a_t = transpose(primals[0], reorder, stream());
auto out_mask =
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
auto lhs_mask_t = transpose(primals[op_mask_idx], reorder, stream());
unmasked_rhs_grad = block_masked_mm(
a_t,
cotan,
block_size_,
std::nullopt,
lhs_mask_t,
out_mask,
stream());
unmasked_rhs_grad_calculated = true;
}
// Multiply, Pad and Reduce if needed
auto grad =
multiply_pad_reduce(primals[1], unmasked_rhs_grad, align_K, align_N);
vjps.push_back(grad);
} else {
throw std::invalid_argument(
"[BlockMaskedMM] Cannot calculate VJP with respect to masks.");