mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use segmented_mm to calculate the MoE gradient
This commit is contained in:
@@ -57,7 +57,7 @@ template <typename T>
|
|||||||
inline void segmented_mm(
|
inline void segmented_mm(
|
||||||
const T* a,
|
const T* a,
|
||||||
const T* b,
|
const T* b,
|
||||||
const int32_t* segments,
|
const uint32_t* segments,
|
||||||
T* out,
|
T* out,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
bool b_transposed,
|
bool b_transposed,
|
||||||
@@ -75,11 +75,10 @@ inline void segmented_mm(
|
|||||||
Shape b_copy = b_shape;
|
Shape b_copy = b_shape;
|
||||||
int32_t M = a_copy[ndim - 2];
|
int32_t M = a_copy[ndim - 2];
|
||||||
int32_t N = b_copy[ndim - 1];
|
int32_t N = b_copy[ndim - 1];
|
||||||
int32_t k_start = 0;
|
|
||||||
for (int i = 0; i < num_segments; i++) {
|
for (int i = 0; i < num_segments; i++) {
|
||||||
int32_t k_start =
|
uint32_t k_start =
|
||||||
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||||
int32_t k_end =
|
uint32_t k_end =
|
||||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||||
a_copy[ndim - 1] = k_end - k_start;
|
a_copy[ndim - 1] = k_end - k_start;
|
||||||
b_copy[ndim - 2] = k_end - k_start;
|
b_copy[ndim - 2] = k_end - k_start;
|
||||||
@@ -529,7 +528,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
segmented_mm<double>(
|
segmented_mm<double>(
|
||||||
a.data<double>(),
|
a.data<double>(),
|
||||||
b.data<double>(),
|
b.data<double>(),
|
||||||
segments.data<int32_t>(),
|
segments.data<uint32_t>(),
|
||||||
static_cast<double*>(out_ptr),
|
static_cast<double*>(out_ptr),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
b_transposed,
|
b_transposed,
|
||||||
@@ -547,7 +546,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
segmented_mm<float>(
|
segmented_mm<float>(
|
||||||
a.data<float>(),
|
a.data<float>(),
|
||||||
b.data<float>(),
|
b.data<float>(),
|
||||||
segments.data<int32_t>(),
|
segments.data<uint32_t>(),
|
||||||
static_cast<float*>(out_ptr),
|
static_cast<float*>(out_ptr),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
b_transposed,
|
b_transposed,
|
||||||
@@ -565,7 +564,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
segmented_mm<float16_t>(
|
segmented_mm<float16_t>(
|
||||||
a.data<float16_t>(),
|
a.data<float16_t>(),
|
||||||
b.data<float16_t>(),
|
b.data<float16_t>(),
|
||||||
segments.data<int32_t>(),
|
segments.data<uint32_t>(),
|
||||||
static_cast<float16_t*>(out_ptr),
|
static_cast<float16_t*>(out_ptr),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
b_transposed,
|
b_transposed,
|
||||||
@@ -583,7 +582,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
segmented_mm<bfloat16_t>(
|
segmented_mm<bfloat16_t>(
|
||||||
a.data<bfloat16_t>(),
|
a.data<bfloat16_t>(),
|
||||||
b.data<bfloat16_t>(),
|
b.data<bfloat16_t>(),
|
||||||
segments.data<int32_t>(),
|
segments.data<uint32_t>(),
|
||||||
static_cast<bfloat16_t*>(out_ptr),
|
static_cast<bfloat16_t*>(out_ptr),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
b_transposed,
|
b_transposed,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ template <
|
|||||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
|
||||||
const device T* A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
const device T* B [[buffer(1)]],
|
const device T* B [[buffer(1)]],
|
||||||
const device int32_t* segments [[buffer(2)]],
|
const device uint32_t* segments [[buffer(2)]],
|
||||||
device T* C [[buffer(3)]],
|
device T* C [[buffer(3)]],
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
@@ -68,7 +68,7 @@ template <
|
|||||||
C += c_row_long * params->ldd + c_col_long;
|
C += c_row_long * params->ldd + c_col_long;
|
||||||
|
|
||||||
// Move the pointers to the start of the segment
|
// Move the pointers to the start of the segment
|
||||||
int32_t k_start, k_end;
|
uint32_t k_start, k_end;
|
||||||
if (segments_contiguous) {
|
if (segments_contiguous) {
|
||||||
k_start = segments[2 * tid.z];
|
k_start = segments[2 * tid.z];
|
||||||
k_end = segments[2 * tid.z + 1];
|
k_end = segments[2 * tid.z + 1];
|
||||||
@@ -92,7 +92,7 @@ template <
|
|||||||
|
|
||||||
// Matrix level alignment so only check K
|
// Matrix level alignment so only check K
|
||||||
if (align_M && align_N) {
|
if (align_M && align_N) {
|
||||||
int k = k_start + BK;
|
uint32_t k = k_start + BK;
|
||||||
for (; k <= k_end; k += BK) {
|
for (; k <= k_end; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ template <
|
|||||||
loader_a.next();
|
loader_a.next();
|
||||||
loader_b.next();
|
loader_b.next();
|
||||||
}
|
}
|
||||||
short k_remain = k_end - (k - BK);
|
short k_remain = BK - short(k - k_end);
|
||||||
const short2 tile_dims_A =
|
const short2 tile_dims_A =
|
||||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
const short2 tile_dims_B =
|
const short2 tile_dims_B =
|
||||||
@@ -125,7 +125,7 @@ template <
|
|||||||
} else {
|
} else {
|
||||||
// Tile aligned do the same as above
|
// Tile aligned do the same as above
|
||||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||||
int k = k_start + BK;
|
uint32_t k = k_start + BK;
|
||||||
for (; k <= k_end; k += BK) {
|
for (; k <= k_end; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@@ -142,7 +142,7 @@ template <
|
|||||||
loader_a.next();
|
loader_a.next();
|
||||||
loader_b.next();
|
loader_b.next();
|
||||||
}
|
}
|
||||||
short k_remain = k_end - (k - BK);
|
short k_remain = BK - short(k - k_end);
|
||||||
const short2 tile_dims_A =
|
const short2 tile_dims_A =
|
||||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
const short2 tile_dims_B =
|
const short2 tile_dims_B =
|
||||||
@@ -159,7 +159,7 @@ template <
|
|||||||
|
|
||||||
// Tile partially aligned check rows
|
// Tile partially aligned check rows
|
||||||
else if (align_N || tgp_bn == BN) {
|
else if (align_N || tgp_bn == BN) {
|
||||||
int k = k_start + BK;
|
uint32_t k = k_start + BK;
|
||||||
for (; k <= k_end; k += BK) {
|
for (; k <= k_end; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@@ -177,7 +177,7 @@ template <
|
|||||||
loader_a.next();
|
loader_a.next();
|
||||||
loader_b.next();
|
loader_b.next();
|
||||||
}
|
}
|
||||||
short k_remain = k_end - (k - BK);
|
short k_remain = BK - short(k - k_end);
|
||||||
const short2 tile_dims_A =
|
const short2 tile_dims_A =
|
||||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
const short2 tile_dims_B =
|
const short2 tile_dims_B =
|
||||||
@@ -194,7 +194,7 @@ template <
|
|||||||
|
|
||||||
// Tile partially aligned check cols
|
// Tile partially aligned check cols
|
||||||
else if (align_M || tgp_bm == BM) {
|
else if (align_M || tgp_bm == BM) {
|
||||||
int k = k_start + BK;
|
uint32_t k = k_start + BK;
|
||||||
for (; k <= k_end; k += BK) {
|
for (; k <= k_end; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@@ -212,7 +212,7 @@ template <
|
|||||||
loader_a.next();
|
loader_a.next();
|
||||||
loader_b.next();
|
loader_b.next();
|
||||||
}
|
}
|
||||||
short k_remain = k_end - (k - BK);
|
short k_remain = BK - short(k - k_end);
|
||||||
const short2 tile_dims_A =
|
const short2 tile_dims_A =
|
||||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
const short2 tile_dims_B =
|
const short2 tile_dims_B =
|
||||||
@@ -229,7 +229,7 @@ template <
|
|||||||
|
|
||||||
// Nothing aligned so check both rows and cols
|
// Nothing aligned so check both rows and cols
|
||||||
else {
|
else {
|
||||||
int k = k_start + BK;
|
uint32_t k = k_start + BK;
|
||||||
for (; k <= k_end; k += BK) {
|
for (; k <= k_end; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@@ -248,7 +248,7 @@ template <
|
|||||||
loader_a.next();
|
loader_a.next();
|
||||||
loader_b.next();
|
loader_b.next();
|
||||||
}
|
}
|
||||||
short k_remain = k_end - (k - BK);
|
short k_remain = BK - short(k - k_end);
|
||||||
const short2 tile_dims_A =
|
const short2 tile_dims_A =
|
||||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
const short2 tile_dims_B =
|
const short2 tile_dims_B =
|
||||||
|
|||||||
@@ -4678,6 +4678,7 @@ array segmented_mm(
|
|||||||
|
|
||||||
a = astype(a, out_type, s);
|
a = astype(a, out_type, s);
|
||||||
b = astype(b, out_type, s);
|
b = astype(b, out_type, s);
|
||||||
|
segments = astype(segments, uint32, s);
|
||||||
|
|
||||||
Shape out_shape = segments.shape();
|
Shape out_shape = segments.shape();
|
||||||
out_shape.pop_back();
|
out_shape.pop_back();
|
||||||
|
|||||||
@@ -5091,6 +5091,31 @@ std::vector<array> GatherMM::vjp(
|
|||||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||||
|
|
||||||
} else if (arg == 1) {
|
} else if (arg == 1) {
|
||||||
|
if (sorted) {
|
||||||
|
// Make the segments based on the rhs_indices
|
||||||
|
int num_segments = primals[1].size() / K / N;
|
||||||
|
auto segments = zeros({num_segments}, uint32, stream());
|
||||||
|
segments = scatter_add_axis(
|
||||||
|
segments, rhs_indices, array(M, uint32), 0, stream());
|
||||||
|
segments = cumsum(segments, 0, false, true, stream());
|
||||||
|
segments =
|
||||||
|
concatenate({array({0}, {1}, uint32), segments}, 0, stream());
|
||||||
|
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, stream());
|
||||||
|
|
||||||
|
// Reshape and transpose the inputs such that they are a big segmented
|
||||||
|
// matmul.
|
||||||
|
auto a = reshape(primals[0], {-1, K}, stream());
|
||||||
|
auto c = swapaxes(reshape(cotan, {-1, N}, stream()), 0, 1, stream());
|
||||||
|
|
||||||
|
// Calculate the gradient.
|
||||||
|
// Since the gather mm is often used as x @ w.T we will calculate the
|
||||||
|
// gradient as c @ a and transpose it before returning it which should
|
||||||
|
// save a copy in that case.
|
||||||
|
auto g = segmented_mm(c, a, segments, stream());
|
||||||
|
g = swapaxes(g, 1, 2, stream());
|
||||||
|
|
||||||
|
vjps.push_back(reshape(g, primals[1].shape(), stream()));
|
||||||
|
} else {
|
||||||
// (M X K).T * M X N -> K X N
|
// (M X K).T * M X N -> K X N
|
||||||
auto base = zeros_like(primals[1], stream());
|
auto base = zeros_like(primals[1], stream());
|
||||||
auto at = swapaxes(primals[0], -1, -2, stream());
|
auto at = swapaxes(primals[0], -1, -2, stream());
|
||||||
@@ -5105,6 +5130,7 @@ std::vector<array> GatherMM::vjp(
|
|||||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||||
|
|
||||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[GatherMM] Cannot calculate VJP with respect to indices.");
|
"[GatherMM] Cannot calculate VJP with respect to indices.");
|
||||||
|
|||||||
Reference in New Issue
Block a user