More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -219,15 +219,15 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
const auto n = a.shape(-1);
const auto rank = a.ndim();
std::vector<int> u_shape = a.shape();
auto u_shape = a.shape();
u_shape[rank - 2] = m;
u_shape[rank - 1] = m;
std::vector<int> s_shape = a.shape();
auto s_shape = a.shape();
s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n);
std::vector<int> vt_shape = a.shape();
auto vt_shape = a.shape();
vt_shape[rank - 2] = n;
vt_shape[rank - 1] = n;
@@ -328,8 +328,8 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
array S = outs[1];
array V = outs[2];
std::vector<int> starts(a.ndim(), 0);
std::vector<int> ends = a.shape();
Shape starts(a.ndim(), 0);
auto ends = a.shape();
int i = a.ndim() - 2;
int j = a.ndim() - 1;
@@ -479,7 +479,7 @@ array eigvalsh(
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigvalsh]");
std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1);
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
return array(
std::move(out_shape),
a.dtype(),
@@ -493,7 +493,7 @@ std::pair<array, array> eigh(
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigh]");
auto out = array::make_arrays(
{std::vector<int>(a.shape().begin(), a.shape().end() - 1), a.shape()},
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()},
std::make_shared<Eigh>(to_stream(s), UPLO, true),
{a});