mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Implement diagonal operator (#562)
* Implement diagonal operator This implements mx.diagonal in operator level, inspired by @ManishAradwad. * added `mx.diag` with tests * corrected few things * nits in bindings * updates to diag --------- Co-authored-by: ManishAradwad <manisharadwad@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1486,5 +1486,26 @@ void init_array(py::module_& m) {
|
||||
"decimals"_a = 0,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
"See :func:`round`.");
|
||||
"See :func:`round`.")
|
||||
.def(
|
||||
"diagonal",
|
||||
[](const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
StreamOrDevice s) { return diagonal(a, offset, axis1, axis2, s); },
|
||||
"offset"_a = 0,
|
||||
"axis1"_a = 0,
|
||||
"axis2"_a = 1,
|
||||
"stream"_a = none,
|
||||
"See :func:`diagonal`.")
|
||||
.def(
|
||||
"diag",
|
||||
[](const array& a, int k, StreamOrDevice s) { return diag(a, k, s); },
|
||||
"k"_a = 0,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Extract a diagonal or construct a diagonal matrix.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -3577,4 +3577,61 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: ``alpha * (a @ b) + beta * c``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diagonal",
|
||||
&diagonal,
|
||||
"a"_a,
|
||||
"offset"_a = 0,
|
||||
"axis1"_a = 0,
|
||||
"axis2"_a = 1,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return specified diagonals.
|
||||
|
||||
If ``a`` is 2-D, then a 1-D array containing the diagonal at the given
|
||||
``offset`` is returned.
|
||||
|
||||
If ``a`` has more than two dimensions, then ``axis1`` and ``axis2``
|
||||
determine the 2D subarrays from which diagonals are extracted. The new
|
||||
shape is the original shape with ``axis1`` and ``axis2`` removed and a
|
||||
new dimension inserted at the end corresponding to the diagonal.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
offset (int, optional): Offset of the diagonal from the main diagonal.
|
||||
Can be positive or negative. Default: ``0``.
|
||||
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
|
||||
the diagonals should be taken. Default: ``0``.
|
||||
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
|
||||
the diagonals should be taken. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The diagonals of the array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diag",
|
||||
&diag,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
"k"_a = 0,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Extract a diagonal or construct a diagonal matrix.
|
||||
If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the
|
||||
:math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is
|
||||
returned.
|
||||
|
||||
Args:
|
||||
a (array): 1-D or 2-D input array.
|
||||
k (int, optional): The diagonal to extract or construct.
|
||||
Default: ``0``.
|
||||
|
||||
Returns:
|
||||
array: The extracted diagonal or the constructed diagonal matrix.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -1785,6 +1785,62 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = a @ b
|
||||
self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))
|
||||
|
||||
def test_diagonal(self):
|
||||
x = mx.array(
|
||||
[
|
||||
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
|
||||
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]],
|
||||
]
|
||||
)
|
||||
expected = [[0, 13], [4, 17], [8, 21]]
|
||||
|
||||
self.assertListEqual(mx.diagonal(x, 0, -1, 0).tolist(), expected)
|
||||
|
||||
expected = [[1, 14], [5, 18], [9, 22]]
|
||||
self.assertListEqual(mx.diagonal(x, -1, 2, 0).tolist(), expected)
|
||||
|
||||
def test_diag(self):
|
||||
# Test 1D input
|
||||
x = mx.array([1, 2, 3, 4])
|
||||
expected = mx.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]])
|
||||
result = mx.diag(x)
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
# Test 1D with offset
|
||||
x = mx.array([2, 6])
|
||||
result = mx.diag(x, k=5)
|
||||
expected = mx.array(np.diag(x, k=5))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
# Test 2D input
|
||||
x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
expected = mx.array([1, 5, 9])
|
||||
result = mx.diag(x)
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
# Test with offset
|
||||
expected = mx.array([2, 6])
|
||||
result = mx.diag(x, 1)
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
# Test non-square
|
||||
x = mx.array([[1, 2, 3], [4, 5, 6]])
|
||||
result = mx.diag(x)
|
||||
expected = mx.array(np.diag(x))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
result = mx.diag(x, k=10)
|
||||
expected = mx.array(np.diag(x, k=10))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
result = mx.diag(x, k=-10)
|
||||
expected = mx.array(np.diag(x, k=-10))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
result = mx.diag(x, k=-1)
|
||||
expected = mx.array(np.diag(x, k=-1))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user