mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
24
mlx/ops.cpp
24
mlx/ops.cpp
@@ -194,6 +194,30 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return ones(a.shape(), a.dtype(), to_stream(s));
|
||||
}
|
||||
|
||||
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
if (n <= 0 || m <= 0) {
|
||||
throw std::invalid_argument("N and M must be positive integers.");
|
||||
}
|
||||
array result = zeros({n * m}, dtype, s);
|
||||
if (k >= m || -k >= n) {
|
||||
return reshape(result, {n, m}, s);
|
||||
}
|
||||
|
||||
int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);
|
||||
int start_index = (k >= 0) ? k : -k * m;
|
||||
|
||||
array diag_indices_array = arange(
|
||||
start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s);
|
||||
array ones_array = ones({diagonal_length, 1}, dtype, s);
|
||||
result = scatter(result, diag_indices_array, ones_array, 0, s);
|
||||
|
||||
return reshape(result, {n, m}, s);
|
||||
}
|
||||
|
||||
array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
return eye(n, n, 0, dtype, s);
|
||||
}
|
||||
|
||||
array reshape(
|
||||
const array& a,
|
||||
std::vector<int> shape,
|
||||
|
23
mlx/ops.h
23
mlx/ops.h
@@ -87,6 +87,29 @@ inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
||||
}
|
||||
array ones_like(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Fill an array of the given shape (n,m) with ones in the specified diagonal
|
||||
* k, and zeros everywhere else. */
|
||||
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
|
||||
return eye(n, n, 0, dtype, s);
|
||||
}
|
||||
inline array eye(int n, int m, StreamOrDevice s = {}) {
|
||||
return eye(n, m, 0, float32, s);
|
||||
}
|
||||
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
|
||||
return eye(n, m, k, float32, s);
|
||||
}
|
||||
inline array eye(int n, StreamOrDevice s = {}) {
|
||||
return eye(n, n, 0, float32, s);
|
||||
}
|
||||
|
||||
/** Create a square matrix of shape (n,n) of zeros, and ones in the major
|
||||
* diagonal. */
|
||||
array identity(int n, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array identity(int n, StreamOrDevice s = {}) {
|
||||
return identity(n, float32, s);
|
||||
}
|
||||
|
||||
/** array manipulation */
|
||||
|
||||
/** Reshape an array to the given shape. */
|
||||
|
Reference in New Issue
Block a user