diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4af812570..b6860c9d4 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3253,34 +3253,41 @@ std::vector GatherQMM::vjp( auto& lhs_indices = primals[4]; auto& rhs_indices = primals[5]; + int M = cotan.shape(-2); + int N = cotan.shape(-1); + int K = x.shape(-1); + bool sorted = left_sorted_ || right_sorted_; + bool no_broadcast = rhs_indices.size() * M * K == x.size(); for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { - vjps.push_back(reshape( - scatter_add( - flatten(zeros_like(x, stream()), 0, -3, stream()), - lhs_indices, - expand_dims( - gather_qmm( - cotan, - w, - scales, - biases, - std::nullopt, - rhs_indices, - !transpose_, - group_size_, - bits_, - sorted, - stream()), - -3, - stream()), - 0, - stream()), - x.shape(), - stream())); + auto g = gather_qmm( + cotan, + w, + scales, + biases, + std::nullopt, + rhs_indices, + !transpose_, + group_size_, + bits_, + sorted, + stream()); + if (sorted && no_broadcast) { + vjps.push_back(g); + } else { + vjps.push_back(reshape( + scatter_add( + flatten(zeros_like(x, stream()), 0, -3, stream()), + lhs_indices, + expand_dims(g, -3, stream()), + 0, + stream()), + x.shape(), + stream())); + } } // gradient wrt to the indices is undefined @@ -5064,6 +5071,8 @@ std::vector GatherMM::vjp( std::vector vjps; auto& cotan = cotangents[0]; + auto& a = primals[0]; + auto& b = primals[1]; auto& lhs_indices = primals[2]; auto& rhs_indices = primals[3]; @@ -5076,23 +5085,26 @@ std::vector GatherMM::vjp( for (auto arg : argnums) { if (arg == 0) { - // M X N * (K X N).T -> M X K - auto bt = swapaxes(primals[1], -1, -2, stream()); - - // g : (out_batch_shape) + (M, K) - auto g = - gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); + auto g = gather_mm( + cotan, + swapaxes(b, -1, -2, stream()), + std::nullopt, + rhs_indices, + sorted, + stream()); if (sorted && no_broadcast) { vjps.push_back(g); } else { - g = expand_dims(g, -3, stream()); - auto base = zeros_like(primals[0], stream()); - auto base_shape = base.shape(); - base = reshape(base, {-1, M, K}, stream()); - auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); - vjps.push_back(reshape(gacc, base_shape, stream())); + vjps.push_back(reshape( + scatter_add( + flatten(zeros_like(a, stream()), 0, -3, stream()), + lhs_indices, + expand_dims(g, -3, stream()), + 0, + stream()), + a.shape(), + stream())); } - } else if (arg == 1) { if (sorted) { // Make the segments based on the rhs_indices