From 490c0c4fdc4b5873b1b5f3807abdcca5691f5acf Mon Sep 17 00:00:00 2001 From: Jacket <44538064+PRESIDENT810@users.noreply.github.com> Date: Mon, 29 Apr 2024 09:57:28 -0500 Subject: [PATCH] [Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035) * Not sure if this is correct * Format * Edit tests * Add negative test * Format * add one more test --------- Co-authored-by: Awni Hannun --- python/src/indexing.cpp | 4 +--- python/tests/test_array.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 446f17930..f89a421f4 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -845,9 +845,7 @@ auto mlx_slice_update( st = (st < 0) ? st + src.shape(ax) : st; starts[ax] = st; stops[ax] = st + 1; - if (src.ndim() - ax < up.ndim()) { - upd_expand_dims.push_back(ax - src.ndim()); - } + upd_expand_dims.push_back(ax); ax++; } } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 4084a2693..bc44b7e6d 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1194,6 +1194,10 @@ class TestArray(mlx_tests.MLXTestCase): a = mx.zeros((2, 2)) a[0, 0, 0] = 1 + with self.assertRaises(ValueError): + a = mx.zeros((5, 4, 3)) + a[:, 0] = mx.ones((5, 1, 3)) + check_slices(np.zeros((2, 2, 2, 2)), 1, None, Ellipsis, None) check_slices( np.zeros((2, 2, 2, 2)), 1, np.array([0, 1]), Ellipsis, np.array([0, 1]) @@ -1251,6 +1255,13 @@ class TestArray(mlx_tests.MLXTestCase): slice(None, None, 2), ) + check_slices(np.zeros((5, 4, 3)), np.ones((5, 3)), slice(None), 0) + + check_slices(np.zeros((5, 4, 3)), np.ones((5, 1, 3)), slice(None), slice(0, 1)) + check_slices( + np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4) + ) + def test_array_at(self): a = mx.array(1) a = a.at[None].add(1)