Use segmented_mm to calculate the MoE gradient

This commit is contained in:
Angelos Katharopoulos
2025-07-02 16:24:21 -07:00
parent 8f771efb82
commit a29fa053c6
4 changed files with 57 additions and 31 deletions

View File

@@ -57,7 +57,7 @@ template <typename T>
inline void segmented_mm(
const T* a,
const T* b,
const int32_t* segments,
const uint32_t* segments,
T* out,
bool a_transposed,
bool b_transposed,
@@ -75,11 +75,10 @@ inline void segmented_mm(
Shape b_copy = b_shape;
int32_t M = a_copy[ndim - 2];
int32_t N = b_copy[ndim - 1];
int32_t k_start = 0;
for (int i = 0; i < num_segments; i++) {
int32_t k_start =
uint32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
int32_t k_end =
uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start;
@@ -529,7 +528,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<double>(
a.data<double>(),
b.data<double>(),
segments.data<int32_t>(),
segments.data<uint32_t>(),
static_cast<double*>(out_ptr),
a_transposed,
b_transposed,
@@ -547,7 +546,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<float>(
a.data<float>(),
b.data<float>(),
segments.data<int32_t>(),
segments.data<uint32_t>(),
static_cast<float*>(out_ptr),
a_transposed,
b_transposed,
@@ -565,7 +564,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<float16_t>(
a.data<float16_t>(),
b.data<float16_t>(),
segments.data<int32_t>(),
segments.data<uint32_t>(),
static_cast<float16_t*>(out_ptr),
a_transposed,
b_transposed,
@@ -583,7 +582,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
segmented_mm<bfloat16_t>(
a.data<bfloat16_t>(),
b.data<bfloat16_t>(),
segments.data<int32_t>(),
segments.data<uint32_t>(),
static_cast<bfloat16_t*>(out_ptr),
a_transposed,
b_transposed,