From 6389b9d37adcf2d2fe8cba48987d1727ed671426 Mon Sep 17 00:00:00 2001 From: Aashiq Dheeraj Date: Tue, 29 Apr 2025 23:04:09 -0400 Subject: [PATCH] fix fp error in roll + add test --- mlx/ops.cpp | 7 +++++-- python/tests/test_ops.py | 5 +++++ tests/ops_tests.cpp | 3 +++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f8308c2d5e..e7abe12db5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -5025,8 +5025,11 @@ array roll( } auto sh = shift[i]; - auto split_index = - (sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); + auto size = a.shape(ax); + if (size == 0) { + continue; // skip rolling this axis if it has size 0 + } + auto split_index = (sh < 0) ? (-sh) % size : size - sh % size; auto parts = split(result, Shape{split_index}, ax, s); std::swap(parts[0], parts[1]); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 47fec31672..d840eac7d1 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2961,6 +2961,11 @@ class TestOps(mlx_tests.MLXTestCase): y2 = mx.roll(x, s, a) self.assertTrue(mx.array_equal(y1, y2).item()) + def test_roll_errors(self): + x = mx.array([]) + result = mx.roll(x, [0], [0]) + self.assertTrue(mx.array_equal(result, x)) + def test_real_imag(self): x = mx.random.uniform(shape=(4, 4)) out = mx.real(x) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c4f319d462..5e2bae5a01 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3859,6 +3859,9 @@ TEST_CASE("test roll") { y = roll(x, {1, 2}, {0, 1}); CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5})) .item()); + + y = roll(array({}), 0, 0); + CHECK(array_equal(y, array({})).item()); } TEST_CASE("test contiguous") {