Adds nuclear norm support (#1894)

* adjust norm unit test tolerance
This commit is contained in:
Abe Leininger
2025-03-04 15:26:02 -06:00
committed by GitHub
parent 9680f72cca
commit 3835a428c5
11 changed files with 260 additions and 55 deletions

View File

@@ -102,8 +102,21 @@ inline array matrix_norm(
dtype,
s);
} else if (ord == 2.0 || ord == -2.0) {
throw std::runtime_error(
"[linalg::norm] Singular value norms are not implemented.");
row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];
col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];
auto a_matrix = (row_axis > col_axis)
? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)
: moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);
a_matrix = svd(a_matrix, false, s).at(0);
a_matrix = (ord == 2.0) ? max(a_matrix, -1, false, s)
: min(a_matrix, -1, false, s);
if (keepdims) {
std::vector<int> sorted_axes = (row_axis < col_axis)
? std::vector<int>{row_axis, col_axis}
: std::vector<int>{col_axis, row_axis};
a_matrix = expand_dims(a_matrix, sorted_axes, s);
}
return astype(a_matrix, dtype, s);
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
@@ -120,8 +133,19 @@ inline array matrix_norm(
if (ord == "f" || ord == "fro") {
return l2_norm(a, axis, keepdims, s);
} else if (ord == "nuc") {
throw std::runtime_error(
"[linalg::norm] Nuclear norm not yet implemented.");
int row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];
int col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];
auto a_matrix = (row_axis > col_axis)
? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)
: moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);
a_matrix = sum(svd(a_matrix, false, s).at(0), -1, false, s);
if (keepdims) {
std::vector<int> sorted_axes = (row_axis < col_axis)
? std::vector<int>{row_axis, col_axis}
: std::vector<int>{col_axis, row_axis};
a_matrix = expand_dims(a_matrix, sorted_axes, s);
}
return a_matrix;
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
@@ -214,7 +238,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
return std::make_pair(out[0], out[1]);
}
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
std::vector<array>
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::svd]");
check_float(a.dtype(), "[linalg::svd]");
@@ -230,14 +255,22 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
const auto n = a.shape(-1);
const auto rank = a.ndim();
auto u_shape = a.shape();
u_shape[rank - 2] = m;
u_shape[rank - 1] = m;
auto s_shape = a.shape();
s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n);
if (!compute_uv) {
return {array(
std::move(s_shape),
std::move(a.dtype()),
std::make_shared<SVD>(to_stream(s), compute_uv),
{a})};
}
auto u_shape = a.shape();
u_shape[rank - 2] = m;
u_shape[rank - 1] = m;
auto vt_shape = a.shape();
vt_shape[rank - 2] = n;
vt_shape[rank - 1] = n;
@@ -245,7 +278,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
return array::make_arrays(
{u_shape, s_shape, vt_shape},
{a.dtype(), a.dtype(), a.dtype()},
std::make_shared<SVD>(to_stream(s)),
std::make_shared<SVD>(to_stream(s), compute_uv),
{a});
}
@@ -323,7 +356,7 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
int m = a.shape(-2);
int n = a.shape(-1);
int k = std::min(m, n);
auto outs = linalg::svd(a, s);
auto outs = linalg::svd(a, true, s);
array U = outs[0];
array S = outs[1];
array V = outs[2];