From 9b1209373992ef21a45b4b469bf0bfcb57dc533d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 7 Oct 2024 17:21:42 -0700 Subject: [PATCH] Add the roll op (#1455) --- docs/src/python/ops.rst | 1 + mlx/ops.cpp | 90 ++++++++++++++++++++++++++++++++++++++++ mlx/ops.h | 24 +++++++++++ python/src/ops.cpp | 50 ++++++++++++++++++++++ python/tests/test_ops.py | 34 +++++++++++++++ tests/ops_tests.cpp | 32 ++++++++++++++ 6 files changed, 231 insertions(+) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index e0d70ea14..e3c50e2ff 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -130,6 +130,7 @@ Operations repeat reshape right_shift + roll round rsqrt save diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7bb9efe8c..c415e4504 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4567,4 +4567,94 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) { out_shape, dtype, std::make_shared(to_stream(s), dtype), {a}); } +array roll( + const array& a, + const std::vector& shift, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } + + if (shift.size() < axes.size()) { + std::ostringstream msg; + msg << "[roll] At least one shift value per axis is required, " + << shift.size() << " provided for " << axes.size() << " axes."; + throw std::invalid_argument(msg.str()); + } + + std::vector parts; + array result = a; + for (int i = 0; i < axes.size(); i++) { + int ax = axes[i]; + if (ax < 0) { + ax += a.ndim(); + } + if (ax < 0 || ax >= a.ndim()) { + std::ostringstream msg; + msg << "[roll] Invalid axis " << axes[i] << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + int sh = shift[i]; + int split_index = + (sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); + + parts = split(result, std::vector{split_index}, ax, s); + std::swap(parts[0], parts[1]); + result = concatenate(parts, ax, s); + } + + return result; +} + +array roll(const array& a, int shift, StreamOrDevice s /* = {} */) { + auto shape = a.shape(); + return reshape( + roll( + reshape(a, std::vector{-1}, s), + std::vector{shift}, + std::vector{0}, + s), + std::move(shape), + s); +} + +array roll( + const array& a, + const std::vector& shift, + StreamOrDevice s /* = {} */) { + int total_shift = 0; + for (auto& s : shift) { + total_shift += s; + } + return roll(a, total_shift, s); +} + +array roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) { + return roll(a, std::vector{shift}, std::vector{axis}, s); +} + +array roll( + const array& a, + int shift, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + std::vector shifts(axes.size(), shift); + return roll(a, shifts, axes, s); +} + +array roll( + const array& a, + const std::vector& shift, + int axis, + StreamOrDevice s /* = {} */) { + int total_shift = 0; + for (auto& s : shift) { + total_shift += s; + } + return roll(a, std::vector{total_shift}, std::vector{axis}, s); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 723e73e1a..32f77514f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1456,6 +1456,30 @@ array right_shift(const array& a, const array& b, StreamOrDevice s = {}); array operator>>(const array& a, const array& b); array view(const array& a, const Dtype& dtype, StreamOrDevice s = {}); + +/** Roll elements along an axis and introduce them on the other side */ +array roll(const array& a, int shift, StreamOrDevice s = {}); +array roll( + const array& a, + const std::vector& shift, + StreamOrDevice s = {}); +array roll(const array& a, int shift, int axis, StreamOrDevice s = {}); +array roll( + const array& a, + int shift, + const std::vector& axes, + StreamOrDevice s = {}); +array roll( + const array& a, + const std::vector& shift, + int axis, + StreamOrDevice s = {}); +array roll( + const array& a, + const std::vector& shift, + const std::vector& axes, + StreamOrDevice s = {}); + /** @} */ } // namespace mlx::core diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8e74907b0..20278b40d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4795,4 +4795,54 @@ void init_ops(nb::module_& m) { Returns: array: The output array. )pbdoc"); + m.def( + "roll", + [](const array& a, + const IntOrVec& shift, + const IntOrVec& axis, + StreamOrDevice s) { + return std::visit( + [&](auto sh, auto ax) -> array { + using T = decltype(ax); + using V = decltype(sh); + + if constexpr (std::is_same_v) { + throw std::invalid_argument( + "[roll] Expected two arguments but only one was given."); + } else { + if constexpr (std::is_same_v) { + return roll(a, sh, s); + } else { + return roll(a, sh, ax, s); + } + } + }, + shift, + axis); + }, + nb::arg(), + "shift"_a, + "axis"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def roll(a: array, shift: Union[int, Tuple[int]], axis: Union[None, int, Tuple[int]] = None, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Roll array elements along a given axis. + + Elements that are rolled beyond the end of the array are introduced at + the beggining and vice-versa. + + If the axis is not provided the array is flattened, rolled and then the + shape is restored. + + Args: + a (array): Input array + shift (int or tuple(int)): The number of places by which elements + are shifted. If positive the array is rolled to the right, if + negative it is rolled to the left. If an int is provided but the + axis is a tuple then the same value is used for all axes. + axis (int or tuple(int), optional): The axis or axes along which to + roll the elements. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 381f0e8ca..bc5f4e8b4 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2641,6 +2641,40 @@ class TestOps(mlx_tests.MLXTestCase): out_t = vht_t(xb) np.testing.assert_allclose(out, out_t, atol=1e-4) + def test_roll(self): + x = mx.arange(10).reshape(2, 5) + + for s in [-2, -1, 0, 1, 2]: + y1 = np.roll(x, s) + y2 = mx.roll(x, s) + self.assertTrue(mx.array_equal(y1, y2).item()) + + y1 = np.roll(x, (s, s, s)) + y2 = mx.roll(x, (s, s, s)) + self.assertTrue(mx.array_equal(y1, y2).item()) + + shifts = [ + 1, + 2, + -1, + -2, + (1, 1), + (-1, 2), + (33, 33), + ] + axes = [ + 0, + 1, + (1, 0), + (0, 1), + (0, 0), + (1, 1), + ] + for s, a in product(shifts, axes): + y1 = np.roll(x, s, a) + y2 = mx.roll(x, s, a) + self.assertTrue(mx.array_equal(y1, y2).item()) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 682f668c9..6bae16fad 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3715,3 +3715,35 @@ TEST_CASE("test view") { auto out = view(in, int32); CHECK(array_equal(out, array({1, 0, 2, 0, 3, 0, 4, 0})).item()); } + +TEST_CASE("test roll") { + auto x = reshape(arange(10), {2, 5}); + + auto y = roll(x, 2); + CHECK(array_equal(y, array({8, 9, 0, 1, 2, 3, 4, 5, 6, 7}, {2, 5})) + .item()); + + y = roll(x, -2); + CHECK(array_equal(y, array({2, 3, 4, 5, 6, 7, 8, 9, 0, 1}, {2, 5})) + .item()); + + y = roll(x, 2, 1); + CHECK(array_equal(y, array({3, 4, 0, 1, 2, 8, 9, 5, 6, 7}, {2, 5})) + .item()); + + y = roll(x, -2, 1); + CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5})) + .item()); + + y = roll(x, 2, {0, 0, 0}); + CHECK(array_equal(y, array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 5})) + .item()); + + y = roll(x, 1, {1, 1, 1}); + CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5})) + .item()); + + y = roll(x, {1, 2}, {0, 1}); + CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5})) + .item()); +}