add mx.trace (#1143) (#1147)

* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
This commit is contained in:
Abe Leininger
2024-05-22 18:50:27 -04:00
committed by GitHub
parent e110ca11e2
commit 79ef49b2c2
6 changed files with 161 additions and 0 deletions

View File

@@ -1228,6 +1228,22 @@ array diagonal(
/** Extract diagonal from a 2d array or create a diagonal matrix. */
array diag(const array& a, int k = 0, StreamOrDevice s = {});
/** Return the sum along a specified diagonal in the given array. */
array trace(
const array& a,
int offset,
int axis1,
int axis2,
Dtype dtype,
StreamOrDevice s = {});
array trace(
const array& a,
int offset,
int axis1,
int axis2,
StreamOrDevice s = {});
array trace(const array& a, StreamOrDevice s = {});
/**
* Implements the identity function but allows injecting dependencies to other
* arrays. This ensures that these other arrays will have been computed