Eigenvalues and eigenvectors (#1334)

* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Kashif Rasul
2024-10-22 21:18:48 +02:00
committed by GitHub
parent c26208f67d
commit 3ddc07e936
23 changed files with 434 additions and 86 deletions

View File

@@ -2196,4 +2196,44 @@ class Cholesky : public UnaryPrimitive {
bool upper_;
};
class Eigh : public Primitive {
public:
explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
: Primitive(stream),
uplo_(std::move(uplo)),
compute_eigenvectors_(compute_eigenvectors) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_PRINT(Eigh)
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
auto shape = inputs[0].shape();
shape.pop_back(); // Remove last dimension for eigenvalues
if (compute_eigenvectors_) {
return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors
} else {
return {shape}; // Only eigenvalues
}
}
bool is_equivalent(const Primitive& other) const override {
if (auto* p = dynamic_cast<const Eigh*>(&other)) {
return uplo_ == p->uplo_ &&
compute_eigenvectors_ == p->compute_eigenvectors_;
}
return false;
}
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::string uplo_;
bool compute_eigenvectors_;
};
} // namespace mlx::core