Added eye/identity ops (#119)

`eye` and `identity` C++ and Python ops
This commit is contained in:
Cyril Zakka, MD
2023-12-11 12:38:17 -08:00
committed by GitHub
parent 69505b4e9b
commit e080290ba4
6 changed files with 175 additions and 0 deletions

View File

@@ -194,6 +194,30 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) {
return ones(a.shape(), a.dtype(), to_stream(s));
}
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
if (n <= 0 || m <= 0) {
throw std::invalid_argument("N and M must be positive integers.");
}
array result = zeros({n * m}, dtype, s);
if (k >= m || -k >= n) {
return reshape(result, {n, m}, s);
}
int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);
int start_index = (k >= 0) ? k : -k * m;
array diag_indices_array = arange(
start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s);
array ones_array = ones({diagonal_length, 1}, dtype, s);
result = scatter(result, diag_indices_array, ones_array, 0, s);
return reshape(result, {n, m}, s);
}
array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
return eye(n, n, 0, dtype, s);
}
array reshape(
const array& a,
std::vector<int> shape,

View File

@@ -87,6 +87,29 @@ inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
}
array ones_like(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape (n,m) with ones in the specified diagonal
* k, and zeros everywhere else. */
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
return eye(n, n, 0, dtype, s);
}
inline array eye(int n, int m, StreamOrDevice s = {}) {
return eye(n, m, 0, float32, s);
}
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
return eye(n, m, k, float32, s);
}
inline array eye(int n, StreamOrDevice s = {}) {
return eye(n, n, 0, float32, s);
}
/** Create a square matrix of shape (n,n) of zeros, and ones in the major
* diagonal. */
array identity(int n, Dtype dtype, StreamOrDevice s = {});
inline array identity(int n, StreamOrDevice s = {}) {
return identity(n, float32, s);
}
/** array manipulation */
/** Reshape an array to the given shape. */