mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
@@ -1095,6 +1095,27 @@ class Gather : public UnaryPrimitive {
|
||||
Shape slice_sizes_;
|
||||
};
|
||||
|
||||
class GatherAxis : public UnaryPrimitive {
|
||||
public:
|
||||
explicit GatherAxis(Stream stream, int axis)
|
||||
: UnaryPrimitive(stream), axis_(axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(GatherAxis)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return axis_;
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
};
|
||||
|
||||
class Greater : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
|
||||
@@ -1786,6 +1807,41 @@ class Scatter : public UnaryPrimitive {
|
||||
std::vector<int> axes_;
|
||||
};
|
||||
|
||||
class ScatterAxis : public UnaryPrimitive {
|
||||
public:
|
||||
enum ReduceType { Sum, None };
|
||||
|
||||
explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
os << "ScatterAxis";
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
os << " Sum";
|
||||
break;
|
||||
case None:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::pair<ReduceType, int> state() const {
|
||||
return {reduce_type_, axis_};
|
||||
}
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
int axis_;
|
||||
};
|
||||
|
||||
class Sigmoid : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
Reference in New Issue
Block a user