Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -5004,4 +5004,81 @@ void init_ops(nb::module_& m) {
Returns:
array: The imaginary part of ``a``.
)pbdoc");
m.def(
"slice",
[](const mx::array& a,
const mx::array& start_indices,
std::vector<int> axes,
mx::Shape slice_size,
mx::StreamOrDevice s) {
return mx::slice(
a, start_indices, std::move(axes), std::move(slice_size), s);
},
nb::arg(),
"start_indices"_a,
"axes"_a,
"slice_size"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def slice(a: array, start_indices: array, axes: Sequence[int], slice_size: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Extract a sub-array from the input array.
Args:
a (array): Input array
start_indices (array): The index location to start the slice at.
axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.
slice_size (tuple(int)): The size of the slice.
Returns:
array: The sliced output array.
Example:
>>> a = mx.array([[1, 2, 3], [4, 5, 6]])
>>> mx.slice(a, start_indices=mx.array(1), axes=(0,), slice_size=(1, 2))
array([[4, 5]], dtype=int32)
>>>
>>> mx.slice(a, start_indices=mx.array(1), axes=(1,), slice_size=(2, 1))
array([[2],
[5]], dtype=int32)
)pbdoc");
m.def(
"slice_update",
[](const mx::array& src,
const mx::array& update,
const mx::array& start_indices,
std::vector<int> axes,
mx::StreamOrDevice s) {
return mx::slice_update(src, update, start_indices, axes, s);
},
nb::arg(),
"update"_a,
"start_indices"_a,
"axes"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def slice_update(a: array, update: array, start_indices: array, axes: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Update a sub-array of the input array.
Args:
a (array): The input array to update
update (array): The update array.
start_indices (array): The index location to start the slice at.
axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.
Returns:
array: The output array with the same shape and type as the input.
Example:
>>> a = mx.zeros((3, 3))
>>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1))
array([[0, 0, 0],
[0, 1, 0],
[0, 1, 0]], dtype=float32)
)pbdoc");
}