mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	| @@ -1926,3 +1926,58 @@ TEST_CASE("test where") { | ||||
|   expected = array({1, 2, 2, 1}, {2, 2}); | ||||
|   CHECK(array_equal(where(condition, x, y), expected).item<bool>()); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test eye") { | ||||
|   auto eye_3 = eye(3); | ||||
|   CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3}); | ||||
|   auto expected_eye_3 = | ||||
|       array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3}); | ||||
|   CHECK(array_equal(eye_3, expected_eye_3).item<bool>()); | ||||
|  | ||||
|   auto eye_3x2 = eye(3, 2); | ||||
|   CHECK_EQ(eye_3x2.shape(), std::vector<int>{3, 2}); | ||||
|   auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2}); | ||||
|   CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>()); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test identity") { | ||||
|   auto id_4 = identity(4); | ||||
|   CHECK_EQ(id_4.shape(), std::vector<int>{4, 4}); | ||||
|   auto expected_id_4 = array( | ||||
|       {1.0f, | ||||
|        0.0f, | ||||
|        0.0f, | ||||
|        0.0f, | ||||
|        0.0f, | ||||
|        1.0f, | ||||
|        0.0f, | ||||
|        0.0f, | ||||
|        0.0f, | ||||
|        0.0f, | ||||
|        1.0f, | ||||
|        0.0f, | ||||
|        0.0f, | ||||
|        0.0f, | ||||
|        0.0f, | ||||
|        1.0f}, | ||||
|       {4, 4}); | ||||
|   CHECK(array_equal(id_4, expected_id_4).item<bool>()); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test eye with positive k offset") { | ||||
|   auto eye_3_k1 = eye(3, 4, 1); | ||||
|   CHECK_EQ(eye_3_k1.shape(), std::vector<int>{3, 4}); | ||||
|   auto expected_eye_3_k1 = array( | ||||
|       {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}, | ||||
|       {3, 4}); | ||||
|   CHECK(array_equal(eye_3_k1, expected_eye_3_k1).item<bool>()); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test eye with negative k offset") { | ||||
|   auto eye_4_k_minus1 = eye(4, 3, -1); | ||||
|   CHECK_EQ(eye_4_k_minus1.shape(), std::vector<int>{4, 3}); | ||||
|   auto expected_eye_4_k_minus1 = array( | ||||
|       {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, | ||||
|       {4, 3}); | ||||
|   CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>()); | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Cyril Zakka, MD
					Cyril Zakka, MD