mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Flatten and unflatten (#1692)
* flatten and unflatten * fix grad * fix shape infer * use squeeze + unsqueeze in get_item
This commit is contained in:
@@ -1031,6 +1031,28 @@ class FFT : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Flatten : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Flatten(Stream stream, int start_axis, int end_axis)
|
||||
: UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Flatten)
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, int start_axis, int end_axis);
|
||||
|
||||
private:
|
||||
int start_axis_;
|
||||
int end_axis_;
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Floor : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
|
||||
@@ -1643,16 +1665,6 @@ class Reshape : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
static std::pair<bool, Strides> prepare_reshape(
|
||||
const array& in,
|
||||
const array& out);
|
||||
static void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
};
|
||||
|
||||
class Reduce : public UnaryPrimitive {
|
||||
@@ -2137,6 +2149,28 @@ class Tanh : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Unflatten : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Unflatten(Stream stream, int axis, Shape shape)
|
||||
: UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Unflatten)
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, int axis, const Shape& shape);
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
Shape shape_;
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Uniform : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
Reference in New Issue
Block a user