mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	fix argpartition + faster {arg} sorts / partitions (#1453)
This commit is contained in:
		@@ -111,7 +111,8 @@ void sort(const array& in, array& out, int axis) {
 | 
			
		||||
 | 
			
		||||
  // Get axis, shape and stride info
 | 
			
		||||
  axis = axis < 0 ? axis + in.ndim() : axis;
 | 
			
		||||
  size_t n_rows = in.size() / in.shape(axis);
 | 
			
		||||
  size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
 | 
			
		||||
  size_t n_rows = in_size / in.shape(axis);
 | 
			
		||||
 | 
			
		||||
  auto remaining_shape = out.shape();
 | 
			
		||||
  remaining_shape.erase(remaining_shape.begin() + axis);
 | 
			
		||||
@@ -123,14 +124,16 @@ void sort(const array& in, array& out, int axis) {
 | 
			
		||||
  int axis_size = out.shape(axis);
 | 
			
		||||
 | 
			
		||||
  // Perform sorting in place
 | 
			
		||||
  ContiguousIterator<size_t> src_it(
 | 
			
		||||
      remaining_shape, remaining_strides, remaining_shape.size());
 | 
			
		||||
  for (int i = 0; i < n_rows; i++) {
 | 
			
		||||
    size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
 | 
			
		||||
    T* data_ptr = out.data<T>() + loc;
 | 
			
		||||
    T* data_ptr = out.data<T>() + src_it.loc;
 | 
			
		||||
 | 
			
		||||
    StridedIterator st(data_ptr, axis_stride, 0);
 | 
			
		||||
    StridedIterator ed(data_ptr, axis_stride, axis_size);
 | 
			
		||||
 | 
			
		||||
    std::stable_sort(st, ed);
 | 
			
		||||
    src_it.step();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -160,11 +163,15 @@ void argsort(const array& in, array& out, int axis) {
 | 
			
		||||
  int axis_size = in.shape(axis);
 | 
			
		||||
 | 
			
		||||
  // Perform sorting
 | 
			
		||||
  ContiguousIterator<size_t> in_it(
 | 
			
		||||
      in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
 | 
			
		||||
  ContiguousIterator<size_t> out_it(
 | 
			
		||||
      out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
 | 
			
		||||
  for (int i = 0; i < n_rows; i++) {
 | 
			
		||||
    size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
 | 
			
		||||
    size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
 | 
			
		||||
    const T* data_ptr = in.data<T>() + in_loc;
 | 
			
		||||
    IdxT* idx_ptr = out.data<IdxT>() + out_loc;
 | 
			
		||||
    const T* data_ptr = in.data<T>() + in_it.loc;
 | 
			
		||||
    IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
 | 
			
		||||
    in_it.step();
 | 
			
		||||
    out_it.step();
 | 
			
		||||
 | 
			
		||||
    StridedIterator st_(idx_ptr, out_stride, 0);
 | 
			
		||||
    StridedIterator ed_(idx_ptr, out_stride, axis_size);
 | 
			
		||||
@@ -192,7 +199,8 @@ void partition(const array& in, array& out, int axis, int kth) {
 | 
			
		||||
 | 
			
		||||
  // Get axis, shape and stride info
 | 
			
		||||
  axis = axis < 0 ? axis + in.ndim() : axis;
 | 
			
		||||
  size_t n_rows = in.size() / in.shape(axis);
 | 
			
		||||
  size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
 | 
			
		||||
  size_t n_rows = in_size / in.shape(axis);
 | 
			
		||||
 | 
			
		||||
  auto remaining_shape = in.shape();
 | 
			
		||||
  remaining_shape.erase(remaining_shape.begin() + axis);
 | 
			
		||||
@@ -206,9 +214,11 @@ void partition(const array& in, array& out, int axis, int kth) {
 | 
			
		||||
  kth = kth < 0 ? kth + axis_size : kth;
 | 
			
		||||
 | 
			
		||||
  // Perform partition in place
 | 
			
		||||
  ContiguousIterator<size_t> src_it(
 | 
			
		||||
      remaining_shape, remaining_strides, remaining_shape.size());
 | 
			
		||||
  for (int i = 0; i < n_rows; i++) {
 | 
			
		||||
    size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
 | 
			
		||||
    T* data_ptr = out.data<T>() + loc;
 | 
			
		||||
    T* data_ptr = out.data<T>() + src_it.loc;
 | 
			
		||||
    src_it.step();
 | 
			
		||||
 | 
			
		||||
    StridedIterator st(data_ptr, axis_stride, 0);
 | 
			
		||||
    StridedIterator md(data_ptr, axis_stride, kth);
 | 
			
		||||
@@ -227,37 +237,49 @@ void argpartition(const array& in, array& out, int axis, int kth) {
 | 
			
		||||
  axis = axis < 0 ? axis + in.ndim() : axis;
 | 
			
		||||
  size_t n_rows = in.size() / in.shape(axis);
 | 
			
		||||
 | 
			
		||||
  auto remaining_shape = in.shape();
 | 
			
		||||
  remaining_shape.erase(remaining_shape.begin() + axis);
 | 
			
		||||
  auto in_remaining_shape = in.shape();
 | 
			
		||||
  in_remaining_shape.erase(in_remaining_shape.begin() + axis);
 | 
			
		||||
 | 
			
		||||
  auto remaining_strides = in.strides();
 | 
			
		||||
  remaining_strides.erase(remaining_strides.begin() + axis);
 | 
			
		||||
  auto in_remaining_strides = in.strides();
 | 
			
		||||
  in_remaining_strides.erase(in_remaining_strides.begin() + axis);
 | 
			
		||||
 | 
			
		||||
  size_t axis_stride = in.strides()[axis];
 | 
			
		||||
  auto out_remaining_shape = out.shape();
 | 
			
		||||
  out_remaining_shape.erase(out_remaining_shape.begin() + axis);
 | 
			
		||||
 | 
			
		||||
  auto out_remaining_strides = out.strides();
 | 
			
		||||
  out_remaining_strides.erase(out_remaining_strides.begin() + axis);
 | 
			
		||||
 | 
			
		||||
  size_t in_stride = in.strides()[axis];
 | 
			
		||||
  size_t out_stride = out.strides()[axis];
 | 
			
		||||
  int axis_size = in.shape(axis);
 | 
			
		||||
 | 
			
		||||
  kth = kth < 0 ? kth + axis_size : kth;
 | 
			
		||||
 | 
			
		||||
  // Perform partition
 | 
			
		||||
  ContiguousIterator<size_t> in_it(
 | 
			
		||||
      in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
 | 
			
		||||
  ContiguousIterator<size_t> out_it(
 | 
			
		||||
      out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
 | 
			
		||||
  for (int i = 0; i < n_rows; i++) {
 | 
			
		||||
    size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
 | 
			
		||||
    const T* data_ptr = in.data<T>() + loc;
 | 
			
		||||
    IdxT* idx_ptr = out.data<IdxT>() + loc;
 | 
			
		||||
    const T* data_ptr = in.data<T>() + in_it.loc;
 | 
			
		||||
    IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
 | 
			
		||||
    in_it.step();
 | 
			
		||||
    out_it.step();
 | 
			
		||||
 | 
			
		||||
    StridedIterator st_(idx_ptr, axis_stride, 0);
 | 
			
		||||
    StridedIterator ed_(idx_ptr, axis_stride, axis_size);
 | 
			
		||||
    StridedIterator st_(idx_ptr, out_stride, 0);
 | 
			
		||||
    StridedIterator ed_(idx_ptr, out_stride, axis_size);
 | 
			
		||||
 | 
			
		||||
    // Initialize with iota
 | 
			
		||||
    std::iota(st_, ed_, IdxT(0));
 | 
			
		||||
 | 
			
		||||
    // Sort according to vals
 | 
			
		||||
    StridedIterator st(idx_ptr, axis_stride, 0);
 | 
			
		||||
    StridedIterator md(idx_ptr, axis_stride, kth);
 | 
			
		||||
    StridedIterator ed(idx_ptr, axis_stride, axis_size);
 | 
			
		||||
    StridedIterator st(idx_ptr, out_stride, 0);
 | 
			
		||||
    StridedIterator md(idx_ptr, out_stride, kth);
 | 
			
		||||
    StridedIterator ed(idx_ptr, out_stride, axis_size);
 | 
			
		||||
 | 
			
		||||
    std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
 | 
			
		||||
      auto v1 = data_ptr[a * axis_stride];
 | 
			
		||||
      auto v2 = data_ptr[b * axis_stride];
 | 
			
		||||
    std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
 | 
			
		||||
      auto v1 = data_ptr[a * in_stride];
 | 
			
		||||
      auto v2 = data_ptr[b * in_stride];
 | 
			
		||||
      return v1 < v2 || (v1 == v2 && a < b);
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
@@ -1954,6 +1954,17 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
                            M = top_k_mx.shape[axis or 0]
 | 
			
		||||
                            self.assertEqual(M, (kth + N) % N)
 | 
			
		||||
 | 
			
		||||
    def test_argpartition(self):
 | 
			
		||||
        x = mx.broadcast_to(mx.array([1, 2, 3]), (2, 3))
 | 
			
		||||
        out = mx.argpartition(x, kth=1, axis=0)
 | 
			
		||||
        expected = mx.array([[0, 0, 0], [1, 1, 1]])
 | 
			
		||||
        self.assertTrue(mx.array_equal(out, expected))
 | 
			
		||||
 | 
			
		||||
        x = mx.array([[1, 2], [3, 4]]).T
 | 
			
		||||
        out = mx.argpartition(x, kth=1, axis=0)
 | 
			
		||||
        expected = mx.array([[0, 0], [1, 1]])
 | 
			
		||||
        self.assertTrue(mx.array_equal(out, expected))
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(
 | 
			
		||||
        os.getenv("LOW_MEMORY", None) is not None,
 | 
			
		||||
        "This test requires a lot of memory",
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user