Fix a couple of slicing bugs (#1827)

* fix a few bugs

* fix conv grad

* speedup test

* comment
This commit is contained in:
Awni Hannun
2025-02-05 19:50:08 -08:00
committed by GitHub
parent 9174606d4c
commit af1b725fda
14 changed files with 170 additions and 107 deletions

View File

@@ -2816,6 +2816,12 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(a.shape, (3, 4, 2))
self.assertEqual(b.shape, (3, 4, 2))
def test_slice_update_reversed(self):
a = mx.array([1, 2, 3, 4])
b = a[::-1]
b[::2] = 0
self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))
if __name__ == "__main__":
unittest.main()