mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Implement diagonal operator (#562)
* Implement diagonal operator This implements mx.diagonal in operator level, inspired by @ManishAradwad. * added `mx.diag` with tests * corrected few things * nits in bindings * updates to diag --------- Co-authored-by: ManishAradwad <manisharadwad@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
65d0b8df9f
commit
3f7aba8498
@ -35,6 +35,8 @@ Operations
|
|||||||
cos
|
cos
|
||||||
cosh
|
cosh
|
||||||
dequantize
|
dequantize
|
||||||
|
diag
|
||||||
|
diagonal
|
||||||
divide
|
divide
|
||||||
divmod
|
divmod
|
||||||
equal
|
equal
|
||||||
|
@ -395,7 +395,7 @@ class array {
|
|||||||
// The ArrayDesc contains the details of the materialized array including the
|
// The ArrayDesc contains the details of the materialized array including the
|
||||||
// shape, strides, the data type. It also includes
|
// shape, strides, the data type. It also includes
|
||||||
// the primitive which knows how to compute the array's data from its inputs
|
// the primitive which knows how to compute the array's data from its inputs
|
||||||
// and a the list of array's inputs for the primitive.
|
// and the list of array's inputs for the primitive.
|
||||||
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
|
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
78
mlx/ops.cpp
78
mlx/ops.cpp
@ -227,7 +227,7 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
|
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
if (n <= 0 || m <= 0) {
|
if (n <= 0 || m <= 0) {
|
||||||
throw std::invalid_argument("N and M must be positive integers.");
|
throw std::invalid_argument("[eye] N and M must be positive integers.");
|
||||||
}
|
}
|
||||||
array result = zeros({n, m}, dtype, s);
|
array result = zeros({n, m}, dtype, s);
|
||||||
if (k >= m || -k >= n) {
|
if (k >= m || -k >= n) {
|
||||||
@ -3251,4 +3251,80 @@ array addmm(
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array diagonal(
|
||||||
|
const array& a,
|
||||||
|
int offset /* = 0 */,
|
||||||
|
int axis1 /* = 0 */,
|
||||||
|
int axis2 /* = 1 */,
|
||||||
|
StreamOrDevice s /* = {} */
|
||||||
|
) {
|
||||||
|
int ndim = a.ndim();
|
||||||
|
if (ndim < 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[diagonal] Array must have at least two dimensions, but got "
|
||||||
|
<< ndim << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;
|
||||||
|
if (ax1 < 0 || ax1 >= ndim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[diagonal] Invalid axis1 " << axis1 << " for array with " << ndim
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::out_of_range(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;
|
||||||
|
if (ax2 < 0 || ax2 >= ndim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[diagonal] Invalid axis2 " << axis2 << " for array with " << ndim
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::out_of_range(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ax1 == ax2) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[diagonal] axis1 and axis2 cannot be the same axis");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto off1 = std::max(-offset, 0);
|
||||||
|
auto off2 = std::max(offset, 0);
|
||||||
|
|
||||||
|
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
|
||||||
|
diag_size = std::max(diag_size, 0);
|
||||||
|
|
||||||
|
std::vector<array> indices = {
|
||||||
|
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
|
||||||
|
|
||||||
|
std::vector<int> slice_sizes = a.shape();
|
||||||
|
slice_sizes[ax1] = 1;
|
||||||
|
slice_sizes[ax2] = 1;
|
||||||
|
|
||||||
|
auto out = gather(a, indices, {ax1, ax2}, slice_sizes, s);
|
||||||
|
return moveaxis(squeeze(out, {ax1 + 1, ax2 + 1}, s), 0, -1, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
|
||||||
|
if (a.ndim() == 1) {
|
||||||
|
int a_size = a.size();
|
||||||
|
int n = a_size + std::abs(k);
|
||||||
|
auto res = zeros({n, n}, a.dtype(), s);
|
||||||
|
|
||||||
|
std::vector<array> indices;
|
||||||
|
auto s1 = std::max(0, -k);
|
||||||
|
auto s2 = std::max(0, k);
|
||||||
|
indices.push_back(arange(s1, a_size + s1, uint32, s));
|
||||||
|
indices.push_back(arange(s2, a_size + s2, uint32, s));
|
||||||
|
|
||||||
|
return scatter(res, indices, reshape(a, {a_size, 1, 1}, s), {0, 1}, s);
|
||||||
|
} else if (a.ndim() == 2) {
|
||||||
|
return diagonal(a, k, 0, 1, s);
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[diag] array must be 1-D or 2-D, got array with " << a.ndim()
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
11
mlx/ops.h
11
mlx/ops.h
@ -1105,4 +1105,15 @@ array addmm(
|
|||||||
const float& beta = 1.f,
|
const float& beta = 1.f,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Extract a diagonal or construct a diagonal array */
|
||||||
|
array diagonal(
|
||||||
|
const array& a,
|
||||||
|
int offset = 0,
|
||||||
|
int axis1 = 0,
|
||||||
|
int axis2 = 1,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Extract diagonal from a 2d array or create a diagonal matrix. */
|
||||||
|
array diag(const array& a, int k = 0, StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1486,5 +1486,26 @@ void init_array(py::module_& m) {
|
|||||||
"decimals"_a = 0,
|
"decimals"_a = 0,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
"See :func:`round`.");
|
"See :func:`round`.")
|
||||||
|
.def(
|
||||||
|
"diagonal",
|
||||||
|
[](const array& a,
|
||||||
|
int offset,
|
||||||
|
int axis1,
|
||||||
|
int axis2,
|
||||||
|
StreamOrDevice s) { return diagonal(a, offset, axis1, axis2, s); },
|
||||||
|
"offset"_a = 0,
|
||||||
|
"axis1"_a = 0,
|
||||||
|
"axis2"_a = 1,
|
||||||
|
"stream"_a = none,
|
||||||
|
"See :func:`diagonal`.")
|
||||||
|
.def(
|
||||||
|
"diag",
|
||||||
|
[](const array& a, int k, StreamOrDevice s) { return diag(a, k, s); },
|
||||||
|
"k"_a = 0,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
Extract a diagonal or construct a diagonal matrix.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -3577,4 +3577,61 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: ``alpha * (a @ b) + beta * c``
|
array: ``alpha * (a @ b) + beta * c``
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"diagonal",
|
||||||
|
&diagonal,
|
||||||
|
"a"_a,
|
||||||
|
"offset"_a = 0,
|
||||||
|
"axis1"_a = 0,
|
||||||
|
"axis2"_a = 1,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Return specified diagonals.
|
||||||
|
|
||||||
|
If ``a`` is 2-D, then a 1-D array containing the diagonal at the given
|
||||||
|
``offset`` is returned.
|
||||||
|
|
||||||
|
If ``a`` has more than two dimensions, then ``axis1`` and ``axis2``
|
||||||
|
determine the 2D subarrays from which diagonals are extracted. The new
|
||||||
|
shape is the original shape with ``axis1`` and ``axis2`` removed and a
|
||||||
|
new dimension inserted at the end corresponding to the diagonal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array
|
||||||
|
offset (int, optional): Offset of the diagonal from the main diagonal.
|
||||||
|
Can be positive or negative. Default: ``0``.
|
||||||
|
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
|
||||||
|
the diagonals should be taken. Default: ``0``.
|
||||||
|
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
|
||||||
|
the diagonals should be taken. Default: ``1``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The diagonals of the array.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"diag",
|
||||||
|
&diag,
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"k"_a = 0,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Extract a diagonal or construct a diagonal matrix.
|
||||||
|
If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the
|
||||||
|
:math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is
|
||||||
|
returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): 1-D or 2-D input array.
|
||||||
|
k (int, optional): The diagonal to extract or construct.
|
||||||
|
Default: ``0``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The extracted diagonal or the constructed diagonal matrix.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -1785,6 +1785,62 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out = a @ b
|
out = a @ b
|
||||||
self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))
|
self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))
|
||||||
|
|
||||||
|
def test_diagonal(self):
|
||||||
|
x = mx.array(
|
||||||
|
[
|
||||||
|
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
|
||||||
|
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected = [[0, 13], [4, 17], [8, 21]]
|
||||||
|
|
||||||
|
self.assertListEqual(mx.diagonal(x, 0, -1, 0).tolist(), expected)
|
||||||
|
|
||||||
|
expected = [[1, 14], [5, 18], [9, 22]]
|
||||||
|
self.assertListEqual(mx.diagonal(x, -1, 2, 0).tolist(), expected)
|
||||||
|
|
||||||
|
def test_diag(self):
|
||||||
|
# Test 1D input
|
||||||
|
x = mx.array([1, 2, 3, 4])
|
||||||
|
expected = mx.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]])
|
||||||
|
result = mx.diag(x)
|
||||||
|
self.assertTrue(mx.array_equal(result, expected))
|
||||||
|
|
||||||
|
# Test 1D with offset
|
||||||
|
x = mx.array([2, 6])
|
||||||
|
result = mx.diag(x, k=5)
|
||||||
|
expected = mx.array(np.diag(x, k=5))
|
||||||
|
self.assertTrue(mx.array_equal(result, expected))
|
||||||
|
|
||||||
|
# Test 2D input
|
||||||
|
x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
expected = mx.array([1, 5, 9])
|
||||||
|
result = mx.diag(x)
|
||||||
|
self.assertTrue(mx.array_equal(result, expected))
|
||||||
|
|
||||||
|
# Test with offset
|
||||||
|
expected = mx.array([2, 6])
|
||||||
|
result = mx.diag(x, 1)
|
||||||
|
self.assertTrue(mx.array_equal(result, expected))
|
||||||
|
|
||||||
|
# Test non-square
|
||||||
|
x = mx.array([[1, 2, 3], [4, 5, 6]])
|
||||||
|
result = mx.diag(x)
|
||||||
|
expected = mx.array(np.diag(x))
|
||||||
|
self.assertTrue(mx.array_equal(result, expected))
|
||||||
|
|
||||||
|
result = mx.diag(x, k=10)
|
||||||
|
expected = mx.array(np.diag(x, k=10))
|
||||||
|
self.assertTrue(mx.array_equal(result, expected))
|
||||||
|
|
||||||
|
result = mx.diag(x, k=-10)
|
||||||
|
expected = mx.array(np.diag(x, k=-10))
|
||||||
|
self.assertTrue(mx.array_equal(result, expected))
|
||||||
|
|
||||||
|
result = mx.diag(x, k=-1)
|
||||||
|
expected = mx.array(np.diag(x, k=-1))
|
||||||
|
self.assertTrue(mx.array_equal(result, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iostream> // TODO
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
@ -2634,3 +2633,86 @@ TEST_CASE("test divmod") {
|
|||||||
eval(out_holder);
|
eval(out_holder);
|
||||||
CHECK_EQ(out_holder[0].item<float>(), 1.0);
|
CHECK_EQ(out_holder[0].item<float>(), 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test diagonal") {
|
||||||
|
auto x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2});
|
||||||
|
auto out = diagonal(x);
|
||||||
|
CHECK(array_equal(out, array({0, 3}, {2})).item<bool>());
|
||||||
|
|
||||||
|
CHECK_THROWS_AS(diagonal(x, 1, 6, 0), std::out_of_range);
|
||||||
|
CHECK_THROWS_AS(diagonal(x, 1, 0, -3), std::out_of_range);
|
||||||
|
|
||||||
|
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 4});
|
||||||
|
out = diagonal(x, 2, 1, 0);
|
||||||
|
CHECK(array_equal(out, array({8}, {1})).item<bool>());
|
||||||
|
|
||||||
|
out = diagonal(x, -1, 0, 1);
|
||||||
|
CHECK(array_equal(out, array({4, 9}, {2})).item<bool>());
|
||||||
|
|
||||||
|
out = diagonal(x, -5, 0, 1);
|
||||||
|
eval(out);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{0});
|
||||||
|
|
||||||
|
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2});
|
||||||
|
out = diagonal(x, 1, 0, 1);
|
||||||
|
CHECK(array_equal(out, array({2, 3}, {2, 1})).item<bool>());
|
||||||
|
|
||||||
|
out = diagonal(x, 0, 2, 0);
|
||||||
|
CHECK(array_equal(out, array({0, 5, 2, 7}, {2, 2})).item<bool>());
|
||||||
|
|
||||||
|
out = diagonal(x, 1, -1, 0);
|
||||||
|
CHECK(array_equal(out, array({4, 9, 6, 11}, {2, 2})).item<bool>());
|
||||||
|
|
||||||
|
x = reshape(arange(16), {2, 2, 2, 2});
|
||||||
|
out = diagonal(x, 0, 0, 1);
|
||||||
|
CHECK(array_equal(out, array({0, 12, 1, 13, 2, 14, 3, 15}, {2, 2, 2}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
CHECK_THROWS_AS(diagonal(x, 0, 1, 1), std::invalid_argument);
|
||||||
|
|
||||||
|
x = array({0, 1}, {2});
|
||||||
|
CHECK_THROWS_AS(diagonal(x, 0, 0, 1), std::invalid_argument);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test diag") {
|
||||||
|
// To few or too many dimensions
|
||||||
|
CHECK_THROWS(diag(array(0.0)));
|
||||||
|
CHECK_THROWS(diag(array({0.0}, {1, 1, 1})));
|
||||||
|
|
||||||
|
// Test with 1D array
|
||||||
|
auto x = array({0, 1, 2, 3}, {4});
|
||||||
|
auto out = diag(x, 0);
|
||||||
|
CHECK(
|
||||||
|
array_equal(
|
||||||
|
out, array({0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3}, {4, 4}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
out = diag(x, 1);
|
||||||
|
CHECK(array_equal(
|
||||||
|
out,
|
||||||
|
array(
|
||||||
|
{0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
||||||
|
2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0},
|
||||||
|
{5, 5}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
out = diag(x, -1);
|
||||||
|
CHECK(array_equal(
|
||||||
|
out,
|
||||||
|
array(
|
||||||
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
|
||||||
|
0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0},
|
||||||
|
{5, 5}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
// Test with 2D array
|
||||||
|
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8}, {3, 3});
|
||||||
|
out = diag(x, 0);
|
||||||
|
CHECK(array_equal(out, array({0, 4, 8}, {3})).item<bool>());
|
||||||
|
|
||||||
|
out = diag(x, 1);
|
||||||
|
CHECK(array_equal(out, array({1, 5}, {2})).item<bool>());
|
||||||
|
|
||||||
|
out = diag(x, -1);
|
||||||
|
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user