mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Simplify the jacobian as well
This commit is contained in:
@@ -5072,23 +5072,26 @@ std::vector<array> GatherMM::vjp(
|
||||
int K = primals[0].shape(-1);
|
||||
|
||||
bool sorted = left_sorted_ || right_sorted_;
|
||||
bool no_broadcast = rhs_indices.size() * M * K == primals[0].size();
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
auto base = zeros_like(primals[0], stream());
|
||||
auto bt = swapaxes(primals[1], -1, -2, stream());
|
||||
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, M, K}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (M, K)
|
||||
auto g =
|
||||
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
if (sorted && no_broadcast) {
|
||||
vjps.push_back(g);
|
||||
} else {
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto base = zeros_like(primals[0], stream());
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, M, K}, stream());
|
||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
}
|
||||
|
||||
} else if (arg == 1) {
|
||||
if (sorted) {
|
||||
|
||||
Reference in New Issue
Block a user