allow take to work with integer index (#1440)

This commit is contained in:
Awni Hannun
2024-09-26 15:58:03 -07:00
committed by GitHub
parent 5b6f38df2b
commit 718aea3f1d
4 changed files with 74 additions and 17 deletions

View File

@@ -144,9 +144,7 @@ array squeeze(
StreamOrDevice s = {});
/** Remove singleton dimensions at the given axis. */
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
return squeeze(a, std::vector<int>{axis}, s);
}
array squeeze(const array& a, int axis, StreamOrDevice s = {});
/** Remove all singleton dimensions. */
array squeeze(const array& a, StreamOrDevice s = {});
@@ -171,8 +169,8 @@ array slice(
/** Slice an array with a stride of 1 in each dimension. */
array slice(
const array& a,
const std::vector<int>& start,
const std::vector<int>& stop,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s = {});
/** Update a slice from the source array */
@@ -936,9 +934,11 @@ array take(
const array& indices,
int axis,
StreamOrDevice s = {});
array take(const array& a, int index, int axis, StreamOrDevice s = {});
/** Take array entries at the given indices treating the array as flattened. */
array take(const array& a, const array& indices, StreamOrDevice s = {});
array take(const array& a, int index, StreamOrDevice s = {});
/** Take array entries given indices along the axis */
array take_along_axis(