mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-01 07:56:44 +08:00
fix fp error in roll + add test
This commit is contained in:
parent
9cfe0b1533
commit
6389b9d37a
@ -5025,8 +5025,11 @@ array roll(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto sh = shift[i];
|
auto sh = shift[i];
|
||||||
auto split_index =
|
auto size = a.shape(ax);
|
||||||
(sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax);
|
if (size == 0) {
|
||||||
|
continue; // skip rolling this axis if it has size 0
|
||||||
|
}
|
||||||
|
auto split_index = (sh < 0) ? (-sh) % size : size - sh % size;
|
||||||
|
|
||||||
auto parts = split(result, Shape{split_index}, ax, s);
|
auto parts = split(result, Shape{split_index}, ax, s);
|
||||||
std::swap(parts[0], parts[1]);
|
std::swap(parts[0], parts[1]);
|
||||||
|
@ -2961,6 +2961,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
y2 = mx.roll(x, s, a)
|
y2 = mx.roll(x, s, a)
|
||||||
self.assertTrue(mx.array_equal(y1, y2).item())
|
self.assertTrue(mx.array_equal(y1, y2).item())
|
||||||
|
|
||||||
|
def test_roll_errors(self):
|
||||||
|
x = mx.array([])
|
||||||
|
result = mx.roll(x, [0], [0])
|
||||||
|
self.assertTrue(mx.array_equal(result, x))
|
||||||
|
|
||||||
def test_real_imag(self):
|
def test_real_imag(self):
|
||||||
x = mx.random.uniform(shape=(4, 4))
|
x = mx.random.uniform(shape=(4, 4))
|
||||||
out = mx.real(x)
|
out = mx.real(x)
|
||||||
|
@ -3859,6 +3859,9 @@ TEST_CASE("test roll") {
|
|||||||
y = roll(x, {1, 2}, {0, 1});
|
y = roll(x, {1, 2}, {0, 1});
|
||||||
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
|
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
|
|
||||||
|
y = roll(array({}), 0, 0);
|
||||||
|
CHECK(array_equal(y, array({})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test contiguous") {
|
TEST_CASE("test contiguous") {
|
||||||
|
Loading…
Reference in New Issue
Block a user