Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -1291,3 +1291,24 @@ TEST_CASE("test grad types") {
}
}
}
TEST_CASE("test grad dynamic slices") {
{
auto fn = [](const array& x) { return slice(x, array({0}), {0}, {1, 2}); };
auto x = array({1, 2, 3, 4}, {2, 2});
auto out = vjp(fn, x, array({1, 1}, {1, 2})).second;
CHECK(array_equal(out, array({1, 1, 0, 0}, {2, 2})).item<bool>());
}
{
auto fn = [](const std::vector<array>& inputs) {
const auto& x = inputs[0];
const auto& update = inputs[1];
return std::vector<array>{slice_update(x, update, array({0}), {0})};
};
auto x = zeros({2, 2});
auto update = array({3.f, 4.f}, {1, 2});
auto outs = vjp(fn, {x, update}, {ones({2, 2})}).second;
CHECK(allclose(outs[0], array({0.f, 0.f, 1.f, 1.f}, {2, 2})).item<bool>());
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
}
}