Added eye/identity ops (#119)

`eye` and `identity` C++ and Python ops
This commit is contained in:
Cyril Zakka, MD
2023-12-11 12:38:17 -08:00
committed by GitHub
parent 69505b4e9b
commit e080290ba4
6 changed files with 175 additions and 0 deletions

View File

@@ -1311,5 +1311,28 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual((a + b)[0, 0].item(), 2)
def test_eye(self):
eye_matrix = mx.eye(3)
np_eye_matrix = np.eye(3)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test for non-square matrix
eye_matrix = mx.eye(3, 4)
np_eye_matrix = np.eye(3, 4)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test with positive k parameter
eye_matrix = mx.eye(3, 4, k=1)
np_eye_matrix = np.eye(3, 4, k=1)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test with negative k parameter
eye_matrix = mx.eye(5, 6, k=-2)
np_eye_matrix = np.eye(5, 6, k=-2)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
if __name__ == "__main__":
unittest.main()