mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
allow take to work with integer index (#1440)
This commit is contained in:
@@ -1398,13 +1398,15 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"take",
|
||||
[](const array& a,
|
||||
const array& indices,
|
||||
const std::variant<int, array>& indices,
|
||||
const std::optional<int>& axis,
|
||||
StreamOrDevice s) {
|
||||
if (axis.has_value()) {
|
||||
return take(a, indices, axis.value(), s);
|
||||
if (auto pv = std::get_if<int>(&indices); pv) {
|
||||
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
|
||||
} else {
|
||||
return take(a, indices, s);
|
||||
auto indices_ = std::get<array>(indices);
|
||||
return axis ? take(a, indices_, axis.value(), s)
|
||||
: take(a, indices_, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
@@ -1413,7 +1415,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Take elements along an axis.
|
||||
|
||||
@@ -1425,7 +1427,7 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
indices (array): Input array with integral type.
|
||||
indices (int or array): Integer index or input array with integral type.
|
||||
axis (int, optional): Axis along which to perform the take. If unspecified
|
||||
the array is treated as a flattened 1-D vector.
|
||||
|
||||
|
Reference in New Issue
Block a user