Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -647,6 +647,52 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
return std::make_pair(has_neg_strides, out_shape);
}
void normalize_dynamic_slice_inputs(
const array& a,
const array& start,
std::vector<int>& axes,
const std::string prefix) {
if (start.size() > a.ndim()) {
std::ostringstream msg;
msg << prefix << " Invalid number of starting positions for "
<< "array with dimension " << a.ndim() << ".";
throw std::invalid_argument(msg.str());
}
if (start.ndim() > 1) {
std::ostringstream msg;
msg << prefix << " array of starting indices "
<< "must be zero or one dimensional but has dimension " << start.ndim()
<< ".";
throw std::invalid_argument(msg.str());
}
if (start.size() != axes.size()) {
std::ostringstream msg;
msg << prefix << " Number of starting indices " << start.size()
<< " does not match number of axes " << axes.size() << ".";
throw std::invalid_argument(msg.str());
}
if (!issubdtype(start.dtype(), integer)) {
std::ostringstream msg;
msg << prefix << " Start indices must be integers, got type "
<< start.dtype() << ".";
throw std::invalid_argument(msg.str());
}
for (auto& ax : axes) {
auto new_ax = ax < 0 ? ax + a.ndim() : ax;
if (new_ax < 0 || new_ax >= a.ndim()) {
std::ostringstream msg;
msg << prefix << " Invalid axis " << ax << " for array with dimension "
<< a.ndim() << ".";
throw std::invalid_argument(msg.str());
}
ax = new_ax;
}
std::set dims(axes.begin(), axes.end());
if (dims.size() != axes.size()) {
throw std::invalid_argument(prefix + " Repeat axes not allowed.");
}
}
} // namespace
array slice(
@@ -687,6 +733,38 @@ array slice(
a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s));
}
array slice(
const array& a,
const array& start,
std::vector<int> axes,
Shape slice_size,
StreamOrDevice s /* = {} */) {
normalize_dynamic_slice_inputs(a, start, axes, "[slice]");
// Check the slice_size
if (slice_size.size() != a.ndim()) {
std::ostringstream msg;
msg << "[slice] Invalid slice size for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < a.ndim(); ++i) {
if (slice_size[i] > a.shape(i)) {
std::ostringstream msg;
msg << "[slice] Invalid slice size " << slice_size
<< " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
}
auto out_shape = slice_size;
return array(
std::move(out_shape),
a.dtype(),
std::make_shared<DynamicSlice>(
to_stream(s), std::move(axes), std::move(slice_size)),
{a, start});
}
/** Update a slice from the source array */
array slice_update(
const array& src,
@@ -699,7 +777,7 @@ array slice_update(
if (start.size() != src.ndim() || stop.size() != src.ndim() ||
strides.size() != src.ndim()) {
std::ostringstream msg;
msg << "[slice] Invalid number of indices or strides for "
msg << "[slice_update] Invalid number of indices or strides for "
<< "array with dimension " << src.ndim() << ".";
throw std::invalid_argument(msg.str());
}
@@ -734,6 +812,36 @@ array slice_update(
src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);
}
/** Update a slice from the source array */
array slice_update(
const array& src,
const array& update,
const array& start,
std::vector<int> axes,
StreamOrDevice s /* = {} */) {
normalize_dynamic_slice_inputs(src, start, axes, "[slice_update]");
// Broadcast update with unspecified axes
auto up_shape = update.shape();
auto dim_diff = std::max(src.ndim() - update.ndim(), 0ul);
up_shape.insert(
up_shape.begin(), src.shape().begin(), src.shape().begin() + dim_diff);
for (int d = dim_diff; d < src.ndim(); ++d) {
up_shape[d] = std::min(up_shape[d], src.shape(d));
}
for (auto ax : axes) {
if (ax < dim_diff) {
up_shape[ax] = 1;
}
}
auto upd = broadcast_to(astype(update, src.dtype(), s), up_shape, s);
return array(
src.shape(),
src.dtype(),
std::make_shared<DynamicSliceUpdate>(to_stream(s), std::move(axes)),
{src, upd, start});
}
std::vector<array> split(
const array& a,
const Shape& indices,