contiguous op / prim (#1612)

This commit is contained in:
Awni Hannun
2024-11-21 19:51:49 -08:00
committed by GitHub
parent 0d5e7716ad
commit dcca0d7477
11 changed files with 104 additions and 25 deletions

View File

@@ -639,6 +639,25 @@ class Conjugate : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Contiguous : public UnaryPrimitive {
public:
explicit Contiguous(Stream stream, bool allow_col_major)
: UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Contiguous)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private:
bool allow_col_major_;
};
class Convolution : public UnaryPrimitive {
public:
explicit Convolution(