Add vmap to scatter (#1200)

* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nicolov
2024-08-06 05:12:27 +02:00
committed by GitHub
parent 58d0e199e1
commit 8c9f0278b9
6 changed files with 269 additions and 8 deletions

View File

@@ -414,6 +414,94 @@ TEST_CASE("test vmap gather") {
}
}
TEST_CASE("test vmap scatter") {
auto make_scatter_fn = [](const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes) {
return [=](const std::vector<array>& inputs) {
auto a = inputs.at(0);
return std::vector<array>{scatter(a, indices, updates, axes)};
};
};
{
// vmap nothing.
auto a = zeros({3, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0];
auto expected =
array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32);
// Non-vmapped function output.
CHECK(array_equal(func({a}).at(0), expected).item<bool>());
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 0, scatter on axis 0.
auto a = zeros({2, 3, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
auto out = vmap(func, /* in_axes = */ {0})({a})[0];
auto expected = array(
{0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0},
{2, 3, 4},
float32);
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 1, scatter on axis 0.
auto a = zeros({3, 2, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
auto out = vmap(func, /* in_axes = */ {1}, /* out_axes = */ {1})({a})[0];
auto expected = array(
{0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0,
1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{3, 2, 4},
float32);
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 0, scatter on axis 1.
auto a = zeros({2, 3, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 2, 1});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{1});
auto out = vmap(func, /* in_axes = */ {0})({a})[0];
auto expected = array(
{0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0},
{2, 3, 4},
float32);
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 2, scatter on axes (0, 1).
auto a = zeros({2, 3, 2});
auto indices = {array({1}), array({2})};
auto axes = {0, 1};
auto updates = reshape(array({1}, float32), {1, 1, 1});
auto func = make_scatter_fn(indices, updates, axes);
auto out = vmap(func, /* in_axes = */ {2}, /* out_axes = */ {2})({a})[0];
auto expected =
array({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1}, {2, 3, 2}, float32);
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test vmap SVD") {
auto fun = [](std::vector<array> inputs) {
return linalg::svd(inputs.at(0), Device::cpu);