Added mx.stack c++ frontend impl (#123)

* stack C++ operation + python bindings
This commit is contained in:
Jason
2023-12-14 16:21:19 -05:00
committed by GitHub
parent e5851e52b1
commit e28b57e371
9 changed files with 191 additions and 4 deletions

View File

@@ -1989,6 +1989,35 @@ TEST_CASE("test where") {
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
}
TEST_CASE("test stack") {
auto x = array({});
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 0});
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{0, 1});
x = array({1, 2, 3}, {3});
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 3});
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{3, 1});
auto y = array({4, 5, 6}, {3});
auto z = std::vector<array>{x, y};
CHECK_EQ(stack(z).shape(), std::vector<int>{2, 3});
CHECK_EQ(stack(z, 0).shape(), std::vector<int>{2, 3});
CHECK_EQ(stack(z, 1).shape(), std::vector<int>{3, 2});
CHECK_EQ(stack(z, -1).shape(), std::vector<int>{3, 2});
CHECK_EQ(stack(z, -2).shape(), std::vector<int>{2, 3});
CHECK_THROWS_MESSAGE(stack({}, 0), "No arrays provided for stacking");
x = array({1, 2, 3}, {3}, float16);
y = array({4, 5, 6}, {3}, int32);
CHECK_EQ(stack({x, y}, 0).dtype(), float16);
x = array({1, 2, 3}, {3}, int32);
y = array({4, 5, 6, 7}, {4}, int32);
CHECK_THROWS_MESSAGE(
stack({x, y}, 0), "All arrays must have the same shape and dtype");
}
TEST_CASE("test eye") {
auto eye_3 = eye(3);
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});

View File

@@ -25,3 +25,38 @@ TEST_CASE("test type promotion") {
CHECK_EQ(result_type(arrs), float32);
}
}
TEST_CASE("test normalize axis") {
struct TestCase {
int axis;
int ndim;
int expected;
};
std::vector<TestCase> testCases = {
{0, 3, 0}, {1, 3, 1}, {2, 3, 2}, {-1, 3, 2}, {-2, 3, 1}, {-3, 3, 0}};
for (const auto& tc : testCases) {
CHECK_EQ(normalize_axis(tc.axis, tc.ndim), tc.expected);
}
CHECK_THROWS(normalize_axis(3, 3));
CHECK_THROWS(normalize_axis(-4, 3));
}
TEST_CASE("test is same size and shape") {
struct TestCase {
std::vector<array> a;
bool expected;
};
std::vector<TestCase> testCases = {
{{array({}), array({})}, true},
{{array({1}), array({1})}, true},
{{array({1, 2, 3}), array({1, 2, 4})}, true},
{{array({1, 2, 3}), array({1, 2})}, false}};
for (const auto& tc : testCases) {
CHECK_EQ(is_same_shape(tc.a), tc.expected);
}
}