Simplify the jacobian as well

This commit is contained in:
Angelos Katharopoulos
2025-07-02 18:23:48 -07:00
parent a29fa053c6
commit a8d7b74984

View File

@@ -5072,23 +5072,26 @@ std::vector<array> GatherMM::vjp(
int K = primals[0].shape(-1);
bool sorted = left_sorted_ || right_sorted_;
bool no_broadcast = rhs_indices.size() * M * K == primals[0].size();
for (auto arg : argnums) {
if (arg == 0) {
// M X N * (K X N).T -> M X K
auto base = zeros_like(primals[0], stream());
auto bt = swapaxes(primals[1], -1, -2, stream());
auto base_shape = base.shape();
base = reshape(base, {-1, M, K}, stream());
// g : (out_batch_shape) + (M, K)
auto g =
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
if (sorted && no_broadcast) {
vjps.push_back(g);
} else {
g = expand_dims(g, -3, stream());
auto base = zeros_like(primals[0], stream());
auto base_shape = base.shape();
base = reshape(base, {-1, M, K}, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
}
} else if (arg == 1) {
if (sorted) {