mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 00:54:37 +08:00
@@ -1311,5 +1311,28 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual((a + b)[0, 0].item(), 2)
|
||||
|
||||
|
||||
def test_eye(self):
|
||||
eye_matrix = mx.eye(3)
|
||||
np_eye_matrix = np.eye(3)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
# Test for non-square matrix
|
||||
eye_matrix = mx.eye(3, 4)
|
||||
np_eye_matrix = np.eye(3, 4)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
# Test with positive k parameter
|
||||
eye_matrix = mx.eye(3, 4, k=1)
|
||||
np_eye_matrix = np.eye(3, 4, k=1)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
# Test with negative k parameter
|
||||
eye_matrix = mx.eye(5, 6, k=-2)
|
||||
np_eye_matrix = np.eye(5, 6, k=-2)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user