mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
allow take to work with integer index (#1440)
This commit is contained in:
@@ -1398,13 +1398,15 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"take",
|
||||
[](const array& a,
|
||||
const array& indices,
|
||||
const std::variant<int, array>& indices,
|
||||
const std::optional<int>& axis,
|
||||
StreamOrDevice s) {
|
||||
if (axis.has_value()) {
|
||||
return take(a, indices, axis.value(), s);
|
||||
if (auto pv = std::get_if<int>(&indices); pv) {
|
||||
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
|
||||
} else {
|
||||
return take(a, indices, s);
|
||||
auto indices_ = std::get<array>(indices);
|
||||
return axis ? take(a, indices_, axis.value(), s)
|
||||
: take(a, indices_, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
@@ -1413,7 +1415,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Take elements along an axis.
|
||||
|
||||
@@ -1425,7 +1427,7 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
indices (array): Input array with integral type.
|
||||
indices (int or array): Integer index or input array with integral type.
|
||||
axis (int, optional): Axis along which to perform the take. If unspecified
|
||||
the array is treated as a flattened 1-D vector.
|
||||
|
||||
|
@@ -1059,6 +1059,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
# Take with integer index
|
||||
a = mx.arange(8).reshape(2, 4)
|
||||
out = mx.take(a, 1, axis=0)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6, 7])))
|
||||
out = mx.take(a, 1, axis=1)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
|
||||
|
||||
def test_take_along_axis(self):
|
||||
a_np = np.arange(8).reshape(2, 2, 2)
|
||||
a_mlx = mx.array(a_np)
|
||||
|
Reference in New Issue
Block a user