mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use segmented_mm to calculate the MoE gradient
This commit is contained in:
@@ -57,7 +57,7 @@ template <typename T>
|
||||
inline void segmented_mm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
const int32_t* segments,
|
||||
const uint32_t* segments,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
@@ -75,11 +75,10 @@ inline void segmented_mm(
|
||||
Shape b_copy = b_shape;
|
||||
int32_t M = a_copy[ndim - 2];
|
||||
int32_t N = b_copy[ndim - 1];
|
||||
int32_t k_start = 0;
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
int32_t k_start =
|
||||
uint32_t k_start =
|
||||
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||
int32_t k_end =
|
||||
uint32_t k_end =
|
||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||
a_copy[ndim - 1] = k_end - k_start;
|
||||
b_copy[ndim - 2] = k_end - k_start;
|
||||
@@ -529,7 +528,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
segmented_mm<double>(
|
||||
a.data<double>(),
|
||||
b.data<double>(),
|
||||
segments.data<int32_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<double*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
@@ -547,7 +546,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
segmented_mm<float>(
|
||||
a.data<float>(),
|
||||
b.data<float>(),
|
||||
segments.data<int32_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
@@ -565,7 +564,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
segmented_mm<float16_t>(
|
||||
a.data<float16_t>(),
|
||||
b.data<float16_t>(),
|
||||
segments.data<int32_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
@@ -583,7 +582,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
segmented_mm<bfloat16_t>(
|
||||
a.data<bfloat16_t>(),
|
||||
b.data<bfloat16_t>(),
|
||||
segments.data<int32_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<bfloat16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
|
||||
Reference in New Issue
Block a user