allow take to work with integer index (#1440)

This commit is contained in:
Awni Hannun
2024-09-26 15:58:03 -07:00
committed by GitHub
parent 5b6f38df2b
commit 718aea3f1d
4 changed files with 74 additions and 17 deletions

View File

@@ -1398,13 +1398,15 @@ void init_ops(nb::module_& m) {
m.def(
"take",
[](const array& a,
const array& indices,
const std::variant<int, array>& indices,
const std::optional<int>& axis,
StreamOrDevice s) {
if (axis.has_value()) {
return take(a, indices, axis.value(), s);
if (auto pv = std::get_if<int>(&indices); pv) {
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
} else {
return take(a, indices, s);
auto indices_ = std::get<array>(indices);
return axis ? take(a, indices_, axis.value(), s)
: take(a, indices_, s);
}
},
nb::arg(),
@@ -1413,7 +1415,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
"def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Take elements along an axis.
@@ -1425,7 +1427,7 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array.
indices (array): Input array with integral type.
indices (int or array): Integer index or input array with integral type.
axis (int, optional): Axis along which to perform the take. If unspecified
the array is treated as a flattened 1-D vector.

View File

@@ -1059,6 +1059,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
# Take with integer index
a = mx.arange(8).reshape(2, 4)
out = mx.take(a, 1, axis=0)
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6, 7])))
out = mx.take(a, 1, axis=1)
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
def test_take_along_axis(self):
a_np = np.arange(8).reshape(2, 2, 2)
a_mlx = mx.array(a_np)