Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -353,6 +353,50 @@ TEST_CASE("test slice update") {
CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
}
TEST_CASE("test dynamic slice") {
auto src = reshape(arange(6), {2, 3});
CHECK_THROWS(slice(src, array({1, 0, 0}), {0, 0, 0}, {1, 1}));
CHECK_THROWS(slice(src, array({1, 0}), {0}, {1, 1}));
CHECK_THROWS(slice(src, array({1}), {3}, {1, 1}));
CHECK_THROWS(slice(src, array({1, 0}), {0, 0}, {1, 1}));
CHECK_THROWS(slice(src, array({1}), {0}, {2, 4}));
CHECK_THROWS(slice(src, array({1.0f}, float32), {0}, {1, 1}));
auto out = slice(src, array({1}), {0}, {1, 2});
auto expected = array({3, 4}, {1, 2});
CHECK(array_equal(out, expected).item<bool>());
out = slice(src, array({1, 1}), {0, 1}, {1, 2});
expected = array({4, 5}, {1, 2});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test dynamic slice update") {
auto src = zeros({2, 3}, int32);
auto upd = ones({1, 2}, int32);
CHECK_THROWS(slice_update(src, upd, array({1, 0, 0}), {0, 0, 0}));
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0}));
CHECK_THROWS(slice_update(src, upd, array({1}), {3}));
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0, 0}));
upd = ones({4}, int32);
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
upd = ones({1, 4}, int32);
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
CHECK_THROWS(slice_update(src, upd, array({1.0f}, float32), {0}));
upd = ones({1, 2}, int32);
auto out = slice_update(src, upd, array({1}), {0});
auto expected = reshape(array({0, 0, 0, 1, 1, 0}), {2, 3});
CHECK(array_equal(out, expected).item<bool>());
upd = ones({1, 2}, int32);
out = slice_update(src, upd, array({1, 1}), {0, 1});
expected = reshape(array({0, 0, 0, 0, 1, 1}), {2, 3});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test split") {
array x = array(1);
CHECK_THROWS(split(x, 0));
@@ -720,7 +764,7 @@ TEST_CASE("test is inf") {
CHECK_FALSE(any(isinf(z)).item<bool>());
array w = array({1.0f, inf, 2.0f});
CHECK(array_equal({false, true, false}, isinf(w)).item<bool>());
CHECK(array_equal(array({false, true, false}), isinf(w)).item<bool>());
array a(1.0f, bfloat16);
CHECK_FALSE(isinf(a).item<bool>());