mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -3487,42 +3487,251 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// The operation that is done w/o intermediates by the primitive is
|
||||
// - tm = (M + block_size - 1) // block_size; MP = tm * block_size;
|
||||
// - tn = (N + block_size - 1) // block_size; NP = tn * block_size;
|
||||
// - tm = (K + block_size - 1) // block_size; KP = tk * block_size;
|
||||
// - mask_b <- mask broadcasted to block sizes
|
||||
// - A_m = A [..., M, K] * mask_b_lhs [..., MP, KP]
|
||||
// - B_m = B [..., K, N] * mask_b_rhs [..., KP, MP]
|
||||
// - C = A_m [..., M, K] @ B_m [..., K, N]
|
||||
// - C_m = C [..., M, N] * mask_b_out [..., MP, NP]
|
||||
//
|
||||
// The grads are therefore
|
||||
// - dC_m = cotan [..., M, N]
|
||||
// - dmask_b_out = cotan [..., M, N] * C [..., M, N]
|
||||
// - dC = cotan [..., M, N] * mask_b_out [..., MP, NP]
|
||||
// - dA_m = dC [..., M, N] @ B_m.T [..., N, K]
|
||||
// - dB_m = A_m.T [..., K, M] @ dC [..., M, N]
|
||||
// - dA = dA_m * mask_b_lhs [..., MP, KP]
|
||||
// - dB = dB_m * mask_b_rhs [..., KP, MP]
|
||||
// - dmask_b_lhs = dA_m [..., M, K] * A [..., M, K] // need [..., MP, KP]
|
||||
// - dmask_b_rhs = dB_m [..., K, N] * B [..., K, N] // need [..., KP, NP]
|
||||
//
|
||||
// Observations:
|
||||
// * If dmask_b_lhs is not needed, then dA can be calulated in one go as a
|
||||
// as a block_masked_mm with mask_b_lhs as the out_mask without needing to
|
||||
// materialize the intermediate dA_m. Similar for dB.
|
||||
// * If dmask_b_lhs is needed, we need to materialize dA_m directly and then
|
||||
// point-wise multiply with A. But the output needs to be padded
|
||||
|
||||
std::vector<array> vjps;
|
||||
auto& cotan = cotangents[0];
|
||||
std::vector<int> reorder(cotan.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
|
||||
|
||||
bool has_op_mask = primals.size() > 3;
|
||||
bool has_out_mask = primals.size() == 3 || primals.size() == 5;
|
||||
|
||||
const int op_mask_idx = has_out_mask ? 3 : 2;
|
||||
bool needs_lhs_mask_vjp = has_op_mask;
|
||||
bool needs_rhs_mask_vjp = has_op_mask;
|
||||
bool needs_lhs_vjp = false;
|
||||
bool needs_rhs_vjp = false;
|
||||
|
||||
for (auto arg : argnums) {
|
||||
needs_lhs_vjp = arg == 0;
|
||||
needs_rhs_vjp = arg == 1;
|
||||
needs_lhs_mask_vjp = arg == op_mask_idx;
|
||||
needs_rhs_mask_vjp = arg == op_mask_idx + 1;
|
||||
}
|
||||
|
||||
if ((needs_lhs_mask_vjp && primals[op_mask_idx].dtype() == bool_) ||
|
||||
(needs_rhs_mask_vjp && primals[op_mask_idx + 1].dtype() == bool_)) {
|
||||
throw std::invalid_argument(
|
||||
"[BlockMaskedMM] Cannot calculate VJP with respect to boolean masks.");
|
||||
}
|
||||
|
||||
auto expand_mask = [&](array mask, int Y, int X) {
|
||||
// Exapnd mask
|
||||
auto mask_reshape = mask.shape();
|
||||
mask = expand_dims(mask, {-3, -1}, stream());
|
||||
auto mask_shape = mask.shape();
|
||||
int mask_ndim = mask_shape.size();
|
||||
|
||||
// Broadcast mask
|
||||
mask_shape[mask_ndim - 1] = block_size_;
|
||||
mask_shape[mask_ndim - 3] = block_size_;
|
||||
mask = broadcast_to(mask, mask_shape, stream());
|
||||
|
||||
// Reshape mask to squeeze in braodcasted dims
|
||||
mask_ndim = mask_reshape.size();
|
||||
mask_reshape[mask_ndim - 2] *= block_size_;
|
||||
mask_reshape[mask_ndim - 1] *= block_size_;
|
||||
mask = reshape(mask, mask_reshape, stream());
|
||||
|
||||
// Slice mask
|
||||
mask_reshape[mask_ndim - 2] = Y;
|
||||
mask_reshape[mask_ndim - 1] = X;
|
||||
mask = slice(mask, std::vector<int>(mask_ndim, 0), mask_reshape, stream());
|
||||
|
||||
return mask;
|
||||
};
|
||||
|
||||
array zero = array(0, cotan.dtype());
|
||||
|
||||
auto multiply_pad_reduce = [&](array p, array q, int align_Y, int align_X) {
|
||||
// Multiply with cotan
|
||||
auto r = multiply(p, q, stream());
|
||||
|
||||
// Pad if needed
|
||||
if ((align_Y != 0) || (align_X != 0)) {
|
||||
r = pad(r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, stream());
|
||||
}
|
||||
|
||||
// Reshape
|
||||
std::vector<int> r_reshape(r.shape().begin(), r.shape().end() - 2);
|
||||
r_reshape.push_back(r.shape(-2) / block_size_);
|
||||
r_reshape.push_back(block_size_);
|
||||
r_reshape.push_back(r.shape(-1) / block_size_);
|
||||
r_reshape.push_back(block_size_);
|
||||
r = reshape(r, r_reshape, stream());
|
||||
|
||||
// Reduce
|
||||
return sum(r, {-3, -1}, false, stream());
|
||||
};
|
||||
|
||||
// Prepare for padding if needed
|
||||
int M = cotan.shape(-2);
|
||||
int N = cotan.shape(-1);
|
||||
int K = primals[0].shape(-1);
|
||||
int align_M = (M % block_size_);
|
||||
int align_N = (N % block_size_);
|
||||
int align_K = (K % block_size_);
|
||||
|
||||
// Potential intermediates
|
||||
auto unmasked_lhs_grad = primals[0];
|
||||
auto unmasked_rhs_grad = primals[1];
|
||||
|
||||
bool unmasked_lhs_grad_calculated = false;
|
||||
bool unmasked_rhs_grad_calculated = false;
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
auto b_t = transpose(primals[1], reorder, stream());
|
||||
auto out_mask = primals[2];
|
||||
auto lhs_mask =
|
||||
has_op_mask ? std::make_optional<array>(primals[3]) : std::nullopt;
|
||||
auto out_mask =
|
||||
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
|
||||
auto lhs_mask = has_op_mask && !needs_lhs_mask_vjp
|
||||
? std::make_optional<array>(primals[op_mask_idx])
|
||||
: std::nullopt;
|
||||
auto rhs_mask_t = has_op_mask
|
||||
? std::make_optional<array>(transpose(primals[4], reorder, stream()))
|
||||
? std::make_optional<array>(
|
||||
transpose(primals[op_mask_idx + 1], reorder, stream()))
|
||||
: std::nullopt;
|
||||
|
||||
auto grad = block_masked_mm(
|
||||
cotan, b_t, block_size_, lhs_mask, out_mask, rhs_mask_t, stream());
|
||||
|
||||
if (needs_lhs_mask_vjp) {
|
||||
unmasked_lhs_grad = grad;
|
||||
unmasked_lhs_grad_calculated = true;
|
||||
auto exp_mask = expand_mask(primals[op_mask_idx], M, K);
|
||||
grad = multiply(grad, exp_mask, stream());
|
||||
}
|
||||
|
||||
vjps.push_back(grad);
|
||||
|
||||
} else if (arg == 1) {
|
||||
// (M X K).T * M X N -> K X N
|
||||
auto a_t = transpose(primals[0], reorder, stream());
|
||||
auto out_mask = primals[2];
|
||||
auto out_mask =
|
||||
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
|
||||
auto lhs_mask_t = has_op_mask
|
||||
? std::make_optional<array>(transpose(primals[3], reorder, stream()))
|
||||
? std::make_optional<array>(
|
||||
transpose(primals[op_mask_idx], reorder, stream()))
|
||||
: std::nullopt;
|
||||
auto rhs_mask = has_op_mask && !needs_rhs_mask_vjp
|
||||
? std::make_optional<array>(primals[op_mask_idx + 1])
|
||||
: std::nullopt;
|
||||
auto rhs_mask =
|
||||
has_op_mask ? std::make_optional<array>(primals[4]) : std::nullopt;
|
||||
|
||||
auto grad = block_masked_mm(
|
||||
a_t, cotan, block_size_, rhs_mask, lhs_mask_t, out_mask, stream());
|
||||
|
||||
if (needs_rhs_mask_vjp) {
|
||||
unmasked_rhs_grad = grad;
|
||||
unmasked_rhs_grad_calculated = true;
|
||||
auto exp_mask = expand_mask(primals[op_mask_idx + 1], K, N);
|
||||
grad = multiply(grad, exp_mask, stream());
|
||||
}
|
||||
|
||||
vjps.push_back(grad);
|
||||
|
||||
} else if (arg == 2 && has_out_mask) {
|
||||
// Produce the forward result
|
||||
auto lhs_mask = has_op_mask
|
||||
? std::make_optional<array>(primals[op_mask_idx])
|
||||
: std::nullopt;
|
||||
auto rhs_mask = has_op_mask
|
||||
? std::make_optional<array>(primals[op_mask_idx + 1])
|
||||
: std::nullopt;
|
||||
|
||||
auto C = block_masked_mm(
|
||||
primals[0],
|
||||
primals[1],
|
||||
block_size_,
|
||||
primals[2],
|
||||
lhs_mask,
|
||||
rhs_mask,
|
||||
stream());
|
||||
|
||||
// Multiply, Pad and Reduce if needed
|
||||
auto grad = multiply_pad_reduce(cotan, C, align_M, align_N);
|
||||
vjps.push_back(grad);
|
||||
|
||||
} else if (arg == op_mask_idx && has_op_mask) {
|
||||
if (!unmasked_lhs_grad_calculated) {
|
||||
// (M X K).T * M X N -> K X N
|
||||
auto b_t = transpose(primals[1], reorder, stream());
|
||||
auto out_mask =
|
||||
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
|
||||
auto rhs_mask_t =
|
||||
transpose(primals[op_mask_idx + 1], reorder, stream());
|
||||
|
||||
unmasked_lhs_grad = block_masked_mm(
|
||||
cotan,
|
||||
b_t,
|
||||
block_size_,
|
||||
std::nullopt,
|
||||
out_mask,
|
||||
rhs_mask_t,
|
||||
stream());
|
||||
|
||||
unmasked_lhs_grad_calculated = true;
|
||||
}
|
||||
|
||||
// Multiply, Pad and Reduce if needed
|
||||
auto grad =
|
||||
multiply_pad_reduce(primals[0], unmasked_lhs_grad, align_M, align_K);
|
||||
vjps.push_back(grad);
|
||||
|
||||
} else if (arg == op_mask_idx + 1 && has_op_mask) {
|
||||
if (!unmasked_rhs_grad_calculated) {
|
||||
// (M X K).T * M X N -> K X N
|
||||
auto a_t = transpose(primals[0], reorder, stream());
|
||||
auto out_mask =
|
||||
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
|
||||
auto lhs_mask_t = transpose(primals[op_mask_idx], reorder, stream());
|
||||
|
||||
unmasked_rhs_grad = block_masked_mm(
|
||||
a_t,
|
||||
cotan,
|
||||
block_size_,
|
||||
std::nullopt,
|
||||
lhs_mask_t,
|
||||
out_mask,
|
||||
stream());
|
||||
|
||||
unmasked_rhs_grad_calculated = true;
|
||||
}
|
||||
|
||||
// Multiply, Pad and Reduce if needed
|
||||
auto grad =
|
||||
multiply_pad_reduce(primals[1], unmasked_rhs_grad, align_K, align_N);
|
||||
vjps.push_back(grad);
|
||||
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[BlockMaskedMM] Cannot calculate VJP with respect to masks.");
|
||||
|
||||
Reference in New Issue
Block a user