mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
* working c++ trace implementation * updated throw + added overloads * added python binding for trace function * pre-commit reformatting * add trace to docs * resolve comments * remove to_stream call
This commit is contained in:
56
mlx/ops.cpp
56
mlx/ops.cpp
@@ -4113,6 +4113,62 @@ array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
}
|
||||
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
int ndim = a.ndim();
|
||||
if (ndim < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[trace] Array must have at least two dimensions, but got " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;
|
||||
if (ax1 < 0 || ax1 >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[trace] Invalid axis1 " << axis1 << " for array with " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::out_of_range(msg.str());
|
||||
}
|
||||
|
||||
auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;
|
||||
if (ax2 < 0 || ax2 >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[trace] Invalid axis2 " << axis2 << " for array with " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::out_of_range(msg.str());
|
||||
}
|
||||
|
||||
if (ax1 == ax2) {
|
||||
throw std::invalid_argument(
|
||||
"[trace] axis1 and axis2 cannot be the same axis");
|
||||
}
|
||||
|
||||
return sum(
|
||||
astype(diagonal(a, offset, axis1, axis2, s), dtype, s),
|
||||
/* axis = */ -1,
|
||||
/* keepdims = */ false,
|
||||
s);
|
||||
}
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto dtype = a.dtype();
|
||||
return trace(a, offset, axis1, axis2, dtype, s);
|
||||
}
|
||||
array trace(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = a.dtype();
|
||||
return trace(a, 0, 0, 1, dtype, s);
|
||||
}
|
||||
|
||||
std::vector<array> depends(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& dependencies) {
|
||||
|
16
mlx/ops.h
16
mlx/ops.h
@@ -1228,6 +1228,22 @@ array diagonal(
|
||||
/** Extract diagonal from a 2d array or create a diagonal matrix. */
|
||||
array diag(const array& a, int k = 0, StreamOrDevice s = {});
|
||||
|
||||
/** Return the sum along a specified diagonal in the given array. */
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s = {});
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
StreamOrDevice s = {});
|
||||
array trace(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Implements the identity function but allows injecting dependencies to other
|
||||
* arrays. This ensures that these other arrays will have been computed
|
||||
|
Reference in New Issue
Block a user