Eigenvalues and eigenvectors (#1334)

* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Kashif Rasul
2024-10-22 21:18:48 +02:00
committed by GitHub
parent c26208f67d
commit 3ddc07e936
23 changed files with 434 additions and 86 deletions

View File

@@ -454,4 +454,50 @@ array cross(
return concatenate(outputs, axis, s);
}
void validate_eigh(const array& a, const std::string fname) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must have type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname << " Arrays must have >= 2 dimensions. Received array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(fname + " Only defined for square matrices.");
}
}
array eigvalsh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigvalsh]");
std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1);
return array(
std::move(out_shape),
a.dtype(),
std::make_shared<Eigh>(to_stream(s), UPLO, false),
{a});
}
std::pair<array, array> eigh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigh]");
auto out = array::make_arrays(
{std::vector<int>(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()},
std::make_shared<Eigh>(to_stream(s), UPLO, true),
{a});
return std::make_pair(out[0], out[1]);
}
} // namespace mlx::core::linalg