allow take to work with integer index (#1440)

This commit is contained in:
Awni Hannun
2024-09-26 15:58:03 -07:00
committed by GitHub
parent 5b6f38df2b
commit 718aea3f1d
4 changed files with 74 additions and 17 deletions

View File

@@ -504,7 +504,20 @@ array squeeze(
shape.push_back(a.shape(i));
}
}
return reshape(a, shape, s);
return reshape(a, std::move(shape), s);
}
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
int ax = axis < 0 ? axis + a.ndim() : axis;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.erase(shape.begin() + ax);
return reshape(a, std::move(shape), s);
}
array squeeze(const array& a, StreamOrDevice s /* = {} */) {
@@ -657,10 +670,15 @@ array slice(
array slice(
const array& a,
const std::vector<int>& start,
const std::vector<int>& stop,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s /* = {} */) {
return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s));
return slice(
a,
std::move(start),
std::move(stop),
std::vector<int>(a.ndim(), 1),
to_stream(s));
}
/** Update a slice from the source array */
@@ -2715,13 +2733,43 @@ array take(
// Squeeze the axis we take over
std::vector<int> out_shape = out.shape();
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
return reshape(out, out_shape, s);
return reshape(out, std::move(out_shape), s);
}
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
return take(reshape(a, {-1}, s), indices, 0, s);
}
array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
// Check for valid axis
if (axis + static_cast<int>(a.ndim()) < 0 ||
axis >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[take] Received invalid axis " << axis << " for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
// Check for valid take
if (a.size() == 0) {
throw std::invalid_argument(
"[take] Cannot do a non-empty take from an array with zero elements.");
}
// Handle negative axis
axis = axis < 0 ? a.ndim() + axis : axis;
std::vector<int> starts(a.ndim(), 0);
std::vector<int> stops = a.shape();
starts[axis] = index;
stops[axis] = index + 1;
return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s);
}
array take(const array& a, int index, StreamOrDevice s /* = {} */) {
return take(reshape(a, {-1}, s), index, 0, s);
}
array take_along_axis(
const array& a,
const array& indices,
@@ -2764,7 +2812,7 @@ array take_along_axis(
// Squeeze out the slice shape
std::vector<int> out_shape(
out.shape().begin(), out.shape().begin() + a.ndim());
return reshape(out, out_shape, s);
return reshape(out, std::move(out_shape), s);
}
array put_along_axis(

View File

@@ -144,9 +144,7 @@ array squeeze(
StreamOrDevice s = {});
/** Remove singleton dimensions at the given axis. */
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
return squeeze(a, std::vector<int>{axis}, s);
}
array squeeze(const array& a, int axis, StreamOrDevice s = {});
/** Remove all singleton dimensions. */
array squeeze(const array& a, StreamOrDevice s = {});
@@ -171,8 +169,8 @@ array slice(
/** Slice an array with a stride of 1 in each dimension. */
array slice(
const array& a,
const std::vector<int>& start,
const std::vector<int>& stop,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s = {});
/** Update a slice from the source array */
@@ -936,9 +934,11 @@ array take(
const array& indices,
int axis,
StreamOrDevice s = {});
array take(const array& a, int index, int axis, StreamOrDevice s = {});
/** Take array entries at the given indices treating the array as flattened. */
array take(const array& a, const array& indices, StreamOrDevice s = {});
array take(const array& a, int index, StreamOrDevice s = {});
/** Take array entries given indices along the axis */
array take_along_axis(