mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
added tri / tril / triu (#170)
* added tri / tril / triu * fixed tests * ctest tests * tri overload and simplified tests * changes from comment * more tests for m * ensure assert if not 2-D * remove broadcast_to * minor tweaks --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2031,6 +2031,78 @@ TEST_CASE("test eye") {
|
||||
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test tri") {
|
||||
auto _tri = tri(4, 4, 0, float32);
|
||||
CHECK_EQ(_tri.shape(), std::vector<int>{4, 4});
|
||||
auto expected_tri = array(
|
||||
{1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
1.0f,
|
||||
1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
1.0f,
|
||||
1.0f,
|
||||
1.0f,
|
||||
0.0f,
|
||||
1.0f,
|
||||
1.0f,
|
||||
1.0f,
|
||||
1.0f},
|
||||
{4, 4});
|
||||
CHECK(array_equal(_tri, expected_tri).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test tril") {
|
||||
auto _tril = tril(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
|
||||
CHECK_EQ(_tril.shape(), std::vector<int>{4, 4});
|
||||
auto expected_tri = array(
|
||||
{2.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
0.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
2.0f},
|
||||
{4, 4});
|
||||
CHECK(array_equal(_tril, expected_tri).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test triu") {
|
||||
auto _triu = triu(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
|
||||
CHECK_EQ(_triu.shape(), std::vector<int>{4, 4});
|
||||
auto expected_tri = array(
|
||||
{2.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
0.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
2.0f,
|
||||
2.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
2.0f},
|
||||
{4, 4});
|
||||
CHECK(array_equal(_triu, expected_tri).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test identity") {
|
||||
auto id_4 = identity(4);
|
||||
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
|
||||
|
Reference in New Issue
Block a user