mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
non-symmetric eig and eigh (#2188)
This commit is contained in:
@@ -488,7 +488,7 @@ array cross(
|
||||
return concatenate(outputs, axis, s);
|
||||
}
|
||||
|
||||
void validate_eigh(
|
||||
void validate_eig(
|
||||
const array& a,
|
||||
const StreamOrDevice& stream,
|
||||
const std::string fname) {
|
||||
@@ -511,7 +511,7 @@ array eigvalsh(
|
||||
const array& a,
|
||||
std::string UPLO /* = "L" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eigh(a, s, "[linalg::eigvalsh]");
|
||||
validate_eig(a, s, "[linalg::eigvalsh]");
|
||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
@@ -524,7 +524,7 @@ std::pair<array, array> eigh(
|
||||
const array& a,
|
||||
std::string UPLO /* = "L" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eigh(a, s, "[linalg::eigh]");
|
||||
validate_eig(a, s, "[linalg::eigh]");
|
||||
auto out = array::make_arrays(
|
||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{a.dtype(), a.dtype()},
|
||||
@@ -533,6 +533,26 @@ std::pair<array, array> eigh(
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
array eigvals(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_eig(a, s, "[linalg::eigvals]");
|
||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
complex64,
|
||||
std::make_shared<Eig>(to_stream(s), false),
|
||||
{a});
|
||||
}
|
||||
|
||||
std::pair<array, array> eig(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_eig(a, s, "[linalg::eig]");
|
||||
auto out = array::make_arrays(
|
||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{complex64, complex64},
|
||||
std::make_shared<Eig>(to_stream(s), true),
|
||||
{a});
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
void validate_lu(
|
||||
const array& a,
|
||||
const StreamOrDevice& stream,
|
||||
|
||||
Reference in New Issue
Block a user