mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add complex eigh (#2191)
This commit is contained in:
committed by
GitHub
parent
48ef3e74e2
commit
0654543dcc
@@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) {
|
||||
}
|
||||
}
|
||||
|
||||
void check_float_or_complex(Dtype dtype, const std::string& prefix) {
|
||||
if (dtype != float32 && dtype != float64 && dtype != complex64) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Arrays must have type float32, float64 or complex64. "
|
||||
<< "Received array with type " << dtype << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
@@ -493,7 +502,7 @@ void validate_eig(
|
||||
const StreamOrDevice& stream,
|
||||
const std::string fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
check_float(a.dtype(), fname);
|
||||
check_float_or_complex(a.dtype(), fname);
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@@ -513,9 +522,10 @@ array eigvalsh(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eig(a, s, "[linalg::eigvalsh]");
|
||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
eigval_type,
|
||||
std::make_shared<Eigh>(to_stream(s), UPLO, false),
|
||||
{a});
|
||||
}
|
||||
@@ -525,9 +535,10 @@ std::pair<array, array> eigh(
|
||||
std::string UPLO /* = "L" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eig(a, s, "[linalg::eigh]");
|
||||
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
|
||||
auto out = array::make_arrays(
|
||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{a.dtype(), a.dtype()},
|
||||
{eigval_type, a.dtype()},
|
||||
std::make_shared<Eigh>(to_stream(s), UPLO, true),
|
||||
{a});
|
||||
return std::make_pair(out[0], out[1]);
|
||||
|
||||
Reference in New Issue
Block a user