diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e643170a6..175fe6f89 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -558,8 +558,7 @@ array slice_update( normalize_slice(src.shape(), start, stop, strides); // Broadcast update shape to slice shape - auto upd_shape_broadcast = broadcast_shapes(upd_shape, update.shape()); - auto update_broadcasted = broadcast_to(update, upd_shape_broadcast, s); + auto update_broadcasted = broadcast_to(update, upd_shape, s); // If the entire src is the slice, just return the update if (!has_neg_strides && upd_shape == src.shape()) { @@ -571,7 +570,7 @@ array slice_update( src.dtype(), std::make_unique( to_stream(s), std::move(start), std::move(stop), std::move(strides)), - {src, update}); + {src, update_broadcasted}); } /** Update a slice from the source array with stride 1 in each dimension */ diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index a0682afed..4f4f0c80f 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -186,6 +186,64 @@ array mlx_gather_nd( return src; } +auto mlx_expand_ellipsis( + const std::vector& shape, + const nb::tuple& entries) { + std::vector indices; + + // Go over all entries and note the position of ellipsis + int non_none_indices_before = 0; + int non_none_indices_after = 0; + std::vector r_indices; + int i = 0; + + // Start from dimension 0 till we hit an ellipsis + for (; i < entries.size(); i++) { + auto idx = entries[i]; + if (!is_valid_index_type(idx)) { + throw std::invalid_argument( + "Cannot index mlx array using the given type yet"); + } + if (!nb::ellipsis().is(idx)) { + indices.push_back(idx); + non_none_indices_before += !idx.is_none(); + } else { + break; + } + } + + // If we do hit an ellipsis, collect indices from the back + for (int j = entries.size() - 1; j > i; j--) { + auto idx = entries[j]; + if (!is_valid_index_type(idx)) { + throw std::invalid_argument( + "Cannot index mlx array using the given type yet"); + } + if (nb::ellipsis().is(idx)) { + throw std::invalid_argument( + "An index can only have a single ellipsis (...)"); + } + r_indices.push_back(idx); + non_none_indices_after += !idx.is_none(); + } + + // Count up the number of non none indices + int non_none_indices = non_none_indices_before + non_none_indices_after; + + // Expand ellipsis + for (int axis = non_none_indices_before; + axis < shape.size() - non_none_indices_after; + axis++) { + indices.push_back(nb::slice(0, shape[axis], 1)); + non_none_indices++; + } + + // Insert indices collected after the ellipsis + indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend()); + + return std::make_pair(non_none_indices, indices); +} + array mlx_get_item_nd(array src, const nb::tuple& entries) { // No indices make this a noop if (entries.size() == 0) { @@ -198,59 +256,13 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { // 3. Calculate the remaining slices and reshapes // Ellipsis handling - std::vector indices; - { - int non_none_indices_before = 0; - int non_none_indices_after = 0; - std::vector r_indices; - int i = 0; - for (; i < entries.size(); i++) { - auto idx = entries[i]; - if (!is_valid_index_type(idx)) { - throw std::invalid_argument( - "Cannot index mlx array using the given type yet"); - } - if (!nb::ellipsis().is(idx)) { - indices.push_back(idx); - non_none_indices_before += !idx.is_none(); - } else { - break; - } - } - for (int j = entries.size() - 1; j > i; j--) { - auto idx = entries[j]; - if (!is_valid_index_type(idx)) { - throw std::invalid_argument( - "Cannot index mlx array using the given type yet"); - } - if (nb::ellipsis().is(idx)) { - throw std::invalid_argument( - "An index can only have a single ellipsis (...)"); - } - r_indices.push_back(idx); - non_none_indices_after += !idx.is_none(); - } - for (int axis = non_none_indices_before; - axis < src.ndim() - non_none_indices_after; - axis++) { - indices.push_back(nb::slice(0, src.shape(axis), 1)); - } - indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend()); - } + auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); // Check for the number of indices passed - { - int cnt = src.ndim(); - for (auto& idx : indices) { - if (!idx.is_none()) { - cnt--; - } - } - if (cnt < 0) { - std::ostringstream msg; - msg << "Too many indices for array with " << src.ndim() << "dimensions."; - throw std::invalid_argument(msg.str()); - } + if (non_none_indices > src.ndim()) { + std::ostringstream msg; + msg << "Too many indices for array with " << src.ndim() << "dimensions."; + throw std::invalid_argument(msg.str()); } // Gather handling @@ -265,7 +277,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { bool have_non_array = false; bool gather_first = false; for (auto& idx : indices) { - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || (nb::isinstance(idx))) { if (have_array && have_non_array) { gather_first = true; break; @@ -276,6 +288,13 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { } } + int n_arr = 0; + for (auto& idx : indices) { + n_arr += nb::isinstance(idx); + } + + have_array &= n_arr > 0; + if (have_array) { int last_array; // Then find the last array @@ -343,6 +362,8 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { remaining_indices = indices; } + bool squeeze_needed = false; + // Slice handling { std::vector starts(src.ndim(), 0); @@ -351,12 +372,24 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { int axis = 0; for (auto& idx : remaining_indices) { if (!idx.is_none()) { - get_slice_params( - starts[axis], - ends[axis], - strides[axis], - nb::cast(idx), - ends[axis]); + if (!have_array && nb::isinstance(idx)) { + int st = nb::cast(idx); + st = (st < 0) ? st + src.shape(axis) : st; + + starts[axis] = st; + ends[axis] = st + 1; + + squeeze_needed = true; + + } else { + get_slice_params( + starts[axis], + ends[axis], + strides[axis], + nb::cast(idx), + ends[axis]); + } + axis++; } } @@ -364,12 +397,14 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { } // Unsqueeze handling - if (remaining_indices.size() > src.ndim()) { + if (remaining_indices.size() > src.ndim() || squeeze_needed) { std::vector out_shape; int axis = 0; for (auto& idx : remaining_indices) { if (idx.is_none()) { out_shape.push_back(1); + } else if (squeeze_needed && nb::isinstance(idx)) { + axis++; } else { out_shape.push_back(src.shape(axis++)); } @@ -479,6 +514,35 @@ std::tuple, array, std::vector> mlx_scatter_args_slice( // Check and update slice params get_slice_params(start, end, stride, in_slice, end); + // If simple stride + if (stride == 1) { + // Squeeze out singleton dims from the start of update + int s = 0; + for (; s < update.ndim() && update.shape(s) == 1; s++) + ; + auto up_shape = + std::vector(update.shape().begin() + s, update.shape().end()); + auto up = reshape(update, up_shape); + + // Build array to mark start of slice + auto idx = array({start}, {1}, uint32); + + // Get slice size + int slice_size = (end - start); + + // Broadcast update to slide size + std::vector up_shape_broadcast = {1, slice_size}; + up_shape_broadcast.insert( + up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end()); + + up = broadcast_to(update, up_shape_broadcast); + + auto indices = std::vector{idx}; + auto axes = std::vector{0}; + + return {indices, up, axes}; + } + return mlx_scatter_args_array( src, arange(start, end, stride, uint32), update); } @@ -487,47 +551,8 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( const array& src, const nb::tuple& entries, const array& update) { - std::vector indices; - int non_none_indices = 0; - // Expand ellipses into a series of ':' slices - { - int non_none_indices_before = 0; - int non_none_indices_after = 0; - bool has_ellipsis = false; - int indices_before = 0; - for (int i = 0; i < entries.size(); ++i) { - auto idx = entries[i]; - if (!is_valid_index_type(idx)) { - throw std::invalid_argument( - "Cannot index mlx array using the given type yet"); - } else if (!nb::ellipsis().is(idx)) { - if (!has_ellipsis) { - indices_before++; - non_none_indices_before += !idx.is_none(); - } else { - non_none_indices_after += !idx.is_none(); - } - indices.push_back(idx); - } else if (has_ellipsis) { - throw std::invalid_argument( - "An index can only have a single ellipsis (...)"); - } else { - has_ellipsis = true; - } - } - if (has_ellipsis) { - for (int axis = non_none_indices_before; - axis < src.ndim() - non_none_indices_after; - axis++) { - indices.insert( - indices.begin() + indices_before, nb::slice(0, src.shape(axis), 1)); - } - non_none_indices = src.ndim(); - } else { - non_none_indices = non_none_indices_before + non_none_indices_after; - } - } + auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); if (non_none_indices > src.ndim()) { std::ostringstream msg; @@ -548,17 +573,29 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( return {{}, broadcast_to(up, src.shape()), {}}; } + // Analyse the types of the indices unsigned long max_dim = 0; bool arrays_first = false; + int num_none = 0; int num_slices = 0; int num_arrays = 0; + int num_strided_slices = 0; { bool have_array = false; bool have_non_array = false; for (auto& idx : indices) { - if (nb::isinstance(idx) || idx.is_none()) { + if (idx.is_none()) { + have_non_array = have_array; + num_none++; + + } else if (nb::isinstance(idx)) { have_non_array = have_array; num_slices++; + + auto slice = nb::cast(idx); + int stride = get_slice_int(nb::getattr(slice, "step"), 1); + num_strided_slices += (stride != 1); + } else if (nb::isinstance(idx)) { have_array = true; if (have_array && have_non_array) { @@ -570,10 +607,23 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( } } + // We have index dims for the arrays, strided slices (implemented as arrays), + // none + int idx_ndim = max_dim + num_strided_slices + num_none; + + // If we have simple non-strided slices, we also attach an index for that + idx_ndim += (num_slices < num_strided_slices); + + // Go over each index type and translate to the needed scatter args std::vector arr_indices; int slice_num = 0; int array_num = 0; int ax = 0; + + // We collect the shapes of the slices and updates during this process + std::vector update_shape(non_none_indices, 1); + std::vector slice_shapes; + for (int i = 0; i < indices.size(); ++i) { auto& pyidx = indices[i]; if (nb::isinstance(pyidx)) { @@ -586,48 +636,79 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( start = (start < 0) ? start + axis_size : start; end = (end < 0) ? end + axis_size : end; - auto idx = arange(start, end, stride, uint32); - std::vector idx_shape(max_dim + num_slices, 1); - auto loc = slice_num + (arrays_first ? max_dim : 0); - slice_num++; - idx_shape[loc] = idx.size(); - arr_indices.push_back(reshape(idx, idx_shape)); + std::vector idx_shape(idx_ndim, 1); + + // If it's a simple slice, we only need to add the start index + if (stride == 1) { + auto idx = array({start}, idx_shape, uint32); + slice_shapes.push_back(end - start); + arr_indices.push_back(idx); + } + // Otherwise we expand the slice into indices using arange + else { + auto idx = arange(start, end, stride, uint32); + auto loc = slice_num + (arrays_first ? max_dim : 0); + slice_num++; + idx_shape[loc] = idx.size(); + slice_shapes.push_back(idx.size()); + arr_indices.push_back(reshape(idx, idx_shape)); + } + // Add the shape to the update + update_shape[ax - 1] = slice_shapes.back(); } else if (nb::isinstance(pyidx)) { + // Add index to arrays arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); + // Add the shape to the update + update_shape[ax - 1] = 1; } else if (pyidx.is_none()) { + // We only use the None's for bookeeping dimensions slice_num++; } else if (nb::isinstance(pyidx)) { ax++; auto idx = nb::cast(pyidx); - std::vector idx_shape; - if (!arrays_first) { - idx_shape.insert(idx_shape.end(), slice_num, 1); + std::vector idx_shape(idx_ndim, 1); + + // Place the arrays in the correct dimension + int st = (!arrays_first) * slice_num + max_dim - idx.ndim(); + for (int j = 0; j < idx.ndim(); j++) { + idx_shape[st + j] = idx.shape()[j]; } - idx_shape.insert(idx_shape.end(), max_dim - idx.ndim(), 1); - idx_shape.insert(idx_shape.end(), idx.shape().begin(), idx.shape().end()); - idx_shape.insert( - idx_shape.end(), num_slices - (arrays_first ? 0 : slice_num), 1); arr_indices.push_back(reshape(idx, idx_shape)); if (!arrays_first && ++array_num == num_arrays) { slice_num += max_dim; } + + // Add the shape to the update + update_shape[ax - 1] = 1; } else { throw std::invalid_argument( "Cannot index mlx array using the given type yet"); } } + // Broadcast the update to the indices and slices arr_indices = broadcast_arrays(arr_indices); - up_shape = arr_indices[0].shape(); - up_shape.insert( - up_shape.end(), + auto up_shape_broadcast = arr_indices[0].shape(); + + up_shape_broadcast.insert( + up_shape_broadcast.end(), slice_shapes.begin(), slice_shapes.end()); + up_shape_broadcast.insert( + up_shape_broadcast.end(), src.shape().begin() + non_none_indices, src.shape().end()); - up = broadcast_to(up, up_shape); - up_shape.insert( - up_shape.begin() + arr_indices[0].ndim(), non_none_indices, 1); - up = reshape(up, up_shape); + up = broadcast_to(up, up_shape_broadcast); + // Reshape the update with the size-1 dims for the int and array indices + auto up_reshape = arr_indices[0].shape(); + up_reshape.insert(up_reshape.end(), update_shape.begin(), update_shape.end()); + up_reshape.insert( + up_reshape.end(), + src.shape().begin() + non_none_indices, + src.shape().end()); + + up = reshape(up, up_reshape); + + // Collect axes std::vector axes(arr_indices.size(), 0); std::iota(axes.begin(), axes.end(), 0); @@ -654,7 +735,112 @@ mlx_compute_scatter_args( throw std::invalid_argument("Cannot index mlx array using the given type."); } +auto mlx_slice_update( + const array& src, + const nb::object& obj, + const ScalarOrArray& v) { + // Can't route to slice update if not slice or tuple + if (src.ndim() == 0 || + (!nb::isinstance(obj) && !nb::isinstance(obj))) { + return std::make_pair(false, src); + } + + // Should be able to route to slice update + + // Pre process tuple + auto upd = to_array(v, src.dtype()); + + // Remove leading singletons dimensions from the update + int s = 0; + for (; s < upd.ndim() && upd.shape(s) == 1; s++) { + }; + auto up_shape = std::vector(upd.shape().begin() + s, upd.shape().end()); + up_shape = up_shape.empty() ? std::vector{1} : up_shape; + auto up = reshape(upd, up_shape); + + // Build slice update params + std::vector starts(src.ndim(), 0); + std::vector stops = src.shape(); + std::vector strides(src.ndim(), 1); + + // If it's just a simple slice, just do a slice update and return + if (nb::isinstance(obj)) { + // Read slice arguments + get_slice_params( + starts[0], + stops[0], + strides[0], + nb::cast(obj), + src.shape(0)); + + // Do slice update + auto out = slice_update(src, up, starts, stops, strides); + return std::make_pair(true, out); + } + + // It must be a tuple + auto entries = nb::cast(obj); + + // Can't route to slice update if any arrays are present + for (int i = 0; i < entries.size(); i++) { + auto idx = entries[i]; + if (nb::isinstance(idx)) { + return std::make_pair(false, src); + } + } + + // Expand ellipses into a series of ':' slices + auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); + + // Dimension check + if (non_none_indices > src.ndim()) { + std::ostringstream msg; + msg << "Too many indices for array with " << src.ndim() << "dimensions."; + throw std::invalid_argument(msg.str()); + } + + // If no non-None indices return the broadcasted update + if (non_none_indices == 0) { + return std::make_pair(true, broadcast_to(up, src.shape())); + } + + // Process entries + std::vector upd_expand_dims; + int ax = 0; + for (int i = 0; i < indices.size(); ++i) { + auto& pyidx = indices[i]; + if (nb::isinstance(pyidx)) { + get_slice_params( + starts[ax], + stops[ax], + strides[ax], + nb::cast(pyidx), + src.shape(ax)); + ax++; + } else if (nb::isinstance(pyidx)) { + int st = nb::cast(pyidx); + st = (st < 0) ? st + src.shape(ax) : st; + starts[ax] = st; + stops[ax] = st + 1; + if (src.ndim() - ax < up.ndim()) { + upd_expand_dims.push_back(ax - src.ndim()); + } + ax++; + } + } + + up = expand_dims(up, upd_expand_dims); + auto out = slice_update(src, up, starts, stops, strides); + return std::make_pair(true, out); +} + void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) { + auto [success, out] = mlx_slice_update(src, obj, v); + if (success) { + src.overwrite_descriptor(out); + return; + } + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { auto out = scatter(src, indices, updates, axes);