Gather mm new kernel and small refactoring (#2040)

This commit is contained in:
Angelos Katharopoulos
2025-04-14 16:37:36 -07:00
committed by GitHub
parent e9e268336b
commit 99eefd2ec0
23 changed files with 1260 additions and 378 deletions

View File

@@ -498,7 +498,13 @@ class BlockMaskedMM : public UnaryPrimitive {
class GatherMM : public UnaryPrimitive {
public:
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {}
explicit GatherMM(
Stream stream,
bool left_sorted = false,
bool right_sorted = false)
: UnaryPrimitive(stream),
left_sorted_(left_sorted),
right_sorted_(right_sorted) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@@ -510,7 +516,14 @@ class GatherMM : public UnaryPrimitive {
const std::vector<array>& outputs) override;
DEFINE_PRINT(GatherMM)
DEFINE_DEFAULT_IS_EQUIVALENT()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(left_sorted_, right_sorted_);
}
private:
bool left_sorted_;
bool right_sorted_;
};
class BroadcastAxes : public UnaryPrimitive {