Add vmap for SVD and inverse (#849)

This commit is contained in:
nicolov
2024-03-21 21:18:27 +01:00
committed by GitHub
parent 53e6a9367c
commit 105d236889
7 changed files with 116 additions and 5 deletions

View File

@@ -413,3 +413,35 @@ TEST_CASE("test vmap gather") {
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
}
}
TEST_CASE("test vmap SVD") {
auto fun = [](std::vector<array> inputs) {
return linalg::svd(inputs.at(0), Device::cpu);
};
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
// vmap over the second axis.
{
auto out = vmap(fun, /* in_axes = */ {1})({a});
const auto& U = out.at(0);
const auto& S = out.at(1);
const auto& Vt = out.at(2);
CHECK_EQ(U.shape(), std::vector<int>{a.shape(1), a.shape(0), a.shape(0)});
CHECK_EQ(S.shape(), std::vector<int>{a.shape(1), a.shape(2)});
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(1), a.shape(2), a.shape(2)});
}
// vmap over the third axis.
{
auto out = vmap(fun, /* in_axes = */ {2})({a});
const auto& U = out.at(0);
const auto& S = out.at(1);
const auto& Vt = out.at(2);
CHECK_EQ(U.shape(), std::vector<int>{a.shape(2), a.shape(0), a.shape(0)});
CHECK_EQ(S.shape(), std::vector<int>{a.shape(2), a.shape(0)});
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(2), a.shape(1), a.shape(1)});
}
}