add mx.trace (#1143) (#1147)

* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
This commit is contained in:
Abe Leininger
2024-05-22 18:50:27 -04:00
committed by GitHub
parent e110ca11e2
commit 79ef49b2c2
6 changed files with 161 additions and 0 deletions

View File

@@ -4065,6 +4065,45 @@ void init_ops(nb::module_& m) {
Returns:
array: The extracted diagonal or the constructed diagonal matrix.
)pbdoc");
m.def(
"trace",
[](const array& a,
int offset,
int axis1,
int axis2,
std::optional<Dtype> dtype,
StreamOrDevice s) {
if (!dtype.has_value()) {
return trace(a, offset, axis1, axis2, s);
}
return trace(a, offset, axis1, axis2, dtype.value(), s);
},
nb::arg(),
"offset"_a = 0,
"axis1"_a = 0,
"axis2"_a = 1,
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the sum along a specified diagonal in the given array.
Args:
a (array): Input array
offset (int, optional): Offset of the diagonal from the main diagonal.
Can be positive or negative. Default: ``0``.
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
the diagonals should be taken. Default: ``0``.
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
the diagonals should be taken. Default: ``1``.
dtype (Dtype, optional): Data type of the output array. If
unspecified the output type is inferred from the input array.
Returns:
array: Sum of specified diagonal.
)pbdoc");
m.def(
"atleast_1d",
[](const nb::args& arys, StreamOrDevice s) -> nb::object {