From 8e5a5a1ccd244a29159f374c21be64f286418172 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Fri, 22 Mar 2024 12:11:17 -0700 Subject: [PATCH] Set item bug fix (#879) * set item shaping bug fix * Add extra tests --- python/src/indexing.cpp | 40 ++++++++++++++++++++++++++------------ python/tests/test_array.py | 23 ++++++++++++++++++++++ 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 4f4f0c80f..59e1ff194 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -196,6 +196,7 @@ auto mlx_expand_ellipsis( int non_none_indices_after = 0; std::vector r_indices; int i = 0; + bool has_ellipsis = false; // Start from dimension 0 till we hit an ellipsis for (; i < entries.size(); i++) { @@ -208,6 +209,7 @@ auto mlx_expand_ellipsis( indices.push_back(idx); non_none_indices_before += !idx.is_none(); } else { + has_ellipsis = true; break; } } @@ -231,11 +233,13 @@ auto mlx_expand_ellipsis( int non_none_indices = non_none_indices_before + non_none_indices_after; // Expand ellipsis - for (int axis = non_none_indices_before; - axis < shape.size() - non_none_indices_after; - axis++) { - indices.push_back(nb::slice(0, shape[axis], 1)); - non_none_indices++; + if (has_ellipsis) { + for (int axis = non_none_indices_before; + axis < shape.size() - non_none_indices_after; + axis++) { + indices.push_back(nb::slice(0, shape[axis], 1)); + non_none_indices++; + } } // Insert indices collected after the ellipsis @@ -409,6 +413,10 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { out_shape.push_back(src.shape(axis++)); } } + + out_shape.insert( + out_shape.end(), src.shape().begin() + axis, src.shape().end()); + src = reshape(src, out_shape); } @@ -580,6 +588,7 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( int num_slices = 0; int num_arrays = 0; int num_strided_slices = 0; + int num_simple_slices_post = 0; { bool have_array = false; bool have_non_array = false; @@ -595,6 +604,7 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( auto slice = nb::cast(idx); int stride = get_slice_int(nb::getattr(slice, "step"), 1); num_strided_slices += (stride != 1); + num_simple_slices_post += (stride == 1); } else if (nb::isinstance(idx)) { have_array = true; @@ -603,16 +613,17 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( } max_dim = std::max(nb::cast(idx).ndim(), max_dim); num_arrays++; + num_simple_slices_post = 0; } } } // We have index dims for the arrays, strided slices (implemented as arrays), // none - int idx_ndim = max_dim + num_strided_slices + num_none; + int idx_ndim = max_dim + num_none + num_slices - num_simple_slices_post; // If we have simple non-strided slices, we also attach an index for that - idx_ndim += (num_slices < num_strided_slices); + idx_ndim = idx_ndim == 0 ? 1 : idx_ndim; // Go over each index type and translate to the needed scatter args std::vector arr_indices; @@ -639,22 +650,27 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( std::vector idx_shape(idx_ndim, 1); // If it's a simple slice, we only need to add the start index - if (stride == 1) { + if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) { auto idx = array({start}, idx_shape, uint32); slice_shapes.push_back(end - start); arr_indices.push_back(idx); + + // Add the shape to the update + update_shape[ax - 1] = slice_shapes.back(); } // Otherwise we expand the slice into indices using arange else { auto idx = arange(start, end, stride, uint32); auto loc = slice_num + (arrays_first ? max_dim : 0); - slice_num++; idx_shape[loc] = idx.size(); - slice_shapes.push_back(idx.size()); arr_indices.push_back(reshape(idx, idx_shape)); + + slice_num++; + num_strided_slices--; + + // Add the shape to the update + update_shape[ax - 1] = 1; } - // Add the shape to the update - update_shape[ax - 1] = slice_shapes.back(); } else if (nb::isinstance(pyidx)) { // Add index to arrays arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index ebbe25806..9b06d3e57 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1058,6 +1058,29 @@ class TestArray(mlx_tests.MLXTestCase): a[2:-2, 2:-2] = 4 self.assertEqual(a[2, 2].item(), 4) + # Check slice array slice + check_slices( + np.zeros((5, 4, 4)), + np.arange(4 * 2 * 3).reshape(4, 2, 3), + slice(0, 4), + np.array([1, 3]), + slice(None, -1), + ) + check_slices( + np.zeros((5, 4, 4)), + np.arange(4 * 2 * 2).reshape(4, 2, 2), + slice(0, 4), + np.array([1, 3]), + slice(0, 4, 2), + ) + + check_slices( + np.zeros((1, 10, 4)), + np.arange(2 * 4).reshape(1, 2, 4), + slice(None, None, None), + np.array([1, 3]), + ) + def test_array_at(self): a = mx.array(1) a = a.at[None].add(1)