From a29fa053c6d2823b29caa666dd383586a740c5d6 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 2 Jul 2025 16:24:21 -0700 Subject: [PATCH] Use segmented_mm to calculate the MoE gradient --- mlx/backend/cpu/masked_mm.cpp | 15 +++--- .../steel/gemm/kernels/steel_gemm_segmented.h | 24 +++++----- mlx/ops.cpp | 1 + mlx/primitives.cpp | 48 ++++++++++++++----- 4 files changed, 57 insertions(+), 31 deletions(-) diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 6fcf25b15..cd0680131 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -57,7 +57,7 @@ template inline void segmented_mm( const T* a, const T* b, - const int32_t* segments, + const uint32_t* segments, T* out, bool a_transposed, bool b_transposed, @@ -75,11 +75,10 @@ inline void segmented_mm( Shape b_copy = b_shape; int32_t M = a_copy[ndim - 2]; int32_t N = b_copy[ndim - 1]; - int32_t k_start = 0; for (int i = 0; i < num_segments; i++) { - int32_t k_start = + uint32_t k_start = segments[elem_to_loc(2 * i, segments_shape, segments_strides)]; - int32_t k_end = + uint32_t k_end = segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; a_copy[ndim - 1] = k_end - k_start; b_copy[ndim - 2] = k_end - k_start; @@ -529,7 +528,7 @@ void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { segmented_mm( a.data(), b.data(), - segments.data(), + segments.data(), static_cast(out_ptr), a_transposed, b_transposed, @@ -547,7 +546,7 @@ void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { segmented_mm( a.data(), b.data(), - segments.data(), + segments.data(), static_cast(out_ptr), a_transposed, b_transposed, @@ -565,7 +564,7 @@ void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { segmented_mm( a.data(), b.data(), - segments.data(), + segments.data(), static_cast(out_ptr), a_transposed, b_transposed, @@ -583,7 +582,7 @@ void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { segmented_mm( a.data(), b.data(), - segments.data(), + segments.data(), static_cast(out_ptr), a_transposed, b_transposed, diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h index d1258f0a9..b915eb343 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h @@ -19,7 +19,7 @@ template < [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], - const device int32_t* segments [[buffer(2)]], + const device uint32_t* segments [[buffer(2)]], device T* C [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], uint simd_lane_id [[thread_index_in_simdgroup]], @@ -68,7 +68,7 @@ template < C += c_row_long * params->ldd + c_col_long; // Move the pointers to the start of the segment - int32_t k_start, k_end; + uint32_t k_start, k_end; if (segments_contiguous) { k_start = segments[2 * tid.z]; k_end = segments[2 * tid.z + 1]; @@ -92,7 +92,7 @@ template < // Matrix level alignment so only check K if (align_M && align_N) { - int k = k_start + BK; + uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -109,7 +109,7 @@ template < loader_a.next(); loader_b.next(); } - short k_remain = k_end - (k - BK); + short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = @@ -125,7 +125,7 @@ template < } else { // Tile aligned do the same as above if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { - int k = k_start + BK; + uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -142,7 +142,7 @@ template < loader_a.next(); loader_b.next(); } - short k_remain = k_end - (k - BK); + short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = @@ -159,7 +159,7 @@ template < // Tile partially aligned check rows else if (align_N || tgp_bn == BN) { - int k = k_start + BK; + uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -177,7 +177,7 @@ template < loader_a.next(); loader_b.next(); } - short k_remain = k_end - (k - BK); + short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = @@ -194,7 +194,7 @@ template < // Tile partially aligned check cols else if (align_M || tgp_bm == BM) { - int k = k_start + BK; + uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -212,7 +212,7 @@ template < loader_a.next(); loader_b.next(); } - short k_remain = k_end - (k - BK); + short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = @@ -229,7 +229,7 @@ template < // Nothing aligned so check both rows and cols else { - int k = k_start + BK; + uint32_t k = k_start + BK; for (; k <= k_end; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -248,7 +248,7 @@ template < loader_a.next(); loader_b.next(); } - short k_remain = k_end - (k - BK); + short k_remain = BK - short(k - k_end); const short2 tile_dims_A = transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); const short2 tile_dims_B = diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e35241c03..255b0307c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4678,6 +4678,7 @@ array segmented_mm( a = astype(a, out_type, s); b = astype(b, out_type, s); + segments = astype(segments, uint32, s); Shape out_shape = segments.shape(); out_shape.pop_back(); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 5f2bfdda4..d63d23a96 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -5091,20 +5091,46 @@ std::vector GatherMM::vjp( vjps.push_back(reshape(gacc, base_shape, stream())); } else if (arg == 1) { - // (M X K).T * M X N -> K X N - auto base = zeros_like(primals[1], stream()); - auto at = swapaxes(primals[0], -1, -2, stream()); + if (sorted) { + // Make the segments based on the rhs_indices + int num_segments = primals[1].size() / K / N; + auto segments = zeros({num_segments}, uint32, stream()); + segments = scatter_add_axis( + segments, rhs_indices, array(M, uint32), 0, stream()); + segments = cumsum(segments, 0, false, true, stream()); + segments = + concatenate({array({0}, {1}, uint32), segments}, 0, stream()); + segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, stream()); - auto base_shape = base.shape(); - base = reshape(base, {-1, K, N}, stream()); + // Reshape and transpose the inputs such that they are a big segmented + // matmul. + auto a = reshape(primals[0], {-1, K}, stream()); + auto c = swapaxes(reshape(cotan, {-1, N}, stream()), 0, 1, stream()); - // g : (out_batch_shape) + (K, N) - auto g = - gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream()); - g = expand_dims(g, -3, stream()); - auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); + // Calculate the gradient. + // Since the gather mm is often used as x @ w.T we will calculate the + // gradient as c @ a and transpose it before returning it which should + // save a copy in that case. + auto g = segmented_mm(c, a, segments, stream()); + g = swapaxes(g, 1, 2, stream()); - vjps.push_back(reshape(gacc, base_shape, stream())); + vjps.push_back(reshape(g, primals[1].shape(), stream())); + } else { + // (M X K).T * M X N -> K X N + auto base = zeros_like(primals[1], stream()); + auto at = swapaxes(primals[0], -1, -2, stream()); + + auto base_shape = base.shape(); + base = reshape(base, {-1, K, N}, stream()); + + // g : (out_batch_shape) + (K, N) + auto g = + gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream()); + g = expand_dims(g, -3, stream()); + auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); + + vjps.push_back(reshape(gacc, base_shape, stream())); + } } else { throw std::invalid_argument( "[GatherMM] Cannot calculate VJP with respect to indices.");