Fix slice data size (#1913)

* fix slice data size

* add test
This commit is contained in:
Awni Hannun
2025-03-02 21:50:42 -08:00
committed by GitHub
parent 5e6c130d93
commit 4e7cd31d12
2 changed files with 13 additions and 3 deletions

View File

@@ -2846,6 +2846,11 @@ class TestOps(mlx_tests.MLXTestCase):
b[::2] = 0
self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))
def test_slice_with_negative_stride(self):
a = mx.random.uniform(shape=(128, 4))
out = a[::-1]
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
if __name__ == "__main__":
unittest.main()