From a8d7b749848ac49a54717be17e168393ec5a8cb2 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 2 Jul 2025 18:23:48 -0700 Subject: [PATCH] Simplify the jacobian as well --- mlx/primitives.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index d63d23a96..4af812570 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -5072,23 +5072,26 @@ std::vector GatherMM::vjp( int K = primals[0].shape(-1); bool sorted = left_sorted_ || right_sorted_; + bool no_broadcast = rhs_indices.size() * M * K == primals[0].size(); for (auto arg : argnums) { if (arg == 0) { // M X N * (K X N).T -> M X K - auto base = zeros_like(primals[0], stream()); auto bt = swapaxes(primals[1], -1, -2, stream()); - auto base_shape = base.shape(); - base = reshape(base, {-1, M, K}, stream()); - // g : (out_batch_shape) + (M, K) auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); - g = expand_dims(g, -3, stream()); - auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); - - vjps.push_back(reshape(gacc, base_shape, stream())); + if (sorted && no_broadcast) { + vjps.push_back(g); + } else { + g = expand_dims(g, -3, stream()); + auto base = zeros_like(primals[0], stream()); + auto base_shape = base.shape(); + base = reshape(base, {-1, M, K}, stream()); + auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); + vjps.push_back(reshape(gacc, base_shape, stream())); + } } else if (arg == 1) { if (sorted) {