MoE backward improvements (#2335)

This commit is contained in:
Angelos Katharopoulos
2025-07-07 17:59:53 -07:00
committed by GitHub
parent a4fcc893cd
commit 4a9b29a875
22 changed files with 1130 additions and 60 deletions

View File

@@ -526,6 +526,16 @@ class GatherMM : public UnaryPrimitive {
bool right_sorted_;
};
class SegmentedMM : public UnaryPrimitive {
public:
explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_PRINT(SegmentedMM)
};
class BroadcastAxes : public UnaryPrimitive {
public:
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})