Implement diagonal operator (#562)

* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jacket
2024-01-30 11:45:48 -06:00
committed by GitHub
parent 65d0b8df9f
commit 3f7aba8498
8 changed files with 309 additions and 4 deletions

View File

@@ -1486,5 +1486,26 @@ void init_array(py::module_& m) {
"decimals"_a = 0,
py::kw_only(),
"stream"_a = none,
"See :func:`round`.");
"See :func:`round`.")
.def(
"diagonal",
[](const array& a,
int offset,
int axis1,
int axis2,
StreamOrDevice s) { return diagonal(a, offset, axis1, axis2, s); },
"offset"_a = 0,
"axis1"_a = 0,
"axis2"_a = 1,
"stream"_a = none,
"See :func:`diagonal`.")
.def(
"diag",
[](const array& a, int k, StreamOrDevice s) { return diag(a, k, s); },
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
Extract a diagonal or construct a diagonal matrix.
)pbdoc");
}

View File

@@ -3577,4 +3577,61 @@ void init_ops(py::module_& m) {
Returns:
array: ``alpha * (a @ b) + beta * c``
)pbdoc");
m.def(
"diagonal",
&diagonal,
"a"_a,
"offset"_a = 0,
"axis1"_a = 0,
"axis2"_a = 1,
"stream"_a = none,
R"pbdoc(
diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array
Return specified diagonals.
If ``a`` is 2-D, then a 1-D array containing the diagonal at the given
``offset`` is returned.
If ``a`` has more than two dimensions, then ``axis1`` and ``axis2``
determine the 2D subarrays from which diagonals are extracted. The new
shape is the original shape with ``axis1`` and ``axis2`` removed and a
new dimension inserted at the end corresponding to the diagonal.
Args:
a (array): Input array
offset (int, optional): Offset of the diagonal from the main diagonal.
Can be positive or negative. Default: ``0``.
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
the diagonals should be taken. Default: ``0``.
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
the diagonals should be taken. Default: ``1``.
Returns:
array: The diagonals of the array.
)pbdoc");
m.def(
"diag",
&diag,
"a"_a,
py::pos_only(),
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
Extract a diagonal or construct a diagonal matrix.
If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the
:math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is
returned.
Args:
a (array): 1-D or 2-D input array.
k (int, optional): The diagonal to extract or construct.
Default: ``0``.
Returns:
array: The extracted diagonal or the constructed diagonal matrix.
)pbdoc");
}