mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-04 10:38:10 +08:00
* working c++ trace implementation * updated throw + added overloads * added python binding for trace function * pre-commit reformatting * add trace to docs * resolve comments * remove to_stream call
This commit is contained in:
@@ -4065,6 +4065,45 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The extracted diagonal or the constructed diagonal matrix.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"trace",
|
||||
[](const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
std::optional<Dtype> dtype,
|
||||
StreamOrDevice s) {
|
||||
if (!dtype.has_value()) {
|
||||
return trace(a, offset, axis1, axis2, s);
|
||||
}
|
||||
return trace(a, offset, axis1, axis2, dtype.value(), s);
|
||||
},
|
||||
nb::arg(),
|
||||
"offset"_a = 0,
|
||||
"axis1"_a = 0,
|
||||
"axis2"_a = 1,
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Return the sum along a specified diagonal in the given array.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
offset (int, optional): Offset of the diagonal from the main diagonal.
|
||||
Can be positive or negative. Default: ``0``.
|
||||
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
|
||||
the diagonals should be taken. Default: ``0``.
|
||||
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
|
||||
the diagonals should be taken. Default: ``1``.
|
||||
dtype (Dtype, optional): Data type of the output array. If
|
||||
unspecified the output type is inferred from the input array.
|
||||
|
||||
Returns:
|
||||
array: Sum of specified diagonal.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"atleast_1d",
|
||||
[](const nb::args& arys, StreamOrDevice s) -> nb::object {
|
||||
|
||||
Reference in New Issue
Block a user