scatter axis + gather axis primitives (#1813)

* scatter axis + gather axis primitives

* add transforms

* comment
This commit is contained in:
Awni Hannun
2025-01-31 20:48:08 -08:00
committed by GitHub
parent c6fc07f1f4
commit b7c9f1d38f
15 changed files with 1037 additions and 85 deletions

View File

@@ -1095,6 +1095,27 @@ class Gather : public UnaryPrimitive {
Shape slice_sizes_;
};
class GatherAxis : public UnaryPrimitive {
public:
explicit GatherAxis(Stream stream, int axis)
: UnaryPrimitive(stream), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(GatherAxis)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return axis_;
}
private:
int axis_;
};
class Greater : public UnaryPrimitive {
public:
explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
@@ -1786,6 +1807,41 @@ class Scatter : public UnaryPrimitive {
std::vector<int> axes_;
};
class ScatterAxis : public UnaryPrimitive {
public:
enum ReduceType { Sum, None };
explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
void print(std::ostream& os) override {
os << "ScatterAxis";
switch (reduce_type_) {
case Sum:
os << " Sum";
break;
case None:
break;
}
}
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::pair<ReduceType, int> state() const {
return {reduce_type_, axis_};
}
private:
ReduceType reduce_type_;
int axis_;
};
class Sigmoid : public UnaryPrimitive {
public:
explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}