Add the roll op (#1455)

This commit is contained in:
Angelos Katharopoulos 2024-10-07 17:21:42 -07:00 committed by GitHub
parent f374b6ca4d
commit 9b12093739
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 231 additions and 0 deletions

View File

@ -130,6 +130,7 @@ Operations
repeat
reshape
right_shift
roll
round
rsqrt
save

View File

@ -4567,4 +4567,94 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
out_shape, dtype, std::make_shared<View>(to_stream(s), dtype), {a});
}
array roll(
const array& a,
const std::vector<int>& shift,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
if (shift.size() < axes.size()) {
std::ostringstream msg;
msg << "[roll] At least one shift value per axis is required, "
<< shift.size() << " provided for " << axes.size() << " axes.";
throw std::invalid_argument(msg.str());
}
std::vector<array> parts;
array result = a;
for (int i = 0; i < axes.size(); i++) {
int ax = axes[i];
if (ax < 0) {
ax += a.ndim();
}
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[roll] Invalid axis " << axes[i] << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
int sh = shift[i];
int split_index =
(sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax);
parts = split(result, std::vector<int>{split_index}, ax, s);
std::swap(parts[0], parts[1]);
result = concatenate(parts, ax, s);
}
return result;
}
array roll(const array& a, int shift, StreamOrDevice s /* = {} */) {
auto shape = a.shape();
return reshape(
roll(
reshape(a, std::vector<int>{-1}, s),
std::vector<int>{shift},
std::vector<int>{0},
s),
std::move(shape),
s);
}
array roll(
const array& a,
const std::vector<int>& shift,
StreamOrDevice s /* = {} */) {
int total_shift = 0;
for (auto& s : shift) {
total_shift += s;
}
return roll(a, total_shift, s);
}
array roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) {
return roll(a, std::vector<int>{shift}, std::vector<int>{axis}, s);
}
array roll(
const array& a,
int shift,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
std::vector<int> shifts(axes.size(), shift);
return roll(a, shifts, axes, s);
}
array roll(
const array& a,
const std::vector<int>& shift,
int axis,
StreamOrDevice s /* = {} */) {
int total_shift = 0;
for (auto& s : shift) {
total_shift += s;
}
return roll(a, std::vector<int>{total_shift}, std::vector<int>{axis}, s);
}
} // namespace mlx::core

View File

@ -1456,6 +1456,30 @@ array right_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator>>(const array& a, const array& b);
array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
/** Roll elements along an axis and introduce them on the other side */
array roll(const array& a, int shift, StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
StreamOrDevice s = {});
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
array roll(
const array& a,
int shift,
const std::vector<int>& axes,
StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
int axis,
StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** @} */
} // namespace mlx::core

View File

@ -4795,4 +4795,54 @@ void init_ops(nb::module_& m) {
Returns:
array: The output array.
)pbdoc");
m.def(
"roll",
[](const array& a,
const IntOrVec& shift,
const IntOrVec& axis,
StreamOrDevice s) {
return std::visit(
[&](auto sh, auto ax) -> array {
using T = decltype(ax);
using V = decltype(sh);
if constexpr (std::is_same_v<V, std::monostate>) {
throw std::invalid_argument(
"[roll] Expected two arguments but only one was given.");
} else {
if constexpr (std::is_same_v<T, std::monostate>) {
return roll(a, sh, s);
} else {
return roll(a, sh, ax, s);
}
}
},
shift,
axis);
},
nb::arg(),
"shift"_a,
"axis"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def roll(a: array, shift: Union[int, Tuple[int]], axis: Union[None, int, Tuple[int]] = None, /, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Roll array elements along a given axis.
Elements that are rolled beyond the end of the array are introduced at
the beggining and vice-versa.
If the axis is not provided the array is flattened, rolled and then the
shape is restored.
Args:
a (array): Input array
shift (int or tuple(int)): The number of places by which elements
are shifted. If positive the array is rolled to the right, if
negative it is rolled to the left. If an int is provided but the
axis is a tuple then the same value is used for all axes.
axis (int or tuple(int), optional): The axis or axes along which to
roll the elements.
)pbdoc");
}

View File

@ -2641,6 +2641,40 @@ class TestOps(mlx_tests.MLXTestCase):
out_t = vht_t(xb)
np.testing.assert_allclose(out, out_t, atol=1e-4)
def test_roll(self):
x = mx.arange(10).reshape(2, 5)
for s in [-2, -1, 0, 1, 2]:
y1 = np.roll(x, s)
y2 = mx.roll(x, s)
self.assertTrue(mx.array_equal(y1, y2).item())
y1 = np.roll(x, (s, s, s))
y2 = mx.roll(x, (s, s, s))
self.assertTrue(mx.array_equal(y1, y2).item())
shifts = [
1,
2,
-1,
-2,
(1, 1),
(-1, 2),
(33, 33),
]
axes = [
0,
1,
(1, 0),
(0, 1),
(0, 0),
(1, 1),
]
for s, a in product(shifts, axes):
y1 = np.roll(x, s, a)
y2 = mx.roll(x, s, a)
self.assertTrue(mx.array_equal(y1, y2).item())
if __name__ == "__main__":
unittest.main()

View File

@ -3715,3 +3715,35 @@ TEST_CASE("test view") {
auto out = view(in, int32);
CHECK(array_equal(out, array({1, 0, 2, 0, 3, 0, 4, 0})).item<bool>());
}
TEST_CASE("test roll") {
auto x = reshape(arange(10), {2, 5});
auto y = roll(x, 2);
CHECK(array_equal(y, array({8, 9, 0, 1, 2, 3, 4, 5, 6, 7}, {2, 5}))
.item<bool>());
y = roll(x, -2);
CHECK(array_equal(y, array({2, 3, 4, 5, 6, 7, 8, 9, 0, 1}, {2, 5}))
.item<bool>());
y = roll(x, 2, 1);
CHECK(array_equal(y, array({3, 4, 0, 1, 2, 8, 9, 5, 6, 7}, {2, 5}))
.item<bool>());
y = roll(x, -2, 1);
CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))
.item<bool>());
y = roll(x, 2, {0, 0, 0});
CHECK(array_equal(y, array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 5}))
.item<bool>());
y = roll(x, 1, {1, 1, 1});
CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))
.item<bool>());
y = roll(x, {1, 2}, {0, 1});
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
.item<bool>());
}