mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
@@ -1926,3 +1926,58 @@ TEST_CASE("test where") {
|
||||
expected = array({1, 2, 2, 1}, {2, 2});
|
||||
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eye") {
|
||||
auto eye_3 = eye(3);
|
||||
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
|
||||
auto expected_eye_3 =
|
||||
array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3});
|
||||
CHECK(array_equal(eye_3, expected_eye_3).item<bool>());
|
||||
|
||||
auto eye_3x2 = eye(3, 2);
|
||||
CHECK_EQ(eye_3x2.shape(), std::vector<int>{3, 2});
|
||||
auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2});
|
||||
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test identity") {
|
||||
auto id_4 = identity(4);
|
||||
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
|
||||
auto expected_id_4 = array(
|
||||
{1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
1.0f},
|
||||
{4, 4});
|
||||
CHECK(array_equal(id_4, expected_id_4).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eye with positive k offset") {
|
||||
auto eye_3_k1 = eye(3, 4, 1);
|
||||
CHECK_EQ(eye_3_k1.shape(), std::vector<int>{3, 4});
|
||||
auto expected_eye_3_k1 = array(
|
||||
{0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||
{3, 4});
|
||||
CHECK(array_equal(eye_3_k1, expected_eye_3_k1).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eye with negative k offset") {
|
||||
auto eye_4_k_minus1 = eye(4, 3, -1);
|
||||
CHECK_EQ(eye_4_k_minus1.shape(), std::vector<int>{4, 3});
|
||||
auto expected_eye_4_k_minus1 = array(
|
||||
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||
{4, 3});
|
||||
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
|
||||
}
|
Reference in New Issue
Block a user