From 3f7aba849819e07b7b39f0c61f15a3ce074ea79a Mon Sep 17 00:00:00 2001 From: Jacket <44538064+PRESIDENT810@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:45:48 -0600 Subject: [PATCH] Implement diagonal operator (#562) * Implement diagonal operator This implements mx.diagonal in operator level, inspired by @ManishAradwad. * added `mx.diag` with tests * corrected few things * nits in bindings * updates to diag --------- Co-authored-by: ManishAradwad Co-authored-by: Awni Hannun --- docs/src/python/ops.rst | 2 + mlx/array.h | 2 +- mlx/ops.cpp | 78 ++++++++++++++++++++++++++++++++++++- mlx/ops.h | 11 ++++++ python/src/array.cpp | 23 ++++++++++- python/src/ops.cpp | 57 +++++++++++++++++++++++++++ python/tests/test_ops.py | 56 +++++++++++++++++++++++++++ tests/ops_tests.cpp | 84 +++++++++++++++++++++++++++++++++++++++- 8 files changed, 309 insertions(+), 4 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 649724a34..09e2d5f71 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -35,6 +35,8 @@ Operations cos cosh dequantize + diag + diagonal divide divmod equal diff --git a/mlx/array.h b/mlx/array.h index 3b70153e3..6e8375a71 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -395,7 +395,7 @@ class array { // The ArrayDesc contains the details of the materialized array including the // shape, strides, the data type. It also includes // the primitive which knows how to compute the array's data from its inputs - // and a the list of array's inputs for the primitive. + // and the list of array's inputs for the primitive. std::shared_ptr array_desc_{nullptr}; }; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d2768d32c..57ebda494 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -227,7 +227,7 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) { array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) { if (n <= 0 || m <= 0) { - throw std::invalid_argument("N and M must be positive integers."); + throw std::invalid_argument("[eye] N and M must be positive integers."); } array result = zeros({n, m}, dtype, s); if (k >= m || -k >= n) { @@ -3251,4 +3251,80 @@ array addmm( return out; } +array diagonal( + const array& a, + int offset /* = 0 */, + int axis1 /* = 0 */, + int axis2 /* = 1 */, + StreamOrDevice s /* = {} */ +) { + int ndim = a.ndim(); + if (ndim < 2) { + std::ostringstream msg; + msg << "[diagonal] Array must have at least two dimensions, but got " + << ndim << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1; + if (ax1 < 0 || ax1 >= ndim) { + std::ostringstream msg; + msg << "[diagonal] Invalid axis1 " << axis1 << " for array with " << ndim + << " dimensions."; + throw std::out_of_range(msg.str()); + } + + auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2; + if (ax2 < 0 || ax2 >= ndim) { + std::ostringstream msg; + msg << "[diagonal] Invalid axis2 " << axis2 << " for array with " << ndim + << " dimensions."; + throw std::out_of_range(msg.str()); + } + + if (ax1 == ax2) { + throw std::invalid_argument( + "[diagonal] axis1 and axis2 cannot be the same axis"); + } + + auto off1 = std::max(-offset, 0); + auto off2 = std::max(offset, 0); + + auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2); + diag_size = std::max(diag_size, 0); + + std::vector indices = { + arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)}; + + std::vector slice_sizes = a.shape(); + slice_sizes[ax1] = 1; + slice_sizes[ax2] = 1; + + auto out = gather(a, indices, {ax1, ax2}, slice_sizes, s); + return moveaxis(squeeze(out, {ax1 + 1, ax2 + 1}, s), 0, -1, s); +} + +array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) { + if (a.ndim() == 1) { + int a_size = a.size(); + int n = a_size + std::abs(k); + auto res = zeros({n, n}, a.dtype(), s); + + std::vector indices; + auto s1 = std::max(0, -k); + auto s2 = std::max(0, k); + indices.push_back(arange(s1, a_size + s1, uint32, s)); + indices.push_back(arange(s2, a_size + s2, uint32, s)); + + return scatter(res, indices, reshape(a, {a_size, 1, 1}, s), {0, 1}, s); + } else if (a.ndim() == 2) { + return diagonal(a, k, 0, 1, s); + } else { + std::ostringstream msg; + msg << "[diag] array must be 1-D or 2-D, got array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 0d17f2d2c..506a29d84 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1105,4 +1105,15 @@ array addmm( const float& beta = 1.f, StreamOrDevice s = {}); +/** Extract a diagonal or construct a diagonal array */ +array diagonal( + const array& a, + int offset = 0, + int axis1 = 0, + int axis2 = 1, + StreamOrDevice s = {}); + +/** Extract diagonal from a 2d array or create a diagonal matrix. */ +array diag(const array& a, int k = 0, StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/python/src/array.cpp b/python/src/array.cpp index 9115ada6e..acb4f8edc 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1486,5 +1486,26 @@ void init_array(py::module_& m) { "decimals"_a = 0, py::kw_only(), "stream"_a = none, - "See :func:`round`."); + "See :func:`round`.") + .def( + "diagonal", + [](const array& a, + int offset, + int axis1, + int axis2, + StreamOrDevice s) { return diagonal(a, offset, axis1, axis2, s); }, + "offset"_a = 0, + "axis1"_a = 0, + "axis2"_a = 1, + "stream"_a = none, + "See :func:`diagonal`.") + .def( + "diag", + [](const array& a, int k, StreamOrDevice s) { return diag(a, k, s); }, + "k"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Extract a diagonal or construct a diagonal matrix. + )pbdoc"); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 10eeac27a..02a401543 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3577,4 +3577,61 @@ void init_ops(py::module_& m) { Returns: array: ``alpha * (a @ b) + beta * c`` )pbdoc"); + m.def( + "diagonal", + &diagonal, + "a"_a, + "offset"_a = 0, + "axis1"_a = 0, + "axis2"_a = 1, + "stream"_a = none, + R"pbdoc( + diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array + + Return specified diagonals. + + If ``a`` is 2-D, then a 1-D array containing the diagonal at the given + ``offset`` is returned. + + If ``a`` has more than two dimensions, then ``axis1`` and ``axis2`` + determine the 2D subarrays from which diagonals are extracted. The new + shape is the original shape with ``axis1`` and ``axis2`` removed and a + new dimension inserted at the end corresponding to the diagonal. + + Args: + a (array): Input array + offset (int, optional): Offset of the diagonal from the main diagonal. + Can be positive or negative. Default: ``0``. + axis1 (int, optional): The first axis of the 2-D sub-arrays from which + the diagonals should be taken. Default: ``0``. + axis2 (int, optional): The second axis of the 2-D sub-arrays from which + the diagonals should be taken. Default: ``1``. + + Returns: + array: The diagonals of the array. + )pbdoc"); + m.def( + "diag", + &diag, + "a"_a, + py::pos_only(), + "k"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array + + Extract a diagonal or construct a diagonal matrix. + If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the + :math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is + returned. + + Args: + a (array): 1-D or 2-D input array. + k (int, optional): The diagonal to extract or construct. + Default: ``0``. + + Returns: + array: The extracted diagonal or the constructed diagonal matrix. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 0797c85eb..c82d9b5c5 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1785,6 +1785,62 @@ class TestOps(mlx_tests.MLXTestCase): out = a @ b self.assertTrue(mx.array_equal(out, mx.zeros((10, 10)))) + def test_diagonal(self): + x = mx.array( + [ + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], + ] + ) + expected = [[0, 13], [4, 17], [8, 21]] + + self.assertListEqual(mx.diagonal(x, 0, -1, 0).tolist(), expected) + + expected = [[1, 14], [5, 18], [9, 22]] + self.assertListEqual(mx.diagonal(x, -1, 2, 0).tolist(), expected) + + def test_diag(self): + # Test 1D input + x = mx.array([1, 2, 3, 4]) + expected = mx.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]) + result = mx.diag(x) + self.assertTrue(mx.array_equal(result, expected)) + + # Test 1D with offset + x = mx.array([2, 6]) + result = mx.diag(x, k=5) + expected = mx.array(np.diag(x, k=5)) + self.assertTrue(mx.array_equal(result, expected)) + + # Test 2D input + x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + expected = mx.array([1, 5, 9]) + result = mx.diag(x) + self.assertTrue(mx.array_equal(result, expected)) + + # Test with offset + expected = mx.array([2, 6]) + result = mx.diag(x, 1) + self.assertTrue(mx.array_equal(result, expected)) + + # Test non-square + x = mx.array([[1, 2, 3], [4, 5, 6]]) + result = mx.diag(x) + expected = mx.array(np.diag(x)) + self.assertTrue(mx.array_equal(result, expected)) + + result = mx.diag(x, k=10) + expected = mx.array(np.diag(x, k=10)) + self.assertTrue(mx.array_equal(result, expected)) + + result = mx.diag(x, k=-10) + expected = mx.array(np.diag(x, k=-10)) + self.assertTrue(mx.array_equal(result, expected)) + + result = mx.diag(x, k=-1) + expected = mx.array(np.diag(x, k=-1)) + self.assertTrue(mx.array_equal(result, expected)) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 9e5f9e277..e52c1294f 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1,6 +1,5 @@ // Copyright © 2023 Apple Inc. #include -#include // TODO #include #include "doctest/doctest.h" @@ -2634,3 +2633,86 @@ TEST_CASE("test divmod") { eval(out_holder); CHECK_EQ(out_holder[0].item(), 1.0); } + +TEST_CASE("test diagonal") { + auto x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2}); + auto out = diagonal(x); + CHECK(array_equal(out, array({0, 3}, {2})).item()); + + CHECK_THROWS_AS(diagonal(x, 1, 6, 0), std::out_of_range); + CHECK_THROWS_AS(diagonal(x, 1, 0, -3), std::out_of_range); + + x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 4}); + out = diagonal(x, 2, 1, 0); + CHECK(array_equal(out, array({8}, {1})).item()); + + out = diagonal(x, -1, 0, 1); + CHECK(array_equal(out, array({4, 9}, {2})).item()); + + out = diagonal(x, -5, 0, 1); + eval(out); + CHECK_EQ(out.shape(), std::vector{0}); + + x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2}); + out = diagonal(x, 1, 0, 1); + CHECK(array_equal(out, array({2, 3}, {2, 1})).item()); + + out = diagonal(x, 0, 2, 0); + CHECK(array_equal(out, array({0, 5, 2, 7}, {2, 2})).item()); + + out = diagonal(x, 1, -1, 0); + CHECK(array_equal(out, array({4, 9, 6, 11}, {2, 2})).item()); + + x = reshape(arange(16), {2, 2, 2, 2}); + out = diagonal(x, 0, 0, 1); + CHECK(array_equal(out, array({0, 12, 1, 13, 2, 14, 3, 15}, {2, 2, 2})) + .item()); + + CHECK_THROWS_AS(diagonal(x, 0, 1, 1), std::invalid_argument); + + x = array({0, 1}, {2}); + CHECK_THROWS_AS(diagonal(x, 0, 0, 1), std::invalid_argument); +} + +TEST_CASE("test diag") { + // To few or too many dimensions + CHECK_THROWS(diag(array(0.0))); + CHECK_THROWS(diag(array({0.0}, {1, 1, 1}))); + + // Test with 1D array + auto x = array({0, 1, 2, 3}, {4}); + auto out = diag(x, 0); + CHECK( + array_equal( + out, array({0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3}, {4, 4})) + .item()); + + out = diag(x, 1); + CHECK(array_equal( + out, + array( + {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0}, + {5, 5})) + .item()); + + out = diag(x, -1); + CHECK(array_equal( + out, + array( + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0}, + {5, 5})) + .item()); + + // Test with 2D array + x = array({0, 1, 2, 3, 4, 5, 6, 7, 8}, {3, 3}); + out = diag(x, 0); + CHECK(array_equal(out, array({0, 4, 8}, {3})).item()); + + out = diag(x, 1); + CHECK(array_equal(out, array({1, 5}, {2})).item()); + + out = diag(x, -1); + CHECK(array_equal(out, array({3, 7}, {2})).item()); +}