Add the roll op (#1455)

This commit is contained in:
Angelos Katharopoulos
2024-10-07 17:21:42 -07:00
committed by GitHub
parent f374b6ca4d
commit 9b12093739
6 changed files with 231 additions and 0 deletions

View File

@@ -2641,6 +2641,40 @@ class TestOps(mlx_tests.MLXTestCase):
out_t = vht_t(xb)
np.testing.assert_allclose(out, out_t, atol=1e-4)
def test_roll(self):
x = mx.arange(10).reshape(2, 5)
for s in [-2, -1, 0, 1, 2]:
y1 = np.roll(x, s)
y2 = mx.roll(x, s)
self.assertTrue(mx.array_equal(y1, y2).item())
y1 = np.roll(x, (s, s, s))
y2 = mx.roll(x, (s, s, s))
self.assertTrue(mx.array_equal(y1, y2).item())
shifts = [
1,
2,
-1,
-2,
(1, 1),
(-1, 2),
(33, 33),
]
axes = [
0,
1,
(1, 0),
(0, 1),
(0, 0),
(1, 1),
]
for s, a in product(shifts, axes):
y1 = np.roll(x, s, a)
y2 = mx.roll(x, s, a)
self.assertTrue(mx.array_equal(y1, y2).item())
if __name__ == "__main__":
unittest.main()