Implement diagonal operator (#562)

* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jacket
2024-01-30 11:45:48 -06:00
committed by GitHub
parent 65d0b8df9f
commit 3f7aba8498
8 changed files with 309 additions and 4 deletions

View File

@@ -395,7 +395,7 @@ class array {
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and a the list of array's inputs for the primitive.
// and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};

View File

@@ -227,7 +227,7 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) {
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
if (n <= 0 || m <= 0) {
throw std::invalid_argument("N and M must be positive integers.");
throw std::invalid_argument("[eye] N and M must be positive integers.");
}
array result = zeros({n, m}, dtype, s);
if (k >= m || -k >= n) {
@@ -3251,4 +3251,80 @@ array addmm(
return out;
}
array diagonal(
const array& a,
int offset /* = 0 */,
int axis1 /* = 0 */,
int axis2 /* = 1 */,
StreamOrDevice s /* = {} */
) {
int ndim = a.ndim();
if (ndim < 2) {
std::ostringstream msg;
msg << "[diagonal] Array must have at least two dimensions, but got "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;
if (ax1 < 0 || ax1 >= ndim) {
std::ostringstream msg;
msg << "[diagonal] Invalid axis1 " << axis1 << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;
if (ax2 < 0 || ax2 >= ndim) {
std::ostringstream msg;
msg << "[diagonal] Invalid axis2 " << axis2 << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
if (ax1 == ax2) {
throw std::invalid_argument(
"[diagonal] axis1 and axis2 cannot be the same axis");
}
auto off1 = std::max(-offset, 0);
auto off2 = std::max(offset, 0);
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
diag_size = std::max(diag_size, 0);
std::vector<array> indices = {
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
std::vector<int> slice_sizes = a.shape();
slice_sizes[ax1] = 1;
slice_sizes[ax2] = 1;
auto out = gather(a, indices, {ax1, ax2}, slice_sizes, s);
return moveaxis(squeeze(out, {ax1 + 1, ax2 + 1}, s), 0, -1, s);
}
array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
if (a.ndim() == 1) {
int a_size = a.size();
int n = a_size + std::abs(k);
auto res = zeros({n, n}, a.dtype(), s);
std::vector<array> indices;
auto s1 = std::max(0, -k);
auto s2 = std::max(0, k);
indices.push_back(arange(s1, a_size + s1, uint32, s));
indices.push_back(arange(s2, a_size + s2, uint32, s));
return scatter(res, indices, reshape(a, {a_size, 1, 1}, s), {0, 1}, s);
} else if (a.ndim() == 2) {
return diagonal(a, k, 0, 1, s);
} else {
std::ostringstream msg;
msg << "[diag] array must be 1-D or 2-D, got array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
}
} // namespace mlx::core

View File

@@ -1105,4 +1105,15 @@ array addmm(
const float& beta = 1.f,
StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */
array diagonal(
const array& a,
int offset = 0,
int axis1 = 0,
int axis2 = 1,
StreamOrDevice s = {});
/** Extract diagonal from a 2d array or create a diagonal matrix. */
array diag(const array& a, int k = 0, StreamOrDevice s = {});
} // namespace mlx::core