Implement diagonal operator (#562)

* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jacket
2024-01-30 11:45:48 -06:00
committed by GitHub
parent 65d0b8df9f
commit 3f7aba8498
8 changed files with 309 additions and 4 deletions

View File

@@ -1785,6 +1785,62 @@ class TestOps(mlx_tests.MLXTestCase):
out = a @ b
self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))
def test_diagonal(self):
x = mx.array(
[
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]],
]
)
expected = [[0, 13], [4, 17], [8, 21]]
self.assertListEqual(mx.diagonal(x, 0, -1, 0).tolist(), expected)
expected = [[1, 14], [5, 18], [9, 22]]
self.assertListEqual(mx.diagonal(x, -1, 2, 0).tolist(), expected)
def test_diag(self):
# Test 1D input
x = mx.array([1, 2, 3, 4])
expected = mx.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]])
result = mx.diag(x)
self.assertTrue(mx.array_equal(result, expected))
# Test 1D with offset
x = mx.array([2, 6])
result = mx.diag(x, k=5)
expected = mx.array(np.diag(x, k=5))
self.assertTrue(mx.array_equal(result, expected))
# Test 2D input
x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
expected = mx.array([1, 5, 9])
result = mx.diag(x)
self.assertTrue(mx.array_equal(result, expected))
# Test with offset
expected = mx.array([2, 6])
result = mx.diag(x, 1)
self.assertTrue(mx.array_equal(result, expected))
# Test non-square
x = mx.array([[1, 2, 3], [4, 5, 6]])
result = mx.diag(x)
expected = mx.array(np.diag(x))
self.assertTrue(mx.array_equal(result, expected))
result = mx.diag(x, k=10)
expected = mx.array(np.diag(x, k=10))
self.assertTrue(mx.array_equal(result, expected))
result = mx.diag(x, k=-10)
expected = mx.array(np.diag(x, k=-10))
self.assertTrue(mx.array_equal(result, expected))
result = mx.diag(x, k=-1)
expected = mx.array(np.diag(x, k=-1))
self.assertTrue(mx.array_equal(result, expected))
if __name__ == "__main__":
unittest.main()