mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
@@ -395,7 +395,7 @@ TEST_CASE("test split") {
|
||||
CHECK_EQ(out[1].shape(), Shape{8, 4});
|
||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||
|
||||
out = split(x, std::vector<int>{});
|
||||
out = split(x, Shape{});
|
||||
CHECK_EQ(out.size(), 1);
|
||||
CHECK_EQ(out[0].shape(), x.shape());
|
||||
|
||||
@@ -405,25 +405,25 @@ TEST_CASE("test split") {
|
||||
CHECK_EQ(out[1].shape(), Shape{4, 12});
|
||||
CHECK_EQ(out[2].shape(), Shape{1, 12});
|
||||
|
||||
out = split(x, std::vector<int>{20});
|
||||
out = split(x, Shape{20});
|
||||
CHECK_EQ(out.size(), 2);
|
||||
CHECK_EQ(out[0].shape(), Shape{8, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{0, 12});
|
||||
|
||||
// Negative indices
|
||||
out = split(x, std::vector<int>{-5});
|
||||
out = split(x, Shape{-5});
|
||||
CHECK_EQ(out[0].shape(), Shape{3, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{5, 12});
|
||||
|
||||
// Different axis
|
||||
out = split(x, std::vector<int>{2, 8}, 1);
|
||||
out = split(x, {2, 8}, 1);
|
||||
CHECK_EQ(out[0].shape(), Shape{8, 2});
|
||||
CHECK_EQ(out[1].shape(), Shape{8, 6});
|
||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||
|
||||
// Out of order indices
|
||||
x = arange(5);
|
||||
out = split(x, std::vector<int>{2, 1, 2});
|
||||
out = split(x, {2, 1, 2});
|
||||
CHECK(array_equal(out[0], array({0, 1})).item<bool>());
|
||||
CHECK(array_equal(out[1], array({})).item<bool>());
|
||||
CHECK(array_equal(out[2], array({1})).item<bool>());
|
||||
|
Reference in New Issue
Block a user