mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
28
mlx/ops.h
28
mlx/ops.h
@@ -164,11 +164,27 @@ array slice(
|
||||
Shape stop,
|
||||
Shape strides,
|
||||
StreamOrDevice s = {});
|
||||
inline array slice(
|
||||
const array& a,
|
||||
std::initializer_list<int> start,
|
||||
Shape stop,
|
||||
Shape strides,
|
||||
StreamOrDevice s = {}) {
|
||||
return slice(a, Shape(start), std::move(stop), std::move(strides), s);
|
||||
}
|
||||
|
||||
/** Slice an array with a stride of 1 in each dimension. */
|
||||
array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array */
|
||||
/** Slice an array with dynamic starting indices. */
|
||||
array slice(
|
||||
const array& a,
|
||||
const array& start,
|
||||
std::vector<int> axes,
|
||||
Shape slice_size,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array. */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
@@ -177,7 +193,7 @@ array slice_update(
|
||||
Shape strides,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array with stride 1 in each dimension */
|
||||
/** Update a slice from the source array with stride 1 in each dimension. */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
@@ -185,6 +201,14 @@ array slice_update(
|
||||
Shape stop,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array with dynamic starting indices. */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
const array& start,
|
||||
std::vector<int> axes,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Split an array into sub-arrays along a given axis. */
|
||||
std::vector<array>
|
||||
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
||||
|
||||
Reference in New Issue
Block a user