diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 91c4ed9e3..a38323797 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -803,9 +803,10 @@ auto mlx_slice_update( // Pre process tuple auto upd = to_array(v, src.dtype()); - // Remove leading singletons dimensions from the update + // Remove extra leading singletons dimensions from the update int s = 0; - for (; s < upd.ndim() && upd.shape(s) == 1; s++) { + for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim(); + s++) { }; auto up_shape = std::vector(upd.shape().begin() + s, upd.shape().end()); up_shape = up_shape.empty() ? std::vector{1} : up_shape;