mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
@@ -764,7 +764,7 @@ auto mlx_slice_update(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
// Can't route to slice update if not slice or tuple
|
||||
// Can't route to slice update if not slice, tuple, or int
|
||||
if (src.ndim() == 0 ||
|
||||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
|
||||
!nb::isinstance<nb::int_>(obj))) {
|
||||
@@ -845,20 +845,14 @@ auto mlx_slice_update(
|
||||
return std::make_pair(true, broadcast_to(up, src.shape()));
|
||||
}
|
||||
|
||||
// Process entries
|
||||
mx::Shape up_reshape(src.ndim());
|
||||
int ax = src.ndim() - 1;
|
||||
int up_ax = up.ndim() - 1;
|
||||
for (; ax >= non_none_indices; ax--) {
|
||||
if (up_ax >= 0) {
|
||||
up_reshape[ax] = up.shape(up_ax);
|
||||
up_ax--;
|
||||
} else {
|
||||
up_reshape[ax] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = indices.size() - 1; i >= 0; --i) {
|
||||
int unspecified = src.ndim() - non_none_indices;
|
||||
std::vector<int> squeeze_dims;
|
||||
std::vector<int> expand_dims;
|
||||
for (int i = indices.size() - 1,
|
||||
ax = non_none_indices - 1,
|
||||
upd_ax = upd.ndim() - unspecified - 1;
|
||||
i >= 0;
|
||||
--i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
get_slice_params(
|
||||
@@ -867,19 +861,26 @@ auto mlx_slice_update(
|
||||
strides[ax],
|
||||
nb::cast<nb::slice>(pyidx),
|
||||
src.shape(ax));
|
||||
up_reshape[ax] = (up_ax >= 0) ? up.shape(up_ax--) : 1;
|
||||
ax--;
|
||||
upd_ax--;
|
||||
} else if (nb::isinstance<nb::int_>(pyidx)) {
|
||||
int st = nb::cast<int>(pyidx);
|
||||
st = (st < 0) ? st + src.shape(ax) : st;
|
||||
st = (st < 0) ? st + src.shape(i) : st;
|
||||
starts[ax] = st;
|
||||
stops[ax] = st + 1;
|
||||
up_reshape[ax] = 1;
|
||||
if (upd_ax >= 0) {
|
||||
expand_dims.push_back(i - indices.size() - unspecified);
|
||||
}
|
||||
ax--;
|
||||
} else if (pyidx.is_none()) {
|
||||
if (upd_ax-- >= 0) {
|
||||
squeeze_dims.push_back(i - indices.size() - unspecified);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
up = reshape(up, std::move(up_reshape));
|
||||
up = mx::squeeze(
|
||||
mx::expand_dims(up, std::move(expand_dims)), std::move(squeeze_dims));
|
||||
auto out = slice_update(src, up, starts, stops, strides);
|
||||
return std::make_pair(true, out);
|
||||
}
|
||||
|
@@ -5004,4 +5004,81 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The imaginary part of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"slice",
|
||||
[](const mx::array& a,
|
||||
const mx::array& start_indices,
|
||||
std::vector<int> axes,
|
||||
mx::Shape slice_size,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::slice(
|
||||
a, start_indices, std::move(axes), std::move(slice_size), s);
|
||||
},
|
||||
nb::arg(),
|
||||
"start_indices"_a,
|
||||
"axes"_a,
|
||||
"slice_size"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def slice(a: array, start_indices: array, axes: Sequence[int], slice_size: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Extract a sub-array from the input array.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
start_indices (array): The index location to start the slice at.
|
||||
axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.
|
||||
slice_size (tuple(int)): The size of the slice.
|
||||
|
||||
Returns:
|
||||
array: The sliced output array.
|
||||
|
||||
Example:
|
||||
|
||||
>>> a = mx.array([[1, 2, 3], [4, 5, 6]])
|
||||
>>> mx.slice(a, start_indices=mx.array(1), axes=(0,), slice_size=(1, 2))
|
||||
array([[4, 5]], dtype=int32)
|
||||
>>>
|
||||
>>> mx.slice(a, start_indices=mx.array(1), axes=(1,), slice_size=(2, 1))
|
||||
array([[2],
|
||||
[5]], dtype=int32)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"slice_update",
|
||||
[](const mx::array& src,
|
||||
const mx::array& update,
|
||||
const mx::array& start_indices,
|
||||
std::vector<int> axes,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::slice_update(src, update, start_indices, axes, s);
|
||||
},
|
||||
nb::arg(),
|
||||
"update"_a,
|
||||
"start_indices"_a,
|
||||
"axes"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def slice_update(a: array, update: array, start_indices: array, axes: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Update a sub-array of the input array.
|
||||
|
||||
Args:
|
||||
a (array): The input array to update
|
||||
update (array): The update array.
|
||||
start_indices (array): The index location to start the slice at.
|
||||
axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.
|
||||
|
||||
Returns:
|
||||
array: The output array with the same shape and type as the input.
|
||||
|
||||
Example:
|
||||
|
||||
>>> a = mx.zeros((3, 3))
|
||||
>>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1))
|
||||
array([[0, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 1, 0]], dtype=float32)
|
||||
)pbdoc");
|
||||
}
|
||||
|
Reference in New Issue
Block a user