mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
added atleast *args input support (#710)
* added atleast list(array) input support * function overloading implemented * Refactoring * fixed formatting * removed pos_only
This commit is contained in:

committed by
GitHub

parent
3b661b7394
commit
08226ab491
@@ -2787,6 +2787,19 @@ TEST_CASE("test atleast_1d") {
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_1d vector") {
|
||||
auto x = std::vector<array>{
|
||||
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
|
||||
auto out = atleast_1d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 1);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1});
|
||||
CHECK_EQ(out[1].ndim(), 1);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{3});
|
||||
CHECK_EQ(out[2].ndim(), 2);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_2d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_2d(x);
|
||||
@@ -2804,6 +2817,19 @@ TEST_CASE("test atleast_2d") {
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_2d vector") {
|
||||
auto x = std::vector<array>{
|
||||
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
|
||||
auto out = atleast_2d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 2);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(out[1].ndim(), 2);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3});
|
||||
CHECK_EQ(out[2].ndim(), 2);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_3d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_3d(x);
|
||||
@@ -2820,3 +2846,16 @@ TEST_CASE("test atleast_3d") {
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_3d vector") {
|
||||
auto x = std::vector<array>{
|
||||
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
|
||||
auto out = atleast_3d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 3);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1, 1});
|
||||
CHECK_EQ(out[1].ndim(), 3);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
|
||||
CHECK_EQ(out[2].ndim(), 3);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
Reference in New Issue
Block a user