added tri / tril / triu (#170)

* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo
2023-12-15 20:30:34 -05:00
committed by GitHub
parent 2e02acdc83
commit dc2edc762c
9 changed files with 207 additions and 12 deletions

View File

@@ -2031,6 +2031,78 @@ TEST_CASE("test eye") {
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
}
TEST_CASE("test tri") {
auto _tri = tri(4, 4, 0, float32);
CHECK_EQ(_tri.shape(), std::vector<int>{4, 4});
auto expected_tri = array(
{1.0f,
0.0f,
0.0f,
0.0f,
1.0f,
1.0f,
0.0f,
0.0f,
1.0f,
1.0f,
1.0f,
0.0f,
1.0f,
1.0f,
1.0f,
1.0f},
{4, 4});
CHECK(array_equal(_tri, expected_tri).item<bool>());
}
TEST_CASE("test tril") {
auto _tril = tril(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
CHECK_EQ(_tril.shape(), std::vector<int>{4, 4});
auto expected_tri = array(
{2.0f,
0.0f,
0.0f,
0.0f,
2.0f,
2.0f,
0.0f,
0.0f,
2.0f,
2.0f,
2.0f,
0.0f,
2.0f,
2.0f,
2.0f,
2.0f},
{4, 4});
CHECK(array_equal(_tril, expected_tri).item<bool>());
}
TEST_CASE("test triu") {
auto _triu = triu(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
CHECK_EQ(_triu.shape(), std::vector<int>{4, 4});
auto expected_tri = array(
{2.0f,
2.0f,
2.0f,
2.0f,
0.0f,
2.0f,
2.0f,
2.0f,
0.0f,
0.0f,
2.0f,
2.0f,
0.0f,
0.0f,
0.0f,
2.0f},
{4, 4});
CHECK(array_equal(_triu, expected_tri).item<bool>());
}
TEST_CASE("test identity") {
auto id_4 = identity(4);
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});