Use segmented_mm to calculate the MoE gradient

This commit is contained in:
Angelos Katharopoulos
2025-07-02 16:24:21 -07:00
parent 8f771efb82
commit a29fa053c6
4 changed files with 57 additions and 31 deletions

View File

@@ -57,7 +57,7 @@ template <typename T>
inline void segmented_mm(
const T* a,
const T* b,
const int32_t* segments,
const uint32_t* segments,
T* out,
bool a_transposed,
bool b_transposed,
@@ -75,11 +75,10 @@ inline void segmented_mm(
Shape b_copy = b_shape;
int32_t M = a_copy[ndim - 2];
int32_t N = b_copy[ndim - 1];
int32_t k_start = 0;
for (int i = 0; i < num_segments; i++) {
int32_t k_start =
uint32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
int32_t k_end =
uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start;
@@ -529,7 +528,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<double>(
a.data<double>(),
b.data<double>(),
segments.data<int32_t>(),
segments.data<uint32_t>(),
static_cast<double*>(out_ptr),
a_transposed,
b_transposed,
@@ -547,7 +546,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<float>(
a.data<float>(),
b.data<float>(),
segments.data<int32_t>(),
segments.data<uint32_t>(),
static_cast<float*>(out_ptr),
a_transposed,
b_transposed,
@@ -565,7 +564,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<float16_t>(
a.data<float16_t>(),
b.data<float16_t>(),
segments.data<int32_t>(),
segments.data<uint32_t>(),
static_cast<float16_t*>(out_ptr),
a_transposed,
b_transposed,
@@ -583,7 +582,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<bfloat16_t>(
a.data<bfloat16_t>(),
b.data<bfloat16_t>(),
segments.data<int32_t>(),
segments.data<uint32_t>(),
static_cast<bfloat16_t*>(out_ptr),
a_transposed,
b_transposed,

View File

@@ -19,7 +19,7 @@ template <
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device int32_t* segments [[buffer(2)]],
const device uint32_t* segments [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
@@ -68,7 +68,7 @@ template <
C += c_row_long * params->ldd + c_col_long;
// Move the pointers to the start of the segment
int32_t k_start, k_end;
uint32_t k_start, k_end;
if (segments_contiguous) {
k_start = segments[2 * tid.z];
k_end = segments[2 * tid.z + 1];
@@ -92,7 +92,7 @@ template <
// Matrix level alignment so only check K
if (align_M && align_N) {
int k = k_start + BK;
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -109,7 +109,7 @@ template <
loader_a.next();
loader_b.next();
}
short k_remain = k_end - (k - BK);
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
@@ -125,7 +125,7 @@ template <
} else {
// Tile aligned do the same as above
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
int k = k_start + BK;
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -142,7 +142,7 @@ template <
loader_a.next();
loader_b.next();
}
short k_remain = k_end - (k - BK);
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
@@ -159,7 +159,7 @@ template <
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
int k = k_start + BK;
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -177,7 +177,7 @@ template <
loader_a.next();
loader_b.next();
}
short k_remain = k_end - (k - BK);
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
@@ -194,7 +194,7 @@ template <
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
int k = k_start + BK;
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -212,7 +212,7 @@ template <
loader_a.next();
loader_b.next();
}
short k_remain = k_end - (k - BK);
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
@@ -229,7 +229,7 @@ template <
// Nothing aligned so check both rows and cols
else {
int k = k_start + BK;
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -248,7 +248,7 @@ template <
loader_a.next();
loader_b.next();
}
short k_remain = k_end - (k - BK);
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =

View File

@@ -4678,6 +4678,7 @@ array segmented_mm(
a = astype(a, out_type, s);
b = astype(b, out_type, s);
segments = astype(segments, uint32, s);
Shape out_shape = segments.shape();
out_shape.pop_back();

View File

@@ -5091,20 +5091,46 @@ std::vector<array> GatherMM::vjp(
vjps.push_back(reshape(gacc, base_shape, stream()));
} else if (arg == 1) {
// (M X K).T * M X N -> K X N
auto base = zeros_like(primals[1], stream());
auto at = swapaxes(primals[0], -1, -2, stream());
if (sorted) {
// Make the segments based on the rhs_indices
int num_segments = primals[1].size() / K / N;
auto segments = zeros({num_segments}, uint32, stream());
segments = scatter_add_axis(
segments, rhs_indices, array(M, uint32), 0, stream());
segments = cumsum(segments, 0, false, true, stream());
segments =
concatenate({array({0}, {1}, uint32), segments}, 0, stream());
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, stream());
auto base_shape = base.shape();
base = reshape(base, {-1, K, N}, stream());
// Reshape and transpose the inputs such that they are a big segmented
// matmul.
auto a = reshape(primals[0], {-1, K}, stream());
auto c = swapaxes(reshape(cotan, {-1, N}, stream()), 0, 1, stream());
// g : (out_batch_shape) + (K, N)
auto g =
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
// Calculate the gradient.
// Since the gather mm is often used as x @ w.T we will calculate the
// gradient as c @ a and transpose it before returning it which should
// save a copy in that case.
auto g = segmented_mm(c, a, segments, stream());
g = swapaxes(g, 1, 2, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
vjps.push_back(reshape(g, primals[1].shape(), stream()));
} else {
// (M X K).T * M X N -> K X N
auto base = zeros_like(primals[1], stream());
auto at = swapaxes(primals[0], -1, -2, stream());
auto base_shape = base.shape();
base = reshape(base, {-1, K, N}, stream());
// g : (out_batch_shape) + (K, N)
auto g =
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
}
} else {
throw std::invalid_argument(
"[GatherMM] Cannot calculate VJP with respect to indices.");