Use segmented_mm to calculate the MoE gradient

This commit is contained in:
Angelos Katharopoulos
2025-07-02 16:24:21 -07:00
parent 8f771efb82
commit a29fa053c6
4 changed files with 57 additions and 31 deletions

View File

@@ -57,7 +57,7 @@ template <typename T>
inline void segmented_mm( inline void segmented_mm(
const T* a, const T* a,
const T* b, const T* b,
const int32_t* segments, const uint32_t* segments,
T* out, T* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
@@ -75,11 +75,10 @@ inline void segmented_mm(
Shape b_copy = b_shape; Shape b_copy = b_shape;
int32_t M = a_copy[ndim - 2]; int32_t M = a_copy[ndim - 2];
int32_t N = b_copy[ndim - 1]; int32_t N = b_copy[ndim - 1];
int32_t k_start = 0;
for (int i = 0; i < num_segments; i++) { for (int i = 0; i < num_segments; i++) {
int32_t k_start = uint32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)]; segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
int32_t k_end = uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
a_copy[ndim - 1] = k_end - k_start; a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start; b_copy[ndim - 2] = k_end - k_start;
@@ -529,7 +528,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<double>( segmented_mm<double>(
a.data<double>(), a.data<double>(),
b.data<double>(), b.data<double>(),
segments.data<int32_t>(), segments.data<uint32_t>(),
static_cast<double*>(out_ptr), static_cast<double*>(out_ptr),
a_transposed, a_transposed,
b_transposed, b_transposed,
@@ -547,7 +546,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<float>( segmented_mm<float>(
a.data<float>(), a.data<float>(),
b.data<float>(), b.data<float>(),
segments.data<int32_t>(), segments.data<uint32_t>(),
static_cast<float*>(out_ptr), static_cast<float*>(out_ptr),
a_transposed, a_transposed,
b_transposed, b_transposed,
@@ -565,7 +564,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<float16_t>( segmented_mm<float16_t>(
a.data<float16_t>(), a.data<float16_t>(),
b.data<float16_t>(), b.data<float16_t>(),
segments.data<int32_t>(), segments.data<uint32_t>(),
static_cast<float16_t*>(out_ptr), static_cast<float16_t*>(out_ptr),
a_transposed, a_transposed,
b_transposed, b_transposed,
@@ -583,7 +582,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<bfloat16_t>( segmented_mm<bfloat16_t>(
a.data<bfloat16_t>(), a.data<bfloat16_t>(),
b.data<bfloat16_t>(), b.data<bfloat16_t>(),
segments.data<int32_t>(), segments.data<uint32_t>(),
static_cast<bfloat16_t*>(out_ptr), static_cast<bfloat16_t*>(out_ptr),
a_transposed, a_transposed,
b_transposed, b_transposed,

View File

@@ -19,7 +19,7 @@ template <
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm( [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
const device T* A [[buffer(0)]], const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]], const device T* B [[buffer(1)]],
const device int32_t* segments [[buffer(2)]], const device uint32_t* segments [[buffer(2)]],
device T* C [[buffer(3)]], device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]], const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
@@ -68,7 +68,7 @@ template <
C += c_row_long * params->ldd + c_col_long; C += c_row_long * params->ldd + c_col_long;
// Move the pointers to the start of the segment // Move the pointers to the start of the segment
int32_t k_start, k_end; uint32_t k_start, k_end;
if (segments_contiguous) { if (segments_contiguous) {
k_start = segments[2 * tid.z]; k_start = segments[2 * tid.z];
k_end = segments[2 * tid.z + 1]; k_end = segments[2 * tid.z + 1];
@@ -92,7 +92,7 @@ template <
// Matrix level alignment so only check K // Matrix level alignment so only check K
if (align_M && align_N) { if (align_M && align_N) {
int k = k_start + BK; uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) { for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -109,7 +109,7 @@ template <
loader_a.next(); loader_a.next();
loader_b.next(); loader_b.next();
} }
short k_remain = k_end - (k - BK); short k_remain = BK - short(k - k_end);
const short2 tile_dims_A = const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B = const short2 tile_dims_B =
@@ -125,7 +125,7 @@ template <
} else { } else {
// Tile aligned do the same as above // Tile aligned do the same as above
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
int k = k_start + BK; uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) { for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -142,7 +142,7 @@ template <
loader_a.next(); loader_a.next();
loader_b.next(); loader_b.next();
} }
short k_remain = k_end - (k - BK); short k_remain = BK - short(k - k_end);
const short2 tile_dims_A = const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B = const short2 tile_dims_B =
@@ -159,7 +159,7 @@ template <
// Tile partially aligned check rows // Tile partially aligned check rows
else if (align_N || tgp_bn == BN) { else if (align_N || tgp_bn == BN) {
int k = k_start + BK; uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) { for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -177,7 +177,7 @@ template <
loader_a.next(); loader_a.next();
loader_b.next(); loader_b.next();
} }
short k_remain = k_end - (k - BK); short k_remain = BK - short(k - k_end);
const short2 tile_dims_A = const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B = const short2 tile_dims_B =
@@ -194,7 +194,7 @@ template <
// Tile partially aligned check cols // Tile partially aligned check cols
else if (align_M || tgp_bm == BM) { else if (align_M || tgp_bm == BM) {
int k = k_start + BK; uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) { for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -212,7 +212,7 @@ template <
loader_a.next(); loader_a.next();
loader_b.next(); loader_b.next();
} }
short k_remain = k_end - (k - BK); short k_remain = BK - short(k - k_end);
const short2 tile_dims_A = const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B = const short2 tile_dims_B =
@@ -229,7 +229,7 @@ template <
// Nothing aligned so check both rows and cols // Nothing aligned so check both rows and cols
else { else {
int k = k_start + BK; uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) { for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -248,7 +248,7 @@ template <
loader_a.next(); loader_a.next();
loader_b.next(); loader_b.next();
} }
short k_remain = k_end - (k - BK); short k_remain = BK - short(k - k_end);
const short2 tile_dims_A = const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B = const short2 tile_dims_B =

View File

@@ -4678,6 +4678,7 @@ array segmented_mm(
a = astype(a, out_type, s); a = astype(a, out_type, s);
b = astype(b, out_type, s); b = astype(b, out_type, s);
segments = astype(segments, uint32, s);
Shape out_shape = segments.shape(); Shape out_shape = segments.shape();
out_shape.pop_back(); out_shape.pop_back();

View File

@@ -5091,6 +5091,31 @@ std::vector<array> GatherMM::vjp(
vjps.push_back(reshape(gacc, base_shape, stream())); vjps.push_back(reshape(gacc, base_shape, stream()));
} else if (arg == 1) { } else if (arg == 1) {
if (sorted) {
// Make the segments based on the rhs_indices
int num_segments = primals[1].size() / K / N;
auto segments = zeros({num_segments}, uint32, stream());
segments = scatter_add_axis(
segments, rhs_indices, array(M, uint32), 0, stream());
segments = cumsum(segments, 0, false, true, stream());
segments =
concatenate({array({0}, {1}, uint32), segments}, 0, stream());
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, stream());
// Reshape and transpose the inputs such that they are a big segmented
// matmul.
auto a = reshape(primals[0], {-1, K}, stream());
auto c = swapaxes(reshape(cotan, {-1, N}, stream()), 0, 1, stream());
// Calculate the gradient.
// Since the gather mm is often used as x @ w.T we will calculate the
// gradient as c @ a and transpose it before returning it which should
// save a copy in that case.
auto g = segmented_mm(c, a, segments, stream());
g = swapaxes(g, 1, 2, stream());
vjps.push_back(reshape(g, primals[1].shape(), stream()));
} else {
// (M X K).T * M X N -> K X N // (M X K).T * M X N -> K X N
auto base = zeros_like(primals[1], stream()); auto base = zeros_like(primals[1], stream());
auto at = swapaxes(primals[0], -1, -2, stream()); auto at = swapaxes(primals[0], -1, -2, stream());
@@ -5105,6 +5130,7 @@ std::vector<array> GatherMM::vjp(
auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream())); vjps.push_back(reshape(gacc, base_shape, stream()));
}
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[GatherMM] Cannot calculate VJP with respect to indices."); "[GatherMM] Cannot calculate VJP with respect to indices.");