mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
contiguous op / prim (#1612)
This commit is contained in:
@@ -639,6 +639,25 @@ class Conjugate : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Contiguous : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Contiguous(Stream stream, bool allow_col_major)
|
||||
: UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Contiguous)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
bool allow_col_major_;
|
||||
};
|
||||
|
||||
class Convolution : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Convolution(
|
||||
|
||||
Reference in New Issue
Block a user