mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add vmap for SVD and inverse (#849)
This commit is contained in:
@@ -413,3 +413,35 @@ TEST_CASE("test vmap gather") {
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap SVD") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return linalg::svd(inputs.at(0), Device::cpu);
|
||||
};
|
||||
|
||||
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
|
||||
|
||||
// vmap over the second axis.
|
||||
{
|
||||
auto out = vmap(fun, /* in_axes = */ {1})({a});
|
||||
const auto& U = out.at(0);
|
||||
const auto& S = out.at(1);
|
||||
const auto& Vt = out.at(2);
|
||||
|
||||
CHECK_EQ(U.shape(), std::vector<int>{a.shape(1), a.shape(0), a.shape(0)});
|
||||
CHECK_EQ(S.shape(), std::vector<int>{a.shape(1), a.shape(2)});
|
||||
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(1), a.shape(2), a.shape(2)});
|
||||
}
|
||||
|
||||
// vmap over the third axis.
|
||||
{
|
||||
auto out = vmap(fun, /* in_axes = */ {2})({a});
|
||||
const auto& U = out.at(0);
|
||||
const auto& S = out.at(1);
|
||||
const auto& Vt = out.at(2);
|
||||
|
||||
CHECK_EQ(U.shape(), std::vector<int>{a.shape(2), a.shape(0), a.shape(0)});
|
||||
CHECK_EQ(S.shape(), std::vector<int>{a.shape(2), a.shape(0)});
|
||||
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(2), a.shape(1), a.shape(1)});
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user