Add tile op (#438)

This commit is contained in:
Diogo
2024-01-13 02:03:16 -05:00
committed by GitHub
parent 1b71487e1f
commit 2e29d0815b
7 changed files with 105 additions and 3 deletions

View File

@@ -2343,6 +2343,32 @@ TEST_CASE("test repeat") {
CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
}
TEST_CASE("tile") {
auto x = array({1, 2, 3}, {3});
auto y = tile(x, {2});
auto expected = array({1, 2, 3, 1, 2, 3}, {6});
CHECK(array_equal(y, expected).item<bool>());
x = array({1, 2, 3, 4}, {2, 2});
y = tile(x, {2});
expected = array({1, 2, 1, 2, 3, 4, 3, 4}, {2, 4});
CHECK(array_equal(y, expected).item<bool>());
x = array({1, 2, 3, 4}, {2, 2});
y = tile(x, {4, 1});
expected = array({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, {8, 2});
CHECK(array_equal(y, expected).item<bool>());
x = array({1, 2, 3, 4}, {2, 2});
y = tile(x, {2, 2});
expected = array({1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
x = array({1, 2, 3}, {3});
y = tile(x, {2, 2, 2});
expected = array(
{1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3},
{2, 2, 6});
CHECK(array_equal(y, expected).item<bool>());
}
TEST_CASE("tensordot") {
auto x = reshape(arange(60.), {3, 4, 5});
auto y = reshape(arange(24.), {4, 3, 2});