mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -219,15 +219,15 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
const auto n = a.shape(-1);
|
||||
const auto rank = a.ndim();
|
||||
|
||||
std::vector<int> u_shape = a.shape();
|
||||
auto u_shape = a.shape();
|
||||
u_shape[rank - 2] = m;
|
||||
u_shape[rank - 1] = m;
|
||||
|
||||
std::vector<int> s_shape = a.shape();
|
||||
auto s_shape = a.shape();
|
||||
s_shape.pop_back();
|
||||
s_shape[rank - 2] = std::min(m, n);
|
||||
|
||||
std::vector<int> vt_shape = a.shape();
|
||||
auto vt_shape = a.shape();
|
||||
vt_shape[rank - 2] = n;
|
||||
vt_shape[rank - 1] = n;
|
||||
|
||||
@@ -328,8 +328,8 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
array S = outs[1];
|
||||
array V = outs[2];
|
||||
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
std::vector<int> ends = a.shape();
|
||||
Shape starts(a.ndim(), 0);
|
||||
auto ends = a.shape();
|
||||
int i = a.ndim() - 2;
|
||||
int j = a.ndim() - 1;
|
||||
|
||||
@@ -479,7 +479,7 @@ array eigvalsh(
|
||||
std::string UPLO /* = "L" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eigh(a, "[linalg::eigvalsh]");
|
||||
std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
@@ -493,7 +493,7 @@ std::pair<array, array> eigh(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eigh(a, "[linalg::eigh]");
|
||||
auto out = array::make_arrays(
|
||||
{std::vector<int>(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{a.dtype(), a.dtype()},
|
||||
std::make_shared<Eigh>(to_stream(s), UPLO, true),
|
||||
{a});
|
||||
|
||||
Reference in New Issue
Block a user