mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Implement diagonal operator (#562)
* Implement diagonal operator This implements mx.diagonal in operator level, inspired by @ManishAradwad. * added `mx.diag` with tests * corrected few things * nits in bindings * updates to diag --------- Co-authored-by: ManishAradwad <manisharadwad@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1785,6 +1785,62 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = a @ b
|
||||
self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))
|
||||
|
||||
def test_diagonal(self):
|
||||
x = mx.array(
|
||||
[
|
||||
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
|
||||
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]],
|
||||
]
|
||||
)
|
||||
expected = [[0, 13], [4, 17], [8, 21]]
|
||||
|
||||
self.assertListEqual(mx.diagonal(x, 0, -1, 0).tolist(), expected)
|
||||
|
||||
expected = [[1, 14], [5, 18], [9, 22]]
|
||||
self.assertListEqual(mx.diagonal(x, -1, 2, 0).tolist(), expected)
|
||||
|
||||
def test_diag(self):
|
||||
# Test 1D input
|
||||
x = mx.array([1, 2, 3, 4])
|
||||
expected = mx.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]])
|
||||
result = mx.diag(x)
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
# Test 1D with offset
|
||||
x = mx.array([2, 6])
|
||||
result = mx.diag(x, k=5)
|
||||
expected = mx.array(np.diag(x, k=5))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
# Test 2D input
|
||||
x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
expected = mx.array([1, 5, 9])
|
||||
result = mx.diag(x)
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
# Test with offset
|
||||
expected = mx.array([2, 6])
|
||||
result = mx.diag(x, 1)
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
# Test non-square
|
||||
x = mx.array([[1, 2, 3], [4, 5, 6]])
|
||||
result = mx.diag(x)
|
||||
expected = mx.array(np.diag(x))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
result = mx.diag(x, k=10)
|
||||
expected = mx.array(np.diag(x, k=10))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
result = mx.diag(x, k=-10)
|
||||
expected = mx.array(np.diag(x, k=-10))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
result = mx.diag(x, k=-1)
|
||||
expected = mx.array(np.diag(x, k=-1))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user