Added eye/identity ops (#119)

`eye` and `identity` C++ and Python ops
This commit is contained in:
Cyril Zakka, MD
2023-12-11 12:38:17 -08:00
committed by GitHub
parent 69505b4e9b
commit e080290ba4
6 changed files with 175 additions and 0 deletions

View File

@@ -1926,3 +1926,58 @@ TEST_CASE("test where") {
expected = array({1, 2, 2, 1}, {2, 2});
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
}
TEST_CASE("test eye") {
auto eye_3 = eye(3);
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
auto expected_eye_3 =
array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3});
CHECK(array_equal(eye_3, expected_eye_3).item<bool>());
auto eye_3x2 = eye(3, 2);
CHECK_EQ(eye_3x2.shape(), std::vector<int>{3, 2});
auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2});
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
}
TEST_CASE("test identity") {
auto id_4 = identity(4);
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
auto expected_id_4 = array(
{1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
1.0f},
{4, 4});
CHECK(array_equal(id_4, expected_id_4).item<bool>());
}
TEST_CASE("test eye with positive k offset") {
auto eye_3_k1 = eye(3, 4, 1);
CHECK_EQ(eye_3_k1.shape(), std::vector<int>{3, 4});
auto expected_eye_3_k1 = array(
{0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f},
{3, 4});
CHECK(array_equal(eye_3_k1, expected_eye_3_k1).item<bool>());
}
TEST_CASE("test eye with negative k offset") {
auto eye_4_k_minus1 = eye(4, 3, -1);
CHECK_EQ(eye_4_k_minus1.shape(), std::vector<int>{4, 3});
auto expected_eye_4_k_minus1 = array(
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
{4, 3});
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
}