mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 15:28:10 +08:00
Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
This commit is contained in:
@@ -25,3 +25,38 @@ TEST_CASE("test type promotion") {
|
||||
CHECK_EQ(result_type(arrs), float32);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test normalize axis") {
|
||||
struct TestCase {
|
||||
int axis;
|
||||
int ndim;
|
||||
int expected;
|
||||
};
|
||||
|
||||
std::vector<TestCase> testCases = {
|
||||
{0, 3, 0}, {1, 3, 1}, {2, 3, 2}, {-1, 3, 2}, {-2, 3, 1}, {-3, 3, 0}};
|
||||
|
||||
for (const auto& tc : testCases) {
|
||||
CHECK_EQ(normalize_axis(tc.axis, tc.ndim), tc.expected);
|
||||
}
|
||||
|
||||
CHECK_THROWS(normalize_axis(3, 3));
|
||||
CHECK_THROWS(normalize_axis(-4, 3));
|
||||
}
|
||||
|
||||
TEST_CASE("test is same size and shape") {
|
||||
struct TestCase {
|
||||
std::vector<array> a;
|
||||
bool expected;
|
||||
};
|
||||
|
||||
std::vector<TestCase> testCases = {
|
||||
{{array({}), array({})}, true},
|
||||
{{array({1}), array({1})}, true},
|
||||
{{array({1, 2, 3}), array({1, 2, 4})}, true},
|
||||
{{array({1, 2, 3}), array({1, 2})}, false}};
|
||||
|
||||
for (const auto& tc : testCases) {
|
||||
CHECK_EQ(is_same_shape(tc.a), tc.expected);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user