Add complex eigh (#2191)

This commit is contained in:
Angelos Katharopoulos
2025-05-18 00:18:43 -07:00
committed by GitHub
parent 48ef3e74e2
commit 0654543dcc
5 changed files with 190 additions and 55 deletions

View File

@@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) {
}
}
void check_float_or_complex(Dtype dtype, const std::string& prefix) {
if (dtype != float32 && dtype != float64 && dtype != complex64) {
std::ostringstream msg;
msg << prefix << " Arrays must have type float32, float64 or complex64. "
<< "Received array with type " << dtype << ".";
throw std::invalid_argument(msg.str());
}
}
Dtype at_least_float(const Dtype& d) {
return issubdtype(d, inexact) ? d : promote_types(d, float32);
}
@@ -493,7 +502,7 @@ void validate_eig(
const StreamOrDevice& stream,
const std::string fname) {
check_cpu_stream(stream, fname);
check_float(a.dtype(), fname);
check_float_or_complex(a.dtype(), fname);
if (a.ndim() < 2) {
std::ostringstream msg;
@@ -513,9 +522,10 @@ array eigvalsh(
StreamOrDevice s /* = {} */) {
validate_eig(a, s, "[linalg::eigvalsh]");
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
return array(
std::move(out_shape),
a.dtype(),
eigval_type,
std::make_shared<Eigh>(to_stream(s), UPLO, false),
{a});
}
@@ -525,9 +535,10 @@ std::pair<array, array> eigh(
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eig(a, s, "[linalg::eigh]");
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
auto out = array::make_arrays(
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()},
{eigval_type, a.dtype()},
std::make_shared<Eigh>(to_stream(s), UPLO, true),
{a});
return std::make_pair(out[0], out[1]);