mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
allow take to work with integer index (#1440)
This commit is contained in:
10
mlx/ops.h
10
mlx/ops.h
@@ -144,9 +144,7 @@ array squeeze(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Remove singleton dimensions at the given axis. */
|
||||
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
|
||||
return squeeze(a, std::vector<int>{axis}, s);
|
||||
}
|
||||
array squeeze(const array& a, int axis, StreamOrDevice s = {});
|
||||
|
||||
/** Remove all singleton dimensions. */
|
||||
array squeeze(const array& a, StreamOrDevice s = {});
|
||||
@@ -171,8 +169,8 @@ array slice(
|
||||
/** Slice an array with a stride of 1 in each dimension. */
|
||||
array slice(
|
||||
const array& a,
|
||||
const std::vector<int>& start,
|
||||
const std::vector<int>& stop,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array */
|
||||
@@ -936,9 +934,11 @@ array take(
|
||||
const array& indices,
|
||||
int axis,
|
||||
StreamOrDevice s = {});
|
||||
array take(const array& a, int index, int axis, StreamOrDevice s = {});
|
||||
|
||||
/** Take array entries at the given indices treating the array as flattened. */
|
||||
array take(const array& a, const array& indices, StreamOrDevice s = {});
|
||||
array take(const array& a, int index, StreamOrDevice s = {});
|
||||
|
||||
/** Take array entries given indices along the axis */
|
||||
array take_along_axis(
|
||||
|
||||
Reference in New Issue
Block a user