ExpandDims primitive (#1687)

* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
This commit is contained in:
Awni Hannun
2024-12-10 16:39:07 -08:00
committed by GitHub
parent 310ad8d9db
commit f76a49e555
13 changed files with 373 additions and 184 deletions

View File

@@ -983,6 +983,28 @@ class Expm1 : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class ExpandDims : public UnaryPrimitive {
public:
explicit ExpandDims(Stream stream, std::vector<int> axes)
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(ExpandDims)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_;
};
class FFT : public UnaryPrimitive {
public:
explicit FFT(
@@ -1046,9 +1068,11 @@ class Gather : public UnaryPrimitive {
public:
explicit Gather(
Stream stream,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes)
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
std::vector<int> axes,
std::vector<int> slice_sizes)
: UnaryPrimitive(stream),
axes_(std::move(axes)),
slice_sizes_(std::move(slice_sizes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@@ -2057,6 +2081,28 @@ class Subtract : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Squeeze : public UnaryPrimitive {
public:
explicit Squeeze(Stream stream, std::vector<int> axes)
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Squeeze)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_;
};
class Tan : public UnaryPrimitive {
public:
explicit Tan(Stream stream) : UnaryPrimitive(stream) {}