mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
110
mlx/ops.cpp
110
mlx/ops.cpp
@@ -647,6 +647,52 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
return std::make_pair(has_neg_strides, out_shape);
|
||||
}
|
||||
|
||||
void normalize_dynamic_slice_inputs(
|
||||
const array& a,
|
||||
const array& start,
|
||||
std::vector<int>& axes,
|
||||
const std::string prefix) {
|
||||
if (start.size() > a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Invalid number of starting positions for "
|
||||
<< "array with dimension " << a.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (start.ndim() > 1) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " array of starting indices "
|
||||
<< "must be zero or one dimensional but has dimension " << start.ndim()
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (start.size() != axes.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Number of starting indices " << start.size()
|
||||
<< " does not match number of axes " << axes.size() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!issubdtype(start.dtype(), integer)) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Start indices must be integers, got type "
|
||||
<< start.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
for (auto& ax : axes) {
|
||||
auto new_ax = ax < 0 ? ax + a.ndim() : ax;
|
||||
if (new_ax < 0 || new_ax >= a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Invalid axis " << ax << " for array with dimension "
|
||||
<< a.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
ax = new_ax;
|
||||
}
|
||||
std::set dims(axes.begin(), axes.end());
|
||||
if (dims.size() != axes.size()) {
|
||||
throw std::invalid_argument(prefix + " Repeat axes not allowed.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array slice(
|
||||
@@ -687,6 +733,38 @@ array slice(
|
||||
a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s));
|
||||
}
|
||||
|
||||
array slice(
|
||||
const array& a,
|
||||
const array& start,
|
||||
std::vector<int> axes,
|
||||
Shape slice_size,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
normalize_dynamic_slice_inputs(a, start, axes, "[slice]");
|
||||
|
||||
// Check the slice_size
|
||||
if (slice_size.size() != a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[slice] Invalid slice size for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
for (int i = 0; i < a.ndim(); ++i) {
|
||||
if (slice_size[i] > a.shape(i)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[slice] Invalid slice size " << slice_size
|
||||
<< " for array with shape " << a.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
auto out_shape = slice_size;
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<DynamicSlice>(
|
||||
to_stream(s), std::move(axes), std::move(slice_size)),
|
||||
{a, start});
|
||||
}
|
||||
|
||||
/** Update a slice from the source array */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
@@ -699,7 +777,7 @@ array slice_update(
|
||||
if (start.size() != src.ndim() || stop.size() != src.ndim() ||
|
||||
strides.size() != src.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[slice] Invalid number of indices or strides for "
|
||||
msg << "[slice_update] Invalid number of indices or strides for "
|
||||
<< "array with dimension " << src.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -734,6 +812,36 @@ array slice_update(
|
||||
src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);
|
||||
}
|
||||
|
||||
/** Update a slice from the source array */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
const array& start,
|
||||
std::vector<int> axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
normalize_dynamic_slice_inputs(src, start, axes, "[slice_update]");
|
||||
|
||||
// Broadcast update with unspecified axes
|
||||
auto up_shape = update.shape();
|
||||
auto dim_diff = std::max(src.ndim() - update.ndim(), 0ul);
|
||||
up_shape.insert(
|
||||
up_shape.begin(), src.shape().begin(), src.shape().begin() + dim_diff);
|
||||
for (int d = dim_diff; d < src.ndim(); ++d) {
|
||||
up_shape[d] = std::min(up_shape[d], src.shape(d));
|
||||
}
|
||||
for (auto ax : axes) {
|
||||
if (ax < dim_diff) {
|
||||
up_shape[ax] = 1;
|
||||
}
|
||||
}
|
||||
auto upd = broadcast_to(astype(update, src.dtype(), s), up_shape, s);
|
||||
return array(
|
||||
src.shape(),
|
||||
src.dtype(),
|
||||
std::make_shared<DynamicSliceUpdate>(to_stream(s), std::move(axes)),
|
||||
{src, upd, start});
|
||||
}
|
||||
|
||||
std::vector<array> split(
|
||||
const array& a,
|
||||
const Shape& indices,
|
||||
|
||||
Reference in New Issue
Block a user