mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add vmap to scatter (#1200)
* Add vmap to scatter * updates * vmap updates + a few more tests * bug fix --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -414,6 +414,94 @@ TEST_CASE("test vmap gather") {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap scatter") {
|
||||
auto make_scatter_fn = [](const std::vector<array>& indices,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes) {
|
||||
return [=](const std::vector<array>& inputs) {
|
||||
auto a = inputs.at(0);
|
||||
return std::vector<array>{scatter(a, indices, updates, axes)};
|
||||
};
|
||||
};
|
||||
|
||||
{
|
||||
// vmap nothing.
|
||||
auto a = zeros({3, 4});
|
||||
auto indices = array({1});
|
||||
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
|
||||
|
||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
|
||||
auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0];
|
||||
auto expected =
|
||||
array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32);
|
||||
// Non-vmapped function output.
|
||||
CHECK(array_equal(func({a}).at(0), expected).item<bool>());
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
// vmap src on axis 0, scatter on axis 0.
|
||||
auto a = zeros({2, 3, 4});
|
||||
auto indices = array({1});
|
||||
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
|
||||
|
||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
|
||||
auto out = vmap(func, /* in_axes = */ {0})({a})[0];
|
||||
auto expected = array(
|
||||
{0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0},
|
||||
{2, 3, 4},
|
||||
float32);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
// vmap src on axis 1, scatter on axis 0.
|
||||
auto a = zeros({3, 2, 4});
|
||||
auto indices = array({1});
|
||||
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
|
||||
|
||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
|
||||
auto out = vmap(func, /* in_axes = */ {1}, /* out_axes = */ {1})({a})[0];
|
||||
auto expected = array(
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0,
|
||||
1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{3, 2, 4},
|
||||
float32);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
// vmap src on axis 0, scatter on axis 1.
|
||||
auto a = zeros({2, 3, 4});
|
||||
auto indices = array({1});
|
||||
auto updates = reshape(array({1, 2}, float32), {1, 2, 1});
|
||||
|
||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{1});
|
||||
auto out = vmap(func, /* in_axes = */ {0})({a})[0];
|
||||
auto expected = array(
|
||||
{0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0},
|
||||
{2, 3, 4},
|
||||
float32);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
// vmap src on axis 2, scatter on axes (0, 1).
|
||||
auto a = zeros({2, 3, 2});
|
||||
auto indices = {array({1}), array({2})};
|
||||
auto axes = {0, 1};
|
||||
auto updates = reshape(array({1}, float32), {1, 1, 1});
|
||||
|
||||
auto func = make_scatter_fn(indices, updates, axes);
|
||||
auto out = vmap(func, /* in_axes = */ {2}, /* out_axes = */ {2})({a})[0];
|
||||
auto expected =
|
||||
array({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1}, {2, 3, 2}, float32);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap SVD") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return linalg::svd(inputs.at(0), Device::cpu);
|
||||
|
Reference in New Issue
Block a user