Add the roll op (#1455)

This commit is contained in:
Angelos Katharopoulos
2024-10-07 17:21:42 -07:00
committed by GitHub
parent f374b6ca4d
commit 9b12093739
6 changed files with 231 additions and 0 deletions

View File

@@ -3715,3 +3715,35 @@ TEST_CASE("test view") {
auto out = view(in, int32);
CHECK(array_equal(out, array({1, 0, 2, 0, 3, 0, 4, 0})).item<bool>());
}
TEST_CASE("test roll") {
auto x = reshape(arange(10), {2, 5});
auto y = roll(x, 2);
CHECK(array_equal(y, array({8, 9, 0, 1, 2, 3, 4, 5, 6, 7}, {2, 5}))
.item<bool>());
y = roll(x, -2);
CHECK(array_equal(y, array({2, 3, 4, 5, 6, 7, 8, 9, 0, 1}, {2, 5}))
.item<bool>());
y = roll(x, 2, 1);
CHECK(array_equal(y, array({3, 4, 0, 1, 2, 8, 9, 5, 6, 7}, {2, 5}))
.item<bool>());
y = roll(x, -2, 1);
CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))
.item<bool>());
y = roll(x, 2, {0, 0, 0});
CHECK(array_equal(y, array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 5}))
.item<bool>());
y = roll(x, 1, {1, 1, 1});
CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))
.item<bool>());
y = roll(x, {1, 2}, {0, 1});
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
.item<bool>());
}