From 718aea3f1d049227fe06fd6e33738191dd3cac7d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 26 Sep 2024 15:58:03 -0700 Subject: [PATCH] allow take to work with integer index (#1440) --- mlx/ops.cpp | 60 ++++++++++++++++++++++++++++++++++++---- mlx/ops.h | 10 +++---- python/src/ops.cpp | 14 ++++++---- python/tests/test_ops.py | 7 +++++ 4 files changed, 74 insertions(+), 17 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f69943cd8..7bb9efe8c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -504,7 +504,20 @@ array squeeze( shape.push_back(a.shape(i)); } } - return reshape(a, shape, s); + return reshape(a, std::move(shape), s); +} + +array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) { + int ax = axis < 0 ? axis + a.ndim() : axis; + if (ax < 0 || ax >= a.ndim()) { + std::ostringstream msg; + msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + auto shape = a.shape(); + shape.erase(shape.begin() + ax); + return reshape(a, std::move(shape), s); } array squeeze(const array& a, StreamOrDevice s /* = {} */) { @@ -657,10 +670,15 @@ array slice( array slice( const array& a, - const std::vector& start, - const std::vector& stop, + std::vector start, + std::vector stop, StreamOrDevice s /* = {} */) { - return slice(a, start, stop, std::vector(a.ndim(), 1), to_stream(s)); + return slice( + a, + std::move(start), + std::move(stop), + std::vector(a.ndim(), 1), + to_stream(s)); } /** Update a slice from the source array */ @@ -2715,13 +2733,43 @@ array take( // Squeeze the axis we take over std::vector out_shape = out.shape(); out_shape.erase(out_shape.begin() + indices.ndim() + axis); - return reshape(out, out_shape, s); + return reshape(out, std::move(out_shape), s); } array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) { return take(reshape(a, {-1}, s), indices, 0, s); } +array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) { + // Check for valid axis + if (axis + static_cast(a.ndim()) < 0 || + axis >= static_cast(a.ndim())) { + std::ostringstream msg; + msg << "[take] Received invalid axis " << axis << " for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + // Check for valid take + if (a.size() == 0) { + throw std::invalid_argument( + "[take] Cannot do a non-empty take from an array with zero elements."); + } + + // Handle negative axis + axis = axis < 0 ? a.ndim() + axis : axis; + + std::vector starts(a.ndim(), 0); + std::vector stops = a.shape(); + starts[axis] = index; + stops[axis] = index + 1; + return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s); +} + +array take(const array& a, int index, StreamOrDevice s /* = {} */) { + return take(reshape(a, {-1}, s), index, 0, s); +} + array take_along_axis( const array& a, const array& indices, @@ -2764,7 +2812,7 @@ array take_along_axis( // Squeeze out the slice shape std::vector out_shape( out.shape().begin(), out.shape().begin() + a.ndim()); - return reshape(out, out_shape, s); + return reshape(out, std::move(out_shape), s); } array put_along_axis( diff --git a/mlx/ops.h b/mlx/ops.h index b0c093f47..723e73e1a 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -144,9 +144,7 @@ array squeeze( StreamOrDevice s = {}); /** Remove singleton dimensions at the given axis. */ -inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) { - return squeeze(a, std::vector{axis}, s); -} +array squeeze(const array& a, int axis, StreamOrDevice s = {}); /** Remove all singleton dimensions. */ array squeeze(const array& a, StreamOrDevice s = {}); @@ -171,8 +169,8 @@ array slice( /** Slice an array with a stride of 1 in each dimension. */ array slice( const array& a, - const std::vector& start, - const std::vector& stop, + std::vector start, + std::vector stop, StreamOrDevice s = {}); /** Update a slice from the source array */ @@ -936,9 +934,11 @@ array take( const array& indices, int axis, StreamOrDevice s = {}); +array take(const array& a, int index, int axis, StreamOrDevice s = {}); /** Take array entries at the given indices treating the array as flattened. */ array take(const array& a, const array& indices, StreamOrDevice s = {}); +array take(const array& a, int index, StreamOrDevice s = {}); /** Take array entries given indices along the axis */ array take_along_axis( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index e1117786b..8e74907b0 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1398,13 +1398,15 @@ void init_ops(nb::module_& m) { m.def( "take", [](const array& a, - const array& indices, + const std::variant& indices, const std::optional& axis, StreamOrDevice s) { - if (axis.has_value()) { - return take(a, indices, axis.value(), s); + if (auto pv = std::get_if(&indices); pv) { + return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s); } else { - return take(a, indices, s); + auto indices_ = std::get(indices); + return axis ? take(a, indices_, axis.value(), s) + : take(a, indices_, s); } }, nb::arg(), @@ -1413,7 +1415,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + "def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Take elements along an axis. @@ -1425,7 +1427,7 @@ void init_ops(nb::module_& m) { Args: a (array): Input array. - indices (array): Input array with integral type. + indices (int or array): Integer index or input array with integral type. axis (int, optional): Axis along which to perform the take. If unspecified the array is treated as a flattened 1-D vector. diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 7a9404f27..2327fcff6 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1059,6 +1059,13 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape) self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) + # Take with integer index + a = mx.arange(8).reshape(2, 4) + out = mx.take(a, 1, axis=0) + self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6, 7]))) + out = mx.take(a, 1, axis=1) + self.assertTrue(mx.array_equal(out, mx.array([1, 5]))) + def test_take_along_axis(self): a_np = np.arange(8).reshape(2, 2, 2) a_mlx = mx.array(a_np)