mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Implement diagonal operator (#562)
* Implement diagonal operator This implements mx.diagonal in operator level, inspired by @ManishAradwad. * added `mx.diag` with tests * corrected few things * nits in bindings * updates to diag --------- Co-authored-by: ManishAradwad <manisharadwad@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1486,5 +1486,26 @@ void init_array(py::module_& m) {
|
||||
"decimals"_a = 0,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
"See :func:`round`.");
|
||||
"See :func:`round`.")
|
||||
.def(
|
||||
"diagonal",
|
||||
[](const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
StreamOrDevice s) { return diagonal(a, offset, axis1, axis2, s); },
|
||||
"offset"_a = 0,
|
||||
"axis1"_a = 0,
|
||||
"axis2"_a = 1,
|
||||
"stream"_a = none,
|
||||
"See :func:`diagonal`.")
|
||||
.def(
|
||||
"diag",
|
||||
[](const array& a, int k, StreamOrDevice s) { return diag(a, k, s); },
|
||||
"k"_a = 0,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Extract a diagonal or construct a diagonal matrix.
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user