Adds nuclear norm support (#1894)

* adjust norm unit test tolerance
This commit is contained in:
Abe Leininger
2025-03-04 15:26:02 -06:00
committed by GitHub
parent 9680f72cca
commit 3835a428c5
11 changed files with 260 additions and 55 deletions

View File

@@ -466,15 +466,19 @@ TEST_CASE("test vmap scatter") {
}
TEST_CASE("test vmap SVD") {
auto fun = [](std::vector<array> inputs) {
return linalg::svd(inputs.at(0), Device::cpu);
auto svd_full = [](std::vector<array> inputs) {
return linalg::svd(inputs.at(0), true, Device::cpu);
};
auto svd_singular = [](std::vector<array> inputs) {
return linalg::svd(inputs.at(0), false, Device::cpu);
};
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
// vmap over the second axis.
{
auto out = vmap(fun, /* in_axes = */ {1})({a});
auto out = vmap(svd_full, /* in_axes = */ {1})({a});
const auto& U = out.at(0);
const auto& S = out.at(1);
const auto& Vt = out.at(2);
@@ -486,7 +490,7 @@ TEST_CASE("test vmap SVD") {
// vmap over the third axis.
{
auto out = vmap(fun, /* in_axes = */ {2})({a});
auto out = vmap(svd_full, /* in_axes = */ {2})({a});
const auto& U = out.at(0);
const auto& S = out.at(1);
const auto& Vt = out.at(2);
@@ -495,6 +499,21 @@ TEST_CASE("test vmap SVD") {
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
}
// test singular values
{
auto out = vmap(svd_singular, /* in_axes = */ {1})({a});
const auto& S = out.at(0);
CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)});
}
{
auto out = vmap(svd_singular, /* in_axes = */ {2})({a});
const auto& S = out.at(0);
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
}
}
TEST_CASE("test vmap dynamic slices") {