Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -164,11 +164,27 @@ array slice(
Shape stop,
Shape strides,
StreamOrDevice s = {});
inline array slice(
const array& a,
std::initializer_list<int> start,
Shape stop,
Shape strides,
StreamOrDevice s = {}) {
return slice(a, Shape(start), std::move(stop), std::move(strides), s);
}
/** Slice an array with a stride of 1 in each dimension. */
array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
/** Update a slice from the source array */
/** Slice an array with dynamic starting indices. */
array slice(
const array& a,
const array& start,
std::vector<int> axes,
Shape slice_size,
StreamOrDevice s = {});
/** Update a slice from the source array. */
array slice_update(
const array& src,
const array& update,
@@ -177,7 +193,7 @@ array slice_update(
Shape strides,
StreamOrDevice s = {});
/** Update a slice from the source array with stride 1 in each dimension */
/** Update a slice from the source array with stride 1 in each dimension. */
array slice_update(
const array& src,
const array& update,
@@ -185,6 +201,14 @@ array slice_update(
Shape stop,
StreamOrDevice s = {});
/** Update a slice from the source array with dynamic starting indices. */
array slice_update(
const array& src,
const array& update,
const array& start,
std::vector<int> axes,
StreamOrDevice s = {});
/** Split an array into sub-arrays along a given axis. */
std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});