add mx.trace (#1143) (#1147)

* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
This commit is contained in:
Abe Leininger
2024-05-22 18:50:27 -04:00
committed by GitHub
parent e110ca11e2
commit 79ef49b2c2
6 changed files with 161 additions and 0 deletions

View File

@@ -4113,6 +4113,62 @@ array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
}
}
array trace(
const array& a,
int offset,
int axis1,
int axis2,
Dtype dtype,
StreamOrDevice s /* = {} */) {
int ndim = a.ndim();
if (ndim < 2) {
std::ostringstream msg;
msg << "[trace] Array must have at least two dimensions, but got " << ndim
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;
if (ax1 < 0 || ax1 >= ndim) {
std::ostringstream msg;
msg << "[trace] Invalid axis1 " << axis1 << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;
if (ax2 < 0 || ax2 >= ndim) {
std::ostringstream msg;
msg << "[trace] Invalid axis2 " << axis2 << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
if (ax1 == ax2) {
throw std::invalid_argument(
"[trace] axis1 and axis2 cannot be the same axis");
}
return sum(
astype(diagonal(a, offset, axis1, axis2, s), dtype, s),
/* axis = */ -1,
/* keepdims = */ false,
s);
}
array trace(
const array& a,
int offset,
int axis1,
int axis2,
StreamOrDevice s /* = {} */) {
auto dtype = a.dtype();
return trace(a, offset, axis1, axis2, dtype, s);
}
array trace(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = a.dtype();
return trace(a, 0, 0, 1, dtype, s);
}
std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies) {

View File

@@ -1228,6 +1228,22 @@ array diagonal(
/** Extract diagonal from a 2d array or create a diagonal matrix. */
array diag(const array& a, int k = 0, StreamOrDevice s = {});
/** Return the sum along a specified diagonal in the given array. */
array trace(
const array& a,
int offset,
int axis1,
int axis2,
Dtype dtype,
StreamOrDevice s = {});
array trace(
const array& a,
int offset,
int axis1,
int axis2,
StreamOrDevice s = {});
array trace(const array& a, StreamOrDevice s = {});
/**
* Implements the identity function but allows injecting dependencies to other
* arrays. This ensures that these other arrays will have been computed