Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -496,3 +496,33 @@ TEST_CASE("test vmap SVD") {
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
}
}
TEST_CASE("test vmap dynamic slices") {
{
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{slice(inputs[0], array({1}), {0}, {2})};
};
auto x = reshape(arange(12), {3, 4});
auto out = vmap(fun)({x})[0];
CHECK(array_equal(out, array({1, 2, 5, 6, 9, 10}, {3, 2})).item<bool>());
out = vmap(fun, /* in_axes */ {1}, /* out_axes */ {1})({x})[0];
CHECK(array_equal(out, array({4, 5, 6, 7, 8, 9, 10, 11}, {2, 4}))
.item<bool>());
}
{
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{
slice_update(inputs[0], inputs[1], array({1}), {0})};
};
auto x = zeros({2, 2});
auto upd = ones({2, 1});
auto out = vmap(fun)({x, upd})[0];
CHECK(array_equal(out, array({0, 1, 0, 1}, {2, 2})).item<bool>());
out = vmap(fun, /* in_axes */ {1, 0}, /* out_axes */ {1})({x, upd})[0];
CHECK(array_equal(out, array({0, 0, 1, 1}, {2, 2})).item<bool>());
}
}