mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Added support for atleast_1d, atleast_2d, atleast_3d (#694)
This commit is contained in:

committed by
GitHub

parent
e1bdf6a8d9
commit
f883fcede0
@@ -2716,3 +2716,54 @@ TEST_CASE("test diag") {
|
||||
out = diag(x, -1);
|
||||
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_1d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_1d(x);
|
||||
CHECK_EQ(out.ndim(), 1);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1});
|
||||
|
||||
x = array({1, 2, 3}, {3});
|
||||
out = atleast_1d(x);
|
||||
CHECK_EQ(out.ndim(), 1);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3});
|
||||
|
||||
x = array({1, 2, 3}, {3, 1});
|
||||
out = atleast_1d(x);
|
||||
CHECK_EQ(out.ndim(), 2);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_2d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_2d(x);
|
||||
CHECK_EQ(out.ndim(), 2);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 1});
|
||||
|
||||
x = array({1, 2, 3}, {3});
|
||||
out = atleast_2d(x);
|
||||
CHECK_EQ(out.ndim(), 2);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 3});
|
||||
|
||||
x = array({1, 2, 3}, {3, 1});
|
||||
out = atleast_2d(x);
|
||||
CHECK_EQ(out.ndim(), 2);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_3d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_3d(x);
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 1, 1});
|
||||
|
||||
x = array({1, 2, 3}, {3});
|
||||
out = atleast_3d(x);
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 3, 1});
|
||||
|
||||
x = array({1, 2, 3}, {3, 1});
|
||||
out = atleast_3d(x);
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
Reference in New Issue
Block a user