mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 23:15:09 +08:00
@@ -466,15 +466,19 @@ TEST_CASE("test vmap scatter") {
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap SVD") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return linalg::svd(inputs.at(0), Device::cpu);
|
||||
auto svd_full = [](std::vector<array> inputs) {
|
||||
return linalg::svd(inputs.at(0), true, Device::cpu);
|
||||
};
|
||||
|
||||
auto svd_singular = [](std::vector<array> inputs) {
|
||||
return linalg::svd(inputs.at(0), false, Device::cpu);
|
||||
};
|
||||
|
||||
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
|
||||
|
||||
// vmap over the second axis.
|
||||
{
|
||||
auto out = vmap(fun, /* in_axes = */ {1})({a});
|
||||
auto out = vmap(svd_full, /* in_axes = */ {1})({a});
|
||||
const auto& U = out.at(0);
|
||||
const auto& S = out.at(1);
|
||||
const auto& Vt = out.at(2);
|
||||
@@ -486,7 +490,7 @@ TEST_CASE("test vmap SVD") {
|
||||
|
||||
// vmap over the third axis.
|
||||
{
|
||||
auto out = vmap(fun, /* in_axes = */ {2})({a});
|
||||
auto out = vmap(svd_full, /* in_axes = */ {2})({a});
|
||||
const auto& U = out.at(0);
|
||||
const auto& S = out.at(1);
|
||||
const auto& Vt = out.at(2);
|
||||
@@ -495,6 +499,21 @@ TEST_CASE("test vmap SVD") {
|
||||
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
|
||||
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
|
||||
}
|
||||
|
||||
// test singular values
|
||||
{
|
||||
auto out = vmap(svd_singular, /* in_axes = */ {1})({a});
|
||||
const auto& S = out.at(0);
|
||||
|
||||
CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)});
|
||||
}
|
||||
|
||||
{
|
||||
auto out = vmap(svd_singular, /* in_axes = */ {2})({a});
|
||||
const auto& S = out.at(0);
|
||||
|
||||
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap dynamic slices") {
|
||||
|
||||
Reference in New Issue
Block a user