mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
This commit is contained in:
@@ -1989,6 +1989,35 @@ TEST_CASE("test where") {
|
||||
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test stack") {
|
||||
auto x = array({});
|
||||
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 0});
|
||||
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{0, 1});
|
||||
|
||||
x = array({1, 2, 3}, {3});
|
||||
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 3});
|
||||
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{3, 1});
|
||||
|
||||
auto y = array({4, 5, 6}, {3});
|
||||
auto z = std::vector<array>{x, y};
|
||||
CHECK_EQ(stack(z).shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(stack(z, 0).shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(stack(z, 1).shape(), std::vector<int>{3, 2});
|
||||
CHECK_EQ(stack(z, -1).shape(), std::vector<int>{3, 2});
|
||||
CHECK_EQ(stack(z, -2).shape(), std::vector<int>{2, 3});
|
||||
|
||||
CHECK_THROWS_MESSAGE(stack({}, 0), "No arrays provided for stacking");
|
||||
|
||||
x = array({1, 2, 3}, {3}, float16);
|
||||
y = array({4, 5, 6}, {3}, int32);
|
||||
CHECK_EQ(stack({x, y}, 0).dtype(), float16);
|
||||
|
||||
x = array({1, 2, 3}, {3}, int32);
|
||||
y = array({4, 5, 6, 7}, {4}, int32);
|
||||
CHECK_THROWS_MESSAGE(
|
||||
stack({x, y}, 0), "All arrays must have the same shape and dtype");
|
||||
}
|
||||
|
||||
TEST_CASE("test eye") {
|
||||
auto eye_3 = eye(3);
|
||||
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
|
||||
|
Reference in New Issue
Block a user