Eigenvalues and eigenvectors (#1334)

* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Kashif Rasul
2024-10-22 21:18:48 +02:00
committed by GitHub
parent c26208f67d
commit 3ddc07e936
23 changed files with 434 additions and 86 deletions

View File

@@ -767,6 +767,27 @@ std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
}
std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
bool needs_move = axes[0] >= (inputs[0].ndim() - 2);
auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
auto ax = needs_move ? 0 : axes[0];
std::vector<array> outputs;
if (compute_eigenvectors_) {
auto [values, vectors] = linalg::eigh(a, uplo_, stream());
outputs = {values, vectors};
} else {
outputs = {linalg::eigvalsh(a, uplo_, stream())};
}
return {outputs, std::vector<int>(outputs.size(), ax)};
}
std::vector<array> Concatenate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,