mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
ExpandDims primitive (#1687)
* add squeeze primitive * simplify squeeze, use in gather * fix * fix * fix * fix * fix no cpu * use squeeze in matmul and friends * expand dims primitive * comment
This commit is contained in:
@@ -983,6 +983,28 @@ class Expm1 : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ExpandDims : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ExpandDims(Stream stream, std::vector<int> axes)
|
||||
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ExpandDims)
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
std::vector<int> axes_;
|
||||
};
|
||||
|
||||
class FFT : public UnaryPrimitive {
|
||||
public:
|
||||
explicit FFT(
|
||||
@@ -1046,9 +1068,11 @@ class Gather : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Gather(
|
||||
Stream stream,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& slice_sizes)
|
||||
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
|
||||
std::vector<int> axes,
|
||||
std::vector<int> slice_sizes)
|
||||
: UnaryPrimitive(stream),
|
||||
axes_(std::move(axes)),
|
||||
slice_sizes_(std::move(slice_sizes)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -2057,6 +2081,28 @@ class Subtract : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Squeeze : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Squeeze(Stream stream, std::vector<int> axes)
|
||||
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Squeeze)
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
std::vector<int> axes_;
|
||||
};
|
||||
|
||||
class Tan : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
Reference in New Issue
Block a user