Real and Imag (#1490)

* real and imag

* fix

* fix
This commit is contained in:
Awni Hannun
2024-10-15 16:23:15 -07:00
committed by GitHub
parent 2b8ace6a03
commit 3f86399922
21 changed files with 275 additions and 46 deletions

View File

@@ -1106,6 +1106,20 @@ class Hadamard : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Imag : public UnaryPrimitive {
public:
explicit Imag(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Imag)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
};
class Less : public UnaryPrimitive {
public:
explicit Less(Stream stream) : UnaryPrimitive(stream) {}
@@ -1561,6 +1575,20 @@ class RandomBits : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Real : public UnaryPrimitive {
public:
explicit Real(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Real)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
};
class Reshape : public UnaryPrimitive {
public:
explicit Reshape(Stream stream, const std::vector<int>& shape)