mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 06:53:18 +08:00
Add the roll op (#1455)
This commit is contained in:
committed by
GitHub
parent
f374b6ca4d
commit
9b12093739
@@ -3715,3 +3715,35 @@ TEST_CASE("test view") {
|
||||
auto out = view(in, int32);
|
||||
CHECK(array_equal(out, array({1, 0, 2, 0, 3, 0, 4, 0})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test roll") {
|
||||
auto x = reshape(arange(10), {2, 5});
|
||||
|
||||
auto y = roll(x, 2);
|
||||
CHECK(array_equal(y, array({8, 9, 0, 1, 2, 3, 4, 5, 6, 7}, {2, 5}))
|
||||
.item<bool>());
|
||||
|
||||
y = roll(x, -2);
|
||||
CHECK(array_equal(y, array({2, 3, 4, 5, 6, 7, 8, 9, 0, 1}, {2, 5}))
|
||||
.item<bool>());
|
||||
|
||||
y = roll(x, 2, 1);
|
||||
CHECK(array_equal(y, array({3, 4, 0, 1, 2, 8, 9, 5, 6, 7}, {2, 5}))
|
||||
.item<bool>());
|
||||
|
||||
y = roll(x, -2, 1);
|
||||
CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))
|
||||
.item<bool>());
|
||||
|
||||
y = roll(x, 2, {0, 0, 0});
|
||||
CHECK(array_equal(y, array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 5}))
|
||||
.item<bool>());
|
||||
|
||||
y = roll(x, 1, {1, 1, 1});
|
||||
CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))
|
||||
.item<bool>());
|
||||
|
||||
y = roll(x, {1, 2}, {0, 1});
|
||||
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user