mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* working c++ trace implementation * updated throw + added overloads * added python binding for trace function * pre-commit reformatting * add trace to docs * resolve comments * remove to_stream call
This commit is contained in:
16
mlx/ops.h
16
mlx/ops.h
@@ -1228,6 +1228,22 @@ array diagonal(
|
||||
/** Extract diagonal from a 2d array or create a diagonal matrix. */
|
||||
array diag(const array& a, int k = 0, StreamOrDevice s = {});
|
||||
|
||||
/** Return the sum along a specified diagonal in the given array. */
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s = {});
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
StreamOrDevice s = {});
|
||||
array trace(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Implements the identity function but allows injecting dependencies to other
|
||||
* arrays. This ensures that these other arrays will have been computed
|
||||
|
||||
Reference in New Issue
Block a user