Fix vmap constant output size (#1524)

* use inputs to determine output size

* remove noop vmap tests
This commit is contained in:
Alex Barron
2024-10-30 16:16:53 -07:00
committed by GitHub
parent 917252a5a1
commit 048fabdabd
3 changed files with 38 additions and 36 deletions

View File

@@ -34,12 +34,8 @@ TEST_CASE("test simple vmap") {
CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument);
auto vfun = vmap(fun, -1, -1);
auto x = zeros({2});
CHECK(array_equal(vfun(x), zeros({4, 2})).item<bool>());
vfun = vmap(fun);
x = zeros({3, 2});
auto vfun = vmap(fun);
auto x = zeros({3, 2});
CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());
vfun = vmap(fun, 0, 1);
@@ -121,16 +117,9 @@ TEST_CASE("test simple vmap") {
out = vfun({x, y})[0];
CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());
CHECK_THROWS_AS(vmap(fun, {-1, -1}, {0}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument);
x = array(1.);
y = array(1.);
vfun = vmap(fun, {-1, -1}, {-1});
out = vfun({x, y})[0];
CHECK(array_equal(out, array(2.)).item<bool>());
x = ones({3, 2, 1});
y = ones({3, 2, 1});
vfun = vmap(vmap(fun));
@@ -187,13 +176,6 @@ TEST_CASE("test simple vmap") {
CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);
cond = array({true, false});
x = array(1.);
y = array(2.);
vfun = vmap(fun, {-1, -1, -1}, {-1});
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1.0, 2.0})).item<bool>());
cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});
x = ones({3, 2, 1});
y = full({3, 2, 1}, 2);
@@ -424,21 +406,6 @@ TEST_CASE("test vmap scatter") {
};
};
{
// vmap nothing.
auto a = zeros({3, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0];
auto expected =
array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32);
// Non-vmapped function output.
CHECK(array_equal(func({a}).at(0), expected).item<bool>());
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 0, scatter on axis 0.
auto a = zeros({2, 3, 4});