non-symmetric eig and eigh (#2188)

This commit is contained in:
Awni Hannun
2025-05-15 13:01:44 -07:00
committed by GitHub
parent cf6c939e86
commit c1eb9d05d9
14 changed files with 423 additions and 5 deletions

View File

@@ -488,7 +488,7 @@ array cross(
return concatenate(outputs, axis, s);
}
void validate_eigh(
void validate_eig(
const array& a,
const StreamOrDevice& stream,
const std::string fname) {
@@ -511,7 +511,7 @@ array eigvalsh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, s, "[linalg::eigvalsh]");
validate_eig(a, s, "[linalg::eigvalsh]");
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
return array(
std::move(out_shape),
@@ -524,7 +524,7 @@ std::pair<array, array> eigh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, s, "[linalg::eigh]");
validate_eig(a, s, "[linalg::eigh]");
auto out = array::make_arrays(
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()},
@@ -533,6 +533,26 @@ std::pair<array, array> eigh(
return std::make_pair(out[0], out[1]);
}
array eigvals(const array& a, StreamOrDevice s /* = {} */) {
validate_eig(a, s, "[linalg::eigvals]");
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
return array(
std::move(out_shape),
complex64,
std::make_shared<Eig>(to_stream(s), false),
{a});
}
std::pair<array, array> eig(const array& a, StreamOrDevice s /* = {} */) {
validate_eig(a, s, "[linalg::eig]");
auto out = array::make_arrays(
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{complex64, complex64},
std::make_shared<Eig>(to_stream(s), true),
{a});
return std::make_pair(out[0], out[1]);
}
void validate_lu(
const array& a,
const StreamOrDevice& stream,