Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -1031,6 +1031,28 @@ class FFT : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Flatten : public UnaryPrimitive {
public:
explicit Flatten(Stream stream, int start_axis, int end_axis)
: UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Flatten)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, int start_axis, int end_axis);
private:
int start_axis_;
int end_axis_;
void eval(const std::vector<array>& inputs, array& out);
};
class Floor : public UnaryPrimitive {
public:
explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
@@ -1643,16 +1665,6 @@ class Reshape : public UnaryPrimitive {
private:
Shape shape_;
void eval(const std::vector<array>& inputs, array& out);
static std::pair<bool, Strides> prepare_reshape(
const array& in,
const array& out);
static void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
};
class Reduce : public UnaryPrimitive {
@@ -2137,6 +2149,28 @@ class Tanh : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Unflatten : public UnaryPrimitive {
public:
explicit Unflatten(Stream stream, int axis, Shape shape)
: UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Unflatten)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, int axis, const Shape& shape);
private:
int axis_;
Shape shape_;
void eval(const std::vector<array>& inputs, array& out);
};
class Uniform : public UnaryPrimitive {
public:
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}