mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use segmented_mm to calculate the MoE gradient
This commit is contained in:
@@ -57,7 +57,7 @@ template <typename T>
|
||||
inline void segmented_mm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
const int32_t* segments,
|
||||
const uint32_t* segments,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
@@ -75,11 +75,10 @@ inline void segmented_mm(
|
||||
Shape b_copy = b_shape;
|
||||
int32_t M = a_copy[ndim - 2];
|
||||
int32_t N = b_copy[ndim - 1];
|
||||
int32_t k_start = 0;
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
int32_t k_start =
|
||||
uint32_t k_start =
|
||||
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||
int32_t k_end =
|
||||
uint32_t k_end =
|
||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||
a_copy[ndim - 1] = k_end - k_start;
|
||||
b_copy[ndim - 2] = k_end - k_start;
|
||||
@@ -529,7 +528,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
segmented_mm<double>(
|
||||
a.data<double>(),
|
||||
b.data<double>(),
|
||||
segments.data<int32_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<double*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
@@ -547,7 +546,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
segmented_mm<float>(
|
||||
a.data<float>(),
|
||||
b.data<float>(),
|
||||
segments.data<int32_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
@@ -565,7 +564,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
segmented_mm<float16_t>(
|
||||
a.data<float16_t>(),
|
||||
b.data<float16_t>(),
|
||||
segments.data<int32_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
@@ -583,7 +582,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
segmented_mm<bfloat16_t>(
|
||||
a.data<bfloat16_t>(),
|
||||
b.data<bfloat16_t>(),
|
||||
segments.data<int32_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<bfloat16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
|
||||
@@ -19,7 +19,7 @@ template <
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device int32_t* segments [[buffer(2)]],
|
||||
const device uint32_t* segments [[buffer(2)]],
|
||||
device T* C [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
@@ -68,7 +68,7 @@ template <
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Move the pointers to the start of the segment
|
||||
int32_t k_start, k_end;
|
||||
uint32_t k_start, k_end;
|
||||
if (segments_contiguous) {
|
||||
k_start = segments[2 * tid.z];
|
||||
k_end = segments[2 * tid.z + 1];
|
||||
@@ -92,7 +92,7 @@ template <
|
||||
|
||||
// Matrix level alignment so only check K
|
||||
if (align_M && align_N) {
|
||||
int k = k_start + BK;
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -109,7 +109,7 @@ template <
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
@@ -125,7 +125,7 @@ template <
|
||||
} else {
|
||||
// Tile aligned do the same as above
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
int k = k_start + BK;
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -142,7 +142,7 @@ template <
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
@@ -159,7 +159,7 @@ template <
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
int k = k_start + BK;
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -177,7 +177,7 @@ template <
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
@@ -194,7 +194,7 @@ template <
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
int k = k_start + BK;
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -212,7 +212,7 @@ template <
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
@@ -229,7 +229,7 @@ template <
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
int k = k_start + BK;
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -248,7 +248,7 @@ template <
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
|
||||
@@ -4678,6 +4678,7 @@ array segmented_mm(
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
segments = astype(segments, uint32, s);
|
||||
|
||||
Shape out_shape = segments.shape();
|
||||
out_shape.pop_back();
|
||||
|
||||
@@ -5091,20 +5091,46 @@ std::vector<array> GatherMM::vjp(
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
|
||||
} else if (arg == 1) {
|
||||
// (M X K).T * M X N -> K X N
|
||||
auto base = zeros_like(primals[1], stream());
|
||||
auto at = swapaxes(primals[0], -1, -2, stream());
|
||||
if (sorted) {
|
||||
// Make the segments based on the rhs_indices
|
||||
int num_segments = primals[1].size() / K / N;
|
||||
auto segments = zeros({num_segments}, uint32, stream());
|
||||
segments = scatter_add_axis(
|
||||
segments, rhs_indices, array(M, uint32), 0, stream());
|
||||
segments = cumsum(segments, 0, false, true, stream());
|
||||
segments =
|
||||
concatenate({array({0}, {1}, uint32), segments}, 0, stream());
|
||||
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, stream());
|
||||
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, K, N}, stream());
|
||||
// Reshape and transpose the inputs such that they are a big segmented
|
||||
// matmul.
|
||||
auto a = reshape(primals[0], {-1, K}, stream());
|
||||
auto c = swapaxes(reshape(cotan, {-1, N}, stream()), 0, 1, stream());
|
||||
|
||||
// g : (out_batch_shape) + (K, N)
|
||||
auto g =
|
||||
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||
// Calculate the gradient.
|
||||
// Since the gather mm is often used as x @ w.T we will calculate the
|
||||
// gradient as c @ a and transpose it before returning it which should
|
||||
// save a copy in that case.
|
||||
auto g = segmented_mm(c, a, segments, stream());
|
||||
g = swapaxes(g, 1, 2, stream());
|
||||
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
vjps.push_back(reshape(g, primals[1].shape(), stream()));
|
||||
} else {
|
||||
// (M X K).T * M X N -> K X N
|
||||
auto base = zeros_like(primals[1], stream());
|
||||
auto at = swapaxes(primals[0], -1, -2, stream());
|
||||
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, K, N}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (K, N)
|
||||
auto g =
|
||||
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[GatherMM] Cannot calculate VJP with respect to indices.");
|
||||
|
||||
Reference in New Issue
Block a user