non-symmetric eig and eigh (#2188)

This commit is contained in:
Awni Hannun
2025-05-15 13:01:44 -07:00
committed by GitHub
parent cf6c939e86
commit c1eb9d05d9
14 changed files with 423 additions and 5 deletions

View File

@@ -2381,6 +2381,29 @@ class Cholesky : public UnaryPrimitive {
bool upper_;
};
class Eig : public Primitive {
public:
explicit Eig(Stream stream, bool compute_eigenvectors)
: Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_PRINT(Eig)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return compute_eigenvectors_;
}
private:
bool compute_eigenvectors_;
};
class Eigh : public Primitive {
public:
explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)