From 3d4174cd371ec330fda15f9c129d09e52726e065 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 5 Jul 2025 00:58:17 -0700 Subject: [PATCH] Add gradient for the scales and biases in gather qmm --- mlx/primitives.cpp | 161 +++++++++++++++++++++++++++++++++------------ 1 file changed, 118 insertions(+), 43 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b6860c9d4..3380fa08b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -109,6 +109,70 @@ std::tuple vmap_ternary_op( return {a, b, c, to_ax}; } +// Calculate the gradient wrt to the weights of the following calculation +// +// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted) +// +// Note the transpose above. This function returns the gradient for w.T so if w +// was used instead then one needs to transpose the returned gradient. +// +// We define it as a separate function to reuse it for gather_mm and +// gather_qmm. +array gather_mm_grad( + const array& x, + const array& dy, + const array& lhs_indices, + const array& rhs_indices, + bool sorted, + Shape batch_shape, + const Stream& s) { + int M = x.shape(-2); + int K = x.shape(-1); + int N = dy.shape(-1); + int num_segments = std::accumulate( + batch_shape.begin(), batch_shape.end(), 1, std::multiplies()); + batch_shape.push_back(N); + batch_shape.push_back(K); + + // If the indices are sorted then it means that we can do the whole gradient + // computation via a segmented matmul. We just need to calculate the segments + // using the indices. + if (sorted) { + auto segments = zeros({num_segments}, uint32, s); + segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s); + segments = cumsum(segments, 0, false, true, s); + segments = concatenate({array({0}, {1}, uint32), segments}, 0, s); + segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s); + + return reshape( + segmented_mm( + swapaxes(flatten(dy, 0, -2, s), 0, 1, s), + flatten(x, 0, -2, s), + segments, + s), + std::move(batch_shape), + s); + } + + // Otherwise we need to gather matmul the dy and then scatter add it to the + // correct locations. + else { + // TODO: If the lhs indices wasn't provided, this is always a sorted matmul + // so we should add that check. + auto dw = gather_mm( + swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s); + return reshape( + scatter_add( + zeros({num_segments, N, K}, dw.dtype(), s), + rhs_indices, + expand_dims(dw, -3, s), + 0, + s), + std::move(batch_shape), + s); + } +} + } // namespace std::vector Primitive::jvp( @@ -3181,7 +3245,6 @@ std::vector QuantizedMatmul::vjp( vjps.push_back(sum(*dsb, -1, false, stream())); } else { // scales - auto s = stream(); auto wq = dequantize( primals[1], ones_like(primals[2], stream()), @@ -3259,6 +3322,7 @@ std::vector GatherQMM::vjp( bool sorted = left_sorted_ || right_sorted_; bool no_broadcast = rhs_indices.size() * M * K == x.size(); + std::optional dsb = std::nullopt; for (auto arg : argnums) { // gradient wrt to x @@ -3297,9 +3361,45 @@ std::vector GatherQMM::vjp( } // gradient wrt to w_q, scales or biases - else { + else if (arg == 1) { throw std::runtime_error( - "GatherQMM::vjp no gradient wrt the quantized matrix yet."); + "GatherQMM::vjp no gradient wrt the quantized weights."); + } else { + if (!dsb) { + auto shape = w.shape(); + shape.pop_back(); + shape.pop_back(); + dsb = unflatten( + gather_mm_grad( + x, + cotan, + lhs_indices, + rhs_indices, + sorted, + std::move(shape), + stream()), + -1, + {-1, group_size_}, + stream()); + } + if (arg == 3) { + vjps.push_back(sum(*dsb, -1, false, stream())); + } else { + vjps.push_back( + sum(multiply( + *dsb, + dequantize( + w, + ones_like(scales, stream()), + zeros_like(biases, stream()), + group_size_, + bits_, + stream()), + stream()), + -1, + false, + stream())); + } } } return vjps; @@ -5106,46 +5206,21 @@ std::vector GatherMM::vjp( stream())); } } else if (arg == 1) { - if (sorted) { - // Make the segments based on the rhs_indices - int num_segments = primals[1].size() / K / N; - auto segments = zeros({num_segments}, uint32, stream()); - segments = scatter_add_axis( - segments, rhs_indices, array(M, uint32), 0, stream()); - segments = cumsum(segments, 0, false, true, stream()); - segments = - concatenate({array({0}, {1}, uint32), segments}, 0, stream()); - segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, stream()); - - // Reshape and transpose the inputs such that they are a big segmented - // matmul. - auto a = reshape(primals[0], {-1, K}, stream()); - auto c = swapaxes(reshape(cotan, {-1, N}, stream()), 0, 1, stream()); - - // Calculate the gradient. - // Since the gather mm is often used as x @ w.T we will calculate the - // gradient as c @ a and transpose it before returning it which should - // save a copy in that case. - auto g = segmented_mm(c, a, segments, stream()); - g = swapaxes(g, 1, 2, stream()); - - vjps.push_back(reshape(g, primals[1].shape(), stream())); - } else { - // (M X K).T * M X N -> K X N - auto base = zeros_like(primals[1], stream()); - auto at = swapaxes(primals[0], -1, -2, stream()); - - auto base_shape = base.shape(); - base = reshape(base, {-1, K, N}, stream()); - - // g : (out_batch_shape) + (K, N) - auto g = - gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream()); - g = expand_dims(g, -3, stream()); - auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); - - vjps.push_back(reshape(gacc, base_shape, stream())); - } + auto shape = b.shape(); + shape.pop_back(); + shape.pop_back(); + vjps.push_back(swapaxes( + gather_mm_grad( + a, + cotan, + lhs_indices, + rhs_indices, + sorted, + std::move(shape), + stream()), + -1, + -2, + stream())); } else { throw std::invalid_argument( "[GatherMM] Cannot calculate VJP with respect to indices.");