Implement diagonal operator (#562)

* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jacket
2024-01-30 11:45:48 -06:00
committed by GitHub
parent 65d0b8df9f
commit 3f7aba8498
8 changed files with 309 additions and 4 deletions

View File

@@ -1486,5 +1486,26 @@ void init_array(py::module_& m) {
"decimals"_a = 0,
py::kw_only(),
"stream"_a = none,
"See :func:`round`.");
"See :func:`round`.")
.def(
"diagonal",
[](const array& a,
int offset,
int axis1,
int axis2,
StreamOrDevice s) { return diagonal(a, offset, axis1, axis2, s); },
"offset"_a = 0,
"axis1"_a = 0,
"axis2"_a = 1,
"stream"_a = none,
"See :func:`diagonal`.")
.def(
"diag",
[](const array& a, int k, StreamOrDevice s) { return diag(a, k, s); },
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
Extract a diagonal or construct a diagonal matrix.
)pbdoc");
}

View File

@@ -3577,4 +3577,61 @@ void init_ops(py::module_& m) {
Returns:
array: ``alpha * (a @ b) + beta * c``
)pbdoc");
m.def(
"diagonal",
&diagonal,
"a"_a,
"offset"_a = 0,
"axis1"_a = 0,
"axis2"_a = 1,
"stream"_a = none,
R"pbdoc(
diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array
Return specified diagonals.
If ``a`` is 2-D, then a 1-D array containing the diagonal at the given
``offset`` is returned.
If ``a`` has more than two dimensions, then ``axis1`` and ``axis2``
determine the 2D subarrays from which diagonals are extracted. The new
shape is the original shape with ``axis1`` and ``axis2`` removed and a
new dimension inserted at the end corresponding to the diagonal.
Args:
a (array): Input array
offset (int, optional): Offset of the diagonal from the main diagonal.
Can be positive or negative. Default: ``0``.
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
the diagonals should be taken. Default: ``0``.
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
the diagonals should be taken. Default: ``1``.
Returns:
array: The diagonals of the array.
)pbdoc");
m.def(
"diag",
&diag,
"a"_a,
py::pos_only(),
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
Extract a diagonal or construct a diagonal matrix.
If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the
:math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is
returned.
Args:
a (array): 1-D or 2-D input array.
k (int, optional): The diagonal to extract or construct.
Default: ``0``.
Returns:
array: The extracted diagonal or the constructed diagonal matrix.
)pbdoc");
}

View File

@@ -1785,6 +1785,62 @@ class TestOps(mlx_tests.MLXTestCase):
out = a @ b
self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))
def test_diagonal(self):
x = mx.array(
[
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]],
]
)
expected = [[0, 13], [4, 17], [8, 21]]
self.assertListEqual(mx.diagonal(x, 0, -1, 0).tolist(), expected)
expected = [[1, 14], [5, 18], [9, 22]]
self.assertListEqual(mx.diagonal(x, -1, 2, 0).tolist(), expected)
def test_diag(self):
# Test 1D input
x = mx.array([1, 2, 3, 4])
expected = mx.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]])
result = mx.diag(x)
self.assertTrue(mx.array_equal(result, expected))
# Test 1D with offset
x = mx.array([2, 6])
result = mx.diag(x, k=5)
expected = mx.array(np.diag(x, k=5))
self.assertTrue(mx.array_equal(result, expected))
# Test 2D input
x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
expected = mx.array([1, 5, 9])
result = mx.diag(x)
self.assertTrue(mx.array_equal(result, expected))
# Test with offset
expected = mx.array([2, 6])
result = mx.diag(x, 1)
self.assertTrue(mx.array_equal(result, expected))
# Test non-square
x = mx.array([[1, 2, 3], [4, 5, 6]])
result = mx.diag(x)
expected = mx.array(np.diag(x))
self.assertTrue(mx.array_equal(result, expected))
result = mx.diag(x, k=10)
expected = mx.array(np.diag(x, k=10))
self.assertTrue(mx.array_equal(result, expected))
result = mx.diag(x, k=-10)
expected = mx.array(np.diag(x, k=-10))
self.assertTrue(mx.array_equal(result, expected))
result = mx.diag(x, k=-1)
expected = mx.array(np.diag(x, k=-1))
self.assertTrue(mx.array_equal(result, expected))
if __name__ == "__main__":
unittest.main()