mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Add tile op (#438)
This commit is contained in:
@@ -2343,6 +2343,32 @@ TEST_CASE("test repeat") {
|
||||
CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
|
||||
}
|
||||
|
||||
TEST_CASE("tile") {
|
||||
auto x = array({1, 2, 3}, {3});
|
||||
auto y = tile(x, {2});
|
||||
auto expected = array({1, 2, 3, 1, 2, 3}, {6});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
x = array({1, 2, 3, 4}, {2, 2});
|
||||
y = tile(x, {2});
|
||||
expected = array({1, 2, 1, 2, 3, 4, 3, 4}, {2, 4});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
x = array({1, 2, 3, 4}, {2, 2});
|
||||
y = tile(x, {4, 1});
|
||||
expected = array({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, {8, 2});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
|
||||
x = array({1, 2, 3, 4}, {2, 2});
|
||||
y = tile(x, {2, 2});
|
||||
expected = array({1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4}, {4, 4});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
x = array({1, 2, 3}, {3});
|
||||
y = tile(x, {2, 2, 2});
|
||||
expected = array(
|
||||
{1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3},
|
||||
{2, 2, 6});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("tensordot") {
|
||||
auto x = reshape(arange(60.), {3, 4, 5});
|
||||
auto y = reshape(arange(24.), {4, 3, 2});
|
||||
|
Reference in New Issue
Block a user