mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding * Refactor ellipsis handling * Route mlx_set_item to slice_update where possible * Update mlx_scatter_args_slice * Don't route to gather if no array indices
This commit is contained in:
parent
e849b3424a
commit
a5681ebc52
@ -558,8 +558,7 @@ array slice_update(
|
||||
normalize_slice(src.shape(), start, stop, strides);
|
||||
|
||||
// Broadcast update shape to slice shape
|
||||
auto upd_shape_broadcast = broadcast_shapes(upd_shape, update.shape());
|
||||
auto update_broadcasted = broadcast_to(update, upd_shape_broadcast, s);
|
||||
auto update_broadcasted = broadcast_to(update, upd_shape, s);
|
||||
|
||||
// If the entire src is the slice, just return the update
|
||||
if (!has_neg_strides && upd_shape == src.shape()) {
|
||||
@ -571,7 +570,7 @@ array slice_update(
|
||||
src.dtype(),
|
||||
std::make_unique<SliceUpdate>(
|
||||
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
||||
{src, update});
|
||||
{src, update_broadcasted});
|
||||
}
|
||||
|
||||
/** Update a slice from the source array with stride 1 in each dimension */
|
||||
|
@ -186,6 +186,64 @@ array mlx_gather_nd(
|
||||
return src;
|
||||
}
|
||||
|
||||
auto mlx_expand_ellipsis(
|
||||
const std::vector<int>& shape,
|
||||
const nb::tuple& entries) {
|
||||
std::vector<nb::object> indices;
|
||||
|
||||
// Go over all entries and note the position of ellipsis
|
||||
int non_none_indices_before = 0;
|
||||
int non_none_indices_after = 0;
|
||||
std::vector<nb::object> r_indices;
|
||||
int i = 0;
|
||||
|
||||
// Start from dimension 0 till we hit an ellipsis
|
||||
for (; i < entries.size(); i++) {
|
||||
auto idx = entries[i];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (!nb::ellipsis().is(idx)) {
|
||||
indices.push_back(idx);
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If we do hit an ellipsis, collect indices from the back
|
||||
for (int j = entries.size() - 1; j > i; j--) {
|
||||
auto idx = entries[j];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (nb::ellipsis().is(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"An index can only have a single ellipsis (...)");
|
||||
}
|
||||
r_indices.push_back(idx);
|
||||
non_none_indices_after += !idx.is_none();
|
||||
}
|
||||
|
||||
// Count up the number of non none indices
|
||||
int non_none_indices = non_none_indices_before + non_none_indices_after;
|
||||
|
||||
// Expand ellipsis
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < shape.size() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(nb::slice(0, shape[axis], 1));
|
||||
non_none_indices++;
|
||||
}
|
||||
|
||||
// Insert indices collected after the ellipsis
|
||||
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
|
||||
|
||||
return std::make_pair(non_none_indices, indices);
|
||||
}
|
||||
|
||||
array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
// No indices make this a noop
|
||||
if (entries.size() == 0) {
|
||||
@ -198,59 +256,13 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
// 3. Calculate the remaining slices and reshapes
|
||||
|
||||
// Ellipsis handling
|
||||
std::vector<nb::object> indices;
|
||||
{
|
||||
int non_none_indices_before = 0;
|
||||
int non_none_indices_after = 0;
|
||||
std::vector<nb::object> r_indices;
|
||||
int i = 0;
|
||||
for (; i < entries.size(); i++) {
|
||||
auto idx = entries[i];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (!nb::ellipsis().is(idx)) {
|
||||
indices.push_back(idx);
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (int j = entries.size() - 1; j > i; j--) {
|
||||
auto idx = entries[j];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (nb::ellipsis().is(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"An index can only have a single ellipsis (...)");
|
||||
}
|
||||
r_indices.push_back(idx);
|
||||
non_none_indices_after += !idx.is_none();
|
||||
}
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(nb::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
|
||||
}
|
||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||
|
||||
// Check for the number of indices passed
|
||||
{
|
||||
int cnt = src.ndim();
|
||||
for (auto& idx : indices) {
|
||||
if (!idx.is_none()) {
|
||||
cnt--;
|
||||
}
|
||||
}
|
||||
if (cnt < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (non_none_indices > src.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Gather handling
|
||||
@ -265,7 +277,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
bool have_non_array = false;
|
||||
bool gather_first = false;
|
||||
for (auto& idx : indices) {
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
if (nb::isinstance<array>(idx) || (nb::isinstance<nb::int_>(idx))) {
|
||||
if (have_array && have_non_array) {
|
||||
gather_first = true;
|
||||
break;
|
||||
@ -276,6 +288,13 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
}
|
||||
}
|
||||
|
||||
int n_arr = 0;
|
||||
for (auto& idx : indices) {
|
||||
n_arr += nb::isinstance<array>(idx);
|
||||
}
|
||||
|
||||
have_array &= n_arr > 0;
|
||||
|
||||
if (have_array) {
|
||||
int last_array;
|
||||
// Then find the last array
|
||||
@ -343,6 +362,8 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
remaining_indices = indices;
|
||||
}
|
||||
|
||||
bool squeeze_needed = false;
|
||||
|
||||
// Slice handling
|
||||
{
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
@ -351,12 +372,24 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
int axis = 0;
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (!idx.is_none()) {
|
||||
get_slice_params(
|
||||
starts[axis],
|
||||
ends[axis],
|
||||
strides[axis],
|
||||
nb::cast<nb::slice>(idx),
|
||||
ends[axis]);
|
||||
if (!have_array && nb::isinstance<nb::int_>(idx)) {
|
||||
int st = nb::cast<int>(idx);
|
||||
st = (st < 0) ? st + src.shape(axis) : st;
|
||||
|
||||
starts[axis] = st;
|
||||
ends[axis] = st + 1;
|
||||
|
||||
squeeze_needed = true;
|
||||
|
||||
} else {
|
||||
get_slice_params(
|
||||
starts[axis],
|
||||
ends[axis],
|
||||
strides[axis],
|
||||
nb::cast<nb::slice>(idx),
|
||||
ends[axis]);
|
||||
}
|
||||
|
||||
axis++;
|
||||
}
|
||||
}
|
||||
@ -364,12 +397,14 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
}
|
||||
|
||||
// Unsqueeze handling
|
||||
if (remaining_indices.size() > src.ndim()) {
|
||||
if (remaining_indices.size() > src.ndim() || squeeze_needed) {
|
||||
std::vector<int> out_shape;
|
||||
int axis = 0;
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (idx.is_none()) {
|
||||
out_shape.push_back(1);
|
||||
} else if (squeeze_needed && nb::isinstance<nb::int_>(idx)) {
|
||||
axis++;
|
||||
} else {
|
||||
out_shape.push_back(src.shape(axis++));
|
||||
}
|
||||
@ -479,6 +514,35 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
// Check and update slice params
|
||||
get_slice_params(start, end, stride, in_slice, end);
|
||||
|
||||
// If simple stride
|
||||
if (stride == 1) {
|
||||
// Squeeze out singleton dims from the start of update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
|
||||
// Build array to mark start of slice
|
||||
auto idx = array({start}, {1}, uint32);
|
||||
|
||||
// Get slice size
|
||||
int slice_size = (end - start);
|
||||
|
||||
// Broadcast update to slide size
|
||||
std::vector<int> up_shape_broadcast = {1, slice_size};
|
||||
up_shape_broadcast.insert(
|
||||
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
|
||||
|
||||
up = broadcast_to(update, up_shape_broadcast);
|
||||
|
||||
auto indices = std::vector<array>{idx};
|
||||
auto axes = std::vector<int>{0};
|
||||
|
||||
return {indices, up, axes};
|
||||
}
|
||||
|
||||
return mlx_scatter_args_array(
|
||||
src, arange(start, end, stride, uint32), update);
|
||||
}
|
||||
@ -487,47 +551,8 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
const array& src,
|
||||
const nb::tuple& entries,
|
||||
const array& update) {
|
||||
std::vector<nb::object> indices;
|
||||
int non_none_indices = 0;
|
||||
|
||||
// Expand ellipses into a series of ':' slices
|
||||
{
|
||||
int non_none_indices_before = 0;
|
||||
int non_none_indices_after = 0;
|
||||
bool has_ellipsis = false;
|
||||
int indices_before = 0;
|
||||
for (int i = 0; i < entries.size(); ++i) {
|
||||
auto idx = entries[i];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
} else if (!nb::ellipsis().is(idx)) {
|
||||
if (!has_ellipsis) {
|
||||
indices_before++;
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
non_none_indices_after += !idx.is_none();
|
||||
}
|
||||
indices.push_back(idx);
|
||||
} else if (has_ellipsis) {
|
||||
throw std::invalid_argument(
|
||||
"An index can only have a single ellipsis (...)");
|
||||
} else {
|
||||
has_ellipsis = true;
|
||||
}
|
||||
}
|
||||
if (has_ellipsis) {
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.insert(
|
||||
indices.begin() + indices_before, nb::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
non_none_indices = src.ndim();
|
||||
} else {
|
||||
non_none_indices = non_none_indices_before + non_none_indices_after;
|
||||
}
|
||||
}
|
||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||
|
||||
if (non_none_indices > src.ndim()) {
|
||||
std::ostringstream msg;
|
||||
@ -548,17 +573,29 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
return {{}, broadcast_to(up, src.shape()), {}};
|
||||
}
|
||||
|
||||
// Analyse the types of the indices
|
||||
unsigned long max_dim = 0;
|
||||
bool arrays_first = false;
|
||||
int num_none = 0;
|
||||
int num_slices = 0;
|
||||
int num_arrays = 0;
|
||||
int num_strided_slices = 0;
|
||||
{
|
||||
bool have_array = false;
|
||||
bool have_non_array = false;
|
||||
for (auto& idx : indices) {
|
||||
if (nb::isinstance<nb::slice>(idx) || idx.is_none()) {
|
||||
if (idx.is_none()) {
|
||||
have_non_array = have_array;
|
||||
num_none++;
|
||||
|
||||
} else if (nb::isinstance<nb::slice>(idx)) {
|
||||
have_non_array = have_array;
|
||||
num_slices++;
|
||||
|
||||
auto slice = nb::cast<nb::slice>(idx);
|
||||
int stride = get_slice_int(nb::getattr(slice, "step"), 1);
|
||||
num_strided_slices += (stride != 1);
|
||||
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
have_array = true;
|
||||
if (have_array && have_non_array) {
|
||||
@ -570,10 +607,23 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
}
|
||||
}
|
||||
|
||||
// We have index dims for the arrays, strided slices (implemented as arrays),
|
||||
// none
|
||||
int idx_ndim = max_dim + num_strided_slices + num_none;
|
||||
|
||||
// If we have simple non-strided slices, we also attach an index for that
|
||||
idx_ndim += (num_slices < num_strided_slices);
|
||||
|
||||
// Go over each index type and translate to the needed scatter args
|
||||
std::vector<array> arr_indices;
|
||||
int slice_num = 0;
|
||||
int array_num = 0;
|
||||
int ax = 0;
|
||||
|
||||
// We collect the shapes of the slices and updates during this process
|
||||
std::vector<int> update_shape(non_none_indices, 1);
|
||||
std::vector<int> slice_shapes;
|
||||
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
@ -586,48 +636,79 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
start = (start < 0) ? start + axis_size : start;
|
||||
end = (end < 0) ? end + axis_size : end;
|
||||
|
||||
auto idx = arange(start, end, stride, uint32);
|
||||
std::vector<int> idx_shape(max_dim + num_slices, 1);
|
||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||
slice_num++;
|
||||
idx_shape[loc] = idx.size();
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
std::vector<int> idx_shape(idx_ndim, 1);
|
||||
|
||||
// If it's a simple slice, we only need to add the start index
|
||||
if (stride == 1) {
|
||||
auto idx = array({start}, idx_shape, uint32);
|
||||
slice_shapes.push_back(end - start);
|
||||
arr_indices.push_back(idx);
|
||||
}
|
||||
// Otherwise we expand the slice into indices using arange
|
||||
else {
|
||||
auto idx = arange(start, end, stride, uint32);
|
||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||
slice_num++;
|
||||
idx_shape[loc] = idx.size();
|
||||
slice_shapes.push_back(idx.size());
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
}
|
||||
// Add the shape to the update
|
||||
update_shape[ax - 1] = slice_shapes.back();
|
||||
} else if (nb::isinstance<nb::int_>(pyidx)) {
|
||||
// Add index to arrays
|
||||
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
|
||||
// Add the shape to the update
|
||||
update_shape[ax - 1] = 1;
|
||||
} else if (pyidx.is_none()) {
|
||||
// We only use the None's for bookeeping dimensions
|
||||
slice_num++;
|
||||
} else if (nb::isinstance<array>(pyidx)) {
|
||||
ax++;
|
||||
auto idx = nb::cast<array>(pyidx);
|
||||
std::vector<int> idx_shape;
|
||||
if (!arrays_first) {
|
||||
idx_shape.insert(idx_shape.end(), slice_num, 1);
|
||||
std::vector<int> idx_shape(idx_ndim, 1);
|
||||
|
||||
// Place the arrays in the correct dimension
|
||||
int st = (!arrays_first) * slice_num + max_dim - idx.ndim();
|
||||
for (int j = 0; j < idx.ndim(); j++) {
|
||||
idx_shape[st + j] = idx.shape()[j];
|
||||
}
|
||||
idx_shape.insert(idx_shape.end(), max_dim - idx.ndim(), 1);
|
||||
idx_shape.insert(idx_shape.end(), idx.shape().begin(), idx.shape().end());
|
||||
idx_shape.insert(
|
||||
idx_shape.end(), num_slices - (arrays_first ? 0 : slice_num), 1);
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
if (!arrays_first && ++array_num == num_arrays) {
|
||||
slice_num += max_dim;
|
||||
}
|
||||
|
||||
// Add the shape to the update
|
||||
update_shape[ax - 1] = 1;
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast the update to the indices and slices
|
||||
arr_indices = broadcast_arrays(arr_indices);
|
||||
up_shape = arr_indices[0].shape();
|
||||
up_shape.insert(
|
||||
up_shape.end(),
|
||||
auto up_shape_broadcast = arr_indices[0].shape();
|
||||
|
||||
up_shape_broadcast.insert(
|
||||
up_shape_broadcast.end(), slice_shapes.begin(), slice_shapes.end());
|
||||
up_shape_broadcast.insert(
|
||||
up_shape_broadcast.end(),
|
||||
src.shape().begin() + non_none_indices,
|
||||
src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(
|
||||
up_shape.begin() + arr_indices[0].ndim(), non_none_indices, 1);
|
||||
up = reshape(up, up_shape);
|
||||
up = broadcast_to(up, up_shape_broadcast);
|
||||
|
||||
// Reshape the update with the size-1 dims for the int and array indices
|
||||
auto up_reshape = arr_indices[0].shape();
|
||||
up_reshape.insert(up_reshape.end(), update_shape.begin(), update_shape.end());
|
||||
up_reshape.insert(
|
||||
up_reshape.end(),
|
||||
src.shape().begin() + non_none_indices,
|
||||
src.shape().end());
|
||||
|
||||
up = reshape(up, up_reshape);
|
||||
|
||||
// Collect axes
|
||||
std::vector<int> axes(arr_indices.size(), 0);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
|
||||
@ -654,7 +735,112 @@ mlx_compute_scatter_args(
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
|
||||
auto mlx_slice_update(
|
||||
const array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
// Can't route to slice update if not slice or tuple
|
||||
if (src.ndim() == 0 ||
|
||||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj))) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
|
||||
// Should be able to route to slice update
|
||||
|
||||
// Pre process tuple
|
||||
auto upd = to_array(v, src.dtype());
|
||||
|
||||
// Remove leading singletons dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < upd.ndim() && upd.shape(s) == 1; s++) {
|
||||
};
|
||||
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
|
||||
up_shape = up_shape.empty() ? std::vector{1} : up_shape;
|
||||
auto up = reshape(upd, up_shape);
|
||||
|
||||
// Build slice update params
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> stops = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
|
||||
// If it's just a simple slice, just do a slice update and return
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
// Read slice arguments
|
||||
get_slice_params(
|
||||
starts[0],
|
||||
stops[0],
|
||||
strides[0],
|
||||
nb::cast<nb::slice>(obj),
|
||||
src.shape(0));
|
||||
|
||||
// Do slice update
|
||||
auto out = slice_update(src, up, starts, stops, strides);
|
||||
return std::make_pair(true, out);
|
||||
}
|
||||
|
||||
// It must be a tuple
|
||||
auto entries = nb::cast<nb::tuple>(obj);
|
||||
|
||||
// Can't route to slice update if any arrays are present
|
||||
for (int i = 0; i < entries.size(); i++) {
|
||||
auto idx = entries[i];
|
||||
if (nb::isinstance<array>(idx)) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
}
|
||||
|
||||
// Expand ellipses into a series of ':' slices
|
||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||
|
||||
// Dimension check
|
||||
if (non_none_indices > src.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// If no non-None indices return the broadcasted update
|
||||
if (non_none_indices == 0) {
|
||||
return std::make_pair(true, broadcast_to(up, src.shape()));
|
||||
}
|
||||
|
||||
// Process entries
|
||||
std::vector<int> upd_expand_dims;
|
||||
int ax = 0;
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
get_slice_params(
|
||||
starts[ax],
|
||||
stops[ax],
|
||||
strides[ax],
|
||||
nb::cast<nb::slice>(pyidx),
|
||||
src.shape(ax));
|
||||
ax++;
|
||||
} else if (nb::isinstance<nb::int_>(pyidx)) {
|
||||
int st = nb::cast<int>(pyidx);
|
||||
st = (st < 0) ? st + src.shape(ax) : st;
|
||||
starts[ax] = st;
|
||||
stops[ax] = st + 1;
|
||||
if (src.ndim() - ax < up.ndim()) {
|
||||
upd_expand_dims.push_back(ax - src.ndim());
|
||||
}
|
||||
ax++;
|
||||
}
|
||||
}
|
||||
|
||||
up = expand_dims(up, upd_expand_dims);
|
||||
auto out = slice_update(src, up, starts, stops, strides);
|
||||
return std::make_pair(true, out);
|
||||
}
|
||||
|
||||
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
|
||||
auto [success, out] = mlx_slice_update(src, obj, v);
|
||||
if (success) {
|
||||
src.overwrite_descriptor(out);
|
||||
return;
|
||||
}
|
||||
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
auto out = scatter(src, indices, updates, axes);
|
||||
|
Loading…
Reference in New Issue
Block a user