diff --git a/mlx/backend/common/sort.cpp b/mlx/backend/common/sort.cpp index d7f4895bf..1d3d80218 100644 --- a/mlx/backend/common/sort.cpp +++ b/mlx/backend/common/sort.cpp @@ -111,7 +111,8 @@ void sort(const array& in, array& out, int axis) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; - size_t n_rows = in.size() / in.shape(axis); + size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); + size_t n_rows = in_size / in.shape(axis); auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); @@ -123,14 +124,16 @@ void sort(const array& in, array& out, int axis) { int axis_size = out.shape(axis); // Perform sorting in place + ContiguousIterator src_it( + remaining_shape, remaining_strides, remaining_shape.size()); for (int i = 0; i < n_rows; i++) { - size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); - T* data_ptr = out.data() + loc; + T* data_ptr = out.data() + src_it.loc; StridedIterator st(data_ptr, axis_stride, 0); StridedIterator ed(data_ptr, axis_stride, axis_size); std::stable_sort(st, ed); + src_it.step(); } } @@ -160,11 +163,15 @@ void argsort(const array& in, array& out, int axis) { int axis_size = in.shape(axis); // Perform sorting + ContiguousIterator in_it( + in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); + ContiguousIterator out_it( + out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); for (int i = 0; i < n_rows; i++) { - size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides); - size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides); - const T* data_ptr = in.data() + in_loc; - IdxT* idx_ptr = out.data() + out_loc; + const T* data_ptr = in.data() + in_it.loc; + IdxT* idx_ptr = out.data() + out_it.loc; + in_it.step(); + out_it.step(); StridedIterator st_(idx_ptr, out_stride, 0); StridedIterator ed_(idx_ptr, out_stride, axis_size); @@ -192,7 +199,8 @@ void partition(const array& in, array& out, int axis, int kth) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; - size_t n_rows = in.size() / in.shape(axis); + size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); + size_t n_rows = in_size / in.shape(axis); auto remaining_shape = in.shape(); remaining_shape.erase(remaining_shape.begin() + axis); @@ -206,9 +214,11 @@ void partition(const array& in, array& out, int axis, int kth) { kth = kth < 0 ? kth + axis_size : kth; // Perform partition in place + ContiguousIterator src_it( + remaining_shape, remaining_strides, remaining_shape.size()); for (int i = 0; i < n_rows; i++) { - size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); - T* data_ptr = out.data() + loc; + T* data_ptr = out.data() + src_it.loc; + src_it.step(); StridedIterator st(data_ptr, axis_stride, 0); StridedIterator md(data_ptr, axis_stride, kth); @@ -227,37 +237,49 @@ void argpartition(const array& in, array& out, int axis, int kth) { axis = axis < 0 ? axis + in.ndim() : axis; size_t n_rows = in.size() / in.shape(axis); - auto remaining_shape = in.shape(); - remaining_shape.erase(remaining_shape.begin() + axis); + auto in_remaining_shape = in.shape(); + in_remaining_shape.erase(in_remaining_shape.begin() + axis); - auto remaining_strides = in.strides(); - remaining_strides.erase(remaining_strides.begin() + axis); + auto in_remaining_strides = in.strides(); + in_remaining_strides.erase(in_remaining_strides.begin() + axis); - size_t axis_stride = in.strides()[axis]; + auto out_remaining_shape = out.shape(); + out_remaining_shape.erase(out_remaining_shape.begin() + axis); + + auto out_remaining_strides = out.strides(); + out_remaining_strides.erase(out_remaining_strides.begin() + axis); + + size_t in_stride = in.strides()[axis]; + size_t out_stride = out.strides()[axis]; int axis_size = in.shape(axis); kth = kth < 0 ? kth + axis_size : kth; // Perform partition + ContiguousIterator in_it( + in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); + ContiguousIterator out_it( + out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); for (int i = 0; i < n_rows; i++) { - size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); - const T* data_ptr = in.data() + loc; - IdxT* idx_ptr = out.data() + loc; + const T* data_ptr = in.data() + in_it.loc; + IdxT* idx_ptr = out.data() + out_it.loc; + in_it.step(); + out_it.step(); - StridedIterator st_(idx_ptr, axis_stride, 0); - StridedIterator ed_(idx_ptr, axis_stride, axis_size); + StridedIterator st_(idx_ptr, out_stride, 0); + StridedIterator ed_(idx_ptr, out_stride, axis_size); // Initialize with iota std::iota(st_, ed_, IdxT(0)); // Sort according to vals - StridedIterator st(idx_ptr, axis_stride, 0); - StridedIterator md(idx_ptr, axis_stride, kth); - StridedIterator ed(idx_ptr, axis_stride, axis_size); + StridedIterator st(idx_ptr, out_stride, 0); + StridedIterator md(idx_ptr, out_stride, kth); + StridedIterator ed(idx_ptr, out_stride, axis_size); - std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) { - auto v1 = data_ptr[a * axis_stride]; - auto v2 = data_ptr[b * axis_stride]; + std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { + auto v1 = data_ptr[a * in_stride]; + auto v2 = data_ptr[b * in_stride]; return v1 < v2 || (v1 == v2 && a < b); }); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 2327fcff6..381f0e8ca 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1954,6 +1954,17 @@ class TestOps(mlx_tests.MLXTestCase): M = top_k_mx.shape[axis or 0] self.assertEqual(M, (kth + N) % N) + def test_argpartition(self): + x = mx.broadcast_to(mx.array([1, 2, 3]), (2, 3)) + out = mx.argpartition(x, kth=1, axis=0) + expected = mx.array([[0, 0, 0], [1, 1, 1]]) + self.assertTrue(mx.array_equal(out, expected)) + + x = mx.array([[1, 2], [3, 4]]).T + out = mx.argpartition(x, kth=1, axis=0) + expected = mx.array([[0, 0], [1, 1]]) + self.assertTrue(mx.array_equal(out, expected)) + @unittest.skipIf( os.getenv("LOW_MEMORY", None) is not None, "This test requires a lot of memory",