allow take to work with integer index (#1440)

This commit is contained in:
Awni Hannun
2024-09-26 15:58:03 -07:00
committed by GitHub
parent 5b6f38df2b
commit 718aea3f1d
4 changed files with 74 additions and 17 deletions

View File

@@ -1398,13 +1398,15 @@ void init_ops(nb::module_& m) {
m.def(
"take",
[](const array& a,
const array& indices,
const std::variant<int, array>& indices,
const std::optional<int>& axis,
StreamOrDevice s) {
if (axis.has_value()) {
return take(a, indices, axis.value(), s);
if (auto pv = std::get_if<int>(&indices); pv) {
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
} else {
return take(a, indices, s);
auto indices_ = std::get<array>(indices);
return axis ? take(a, indices_, axis.value(), s)
: take(a, indices_, s);
}
},
nb::arg(),
@@ -1413,7 +1415,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
"def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Take elements along an axis.
@@ -1425,7 +1427,7 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array.
indices (array): Input array with integral type.
indices (int or array): Integer index or input array with integral type.
axis (int, optional): Axis along which to perform the take. If unspecified
the array is treated as a flattened 1-D vector.