mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	fix argpartition + faster {arg} sorts / partitions (#1453)
This commit is contained in:
		| @@ -111,7 +111,8 @@ void sort(const array& in, array& out, int axis) { | ||||
|  | ||||
|   // Get axis, shape and stride info | ||||
|   axis = axis < 0 ? axis + in.ndim() : axis; | ||||
|   size_t n_rows = in.size() / in.shape(axis); | ||||
|   size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); | ||||
|   size_t n_rows = in_size / in.shape(axis); | ||||
|  | ||||
|   auto remaining_shape = out.shape(); | ||||
|   remaining_shape.erase(remaining_shape.begin() + axis); | ||||
| @@ -123,14 +124,16 @@ void sort(const array& in, array& out, int axis) { | ||||
|   int axis_size = out.shape(axis); | ||||
|  | ||||
|   // Perform sorting in place | ||||
|   ContiguousIterator<size_t> src_it( | ||||
|       remaining_shape, remaining_strides, remaining_shape.size()); | ||||
|   for (int i = 0; i < n_rows; i++) { | ||||
|     size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); | ||||
|     T* data_ptr = out.data<T>() + loc; | ||||
|     T* data_ptr = out.data<T>() + src_it.loc; | ||||
|  | ||||
|     StridedIterator st(data_ptr, axis_stride, 0); | ||||
|     StridedIterator ed(data_ptr, axis_stride, axis_size); | ||||
|  | ||||
|     std::stable_sort(st, ed); | ||||
|     src_it.step(); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @@ -160,11 +163,15 @@ void argsort(const array& in, array& out, int axis) { | ||||
|   int axis_size = in.shape(axis); | ||||
|  | ||||
|   // Perform sorting | ||||
|   ContiguousIterator<size_t> in_it( | ||||
|       in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); | ||||
|   ContiguousIterator<size_t> out_it( | ||||
|       out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); | ||||
|   for (int i = 0; i < n_rows; i++) { | ||||
|     size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides); | ||||
|     size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides); | ||||
|     const T* data_ptr = in.data<T>() + in_loc; | ||||
|     IdxT* idx_ptr = out.data<IdxT>() + out_loc; | ||||
|     const T* data_ptr = in.data<T>() + in_it.loc; | ||||
|     IdxT* idx_ptr = out.data<IdxT>() + out_it.loc; | ||||
|     in_it.step(); | ||||
|     out_it.step(); | ||||
|  | ||||
|     StridedIterator st_(idx_ptr, out_stride, 0); | ||||
|     StridedIterator ed_(idx_ptr, out_stride, axis_size); | ||||
| @@ -192,7 +199,8 @@ void partition(const array& in, array& out, int axis, int kth) { | ||||
|  | ||||
|   // Get axis, shape and stride info | ||||
|   axis = axis < 0 ? axis + in.ndim() : axis; | ||||
|   size_t n_rows = in.size() / in.shape(axis); | ||||
|   size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); | ||||
|   size_t n_rows = in_size / in.shape(axis); | ||||
|  | ||||
|   auto remaining_shape = in.shape(); | ||||
|   remaining_shape.erase(remaining_shape.begin() + axis); | ||||
| @@ -206,9 +214,11 @@ void partition(const array& in, array& out, int axis, int kth) { | ||||
|   kth = kth < 0 ? kth + axis_size : kth; | ||||
|  | ||||
|   // Perform partition in place | ||||
|   ContiguousIterator<size_t> src_it( | ||||
|       remaining_shape, remaining_strides, remaining_shape.size()); | ||||
|   for (int i = 0; i < n_rows; i++) { | ||||
|     size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); | ||||
|     T* data_ptr = out.data<T>() + loc; | ||||
|     T* data_ptr = out.data<T>() + src_it.loc; | ||||
|     src_it.step(); | ||||
|  | ||||
|     StridedIterator st(data_ptr, axis_stride, 0); | ||||
|     StridedIterator md(data_ptr, axis_stride, kth); | ||||
| @@ -227,37 +237,49 @@ void argpartition(const array& in, array& out, int axis, int kth) { | ||||
|   axis = axis < 0 ? axis + in.ndim() : axis; | ||||
|   size_t n_rows = in.size() / in.shape(axis); | ||||
|  | ||||
|   auto remaining_shape = in.shape(); | ||||
|   remaining_shape.erase(remaining_shape.begin() + axis); | ||||
|   auto in_remaining_shape = in.shape(); | ||||
|   in_remaining_shape.erase(in_remaining_shape.begin() + axis); | ||||
|  | ||||
|   auto remaining_strides = in.strides(); | ||||
|   remaining_strides.erase(remaining_strides.begin() + axis); | ||||
|   auto in_remaining_strides = in.strides(); | ||||
|   in_remaining_strides.erase(in_remaining_strides.begin() + axis); | ||||
|  | ||||
|   size_t axis_stride = in.strides()[axis]; | ||||
|   auto out_remaining_shape = out.shape(); | ||||
|   out_remaining_shape.erase(out_remaining_shape.begin() + axis); | ||||
|  | ||||
|   auto out_remaining_strides = out.strides(); | ||||
|   out_remaining_strides.erase(out_remaining_strides.begin() + axis); | ||||
|  | ||||
|   size_t in_stride = in.strides()[axis]; | ||||
|   size_t out_stride = out.strides()[axis]; | ||||
|   int axis_size = in.shape(axis); | ||||
|  | ||||
|   kth = kth < 0 ? kth + axis_size : kth; | ||||
|  | ||||
|   // Perform partition | ||||
|   ContiguousIterator<size_t> in_it( | ||||
|       in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); | ||||
|   ContiguousIterator<size_t> out_it( | ||||
|       out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); | ||||
|   for (int i = 0; i < n_rows; i++) { | ||||
|     size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); | ||||
|     const T* data_ptr = in.data<T>() + loc; | ||||
|     IdxT* idx_ptr = out.data<IdxT>() + loc; | ||||
|     const T* data_ptr = in.data<T>() + in_it.loc; | ||||
|     IdxT* idx_ptr = out.data<IdxT>() + out_it.loc; | ||||
|     in_it.step(); | ||||
|     out_it.step(); | ||||
|  | ||||
|     StridedIterator st_(idx_ptr, axis_stride, 0); | ||||
|     StridedIterator ed_(idx_ptr, axis_stride, axis_size); | ||||
|     StridedIterator st_(idx_ptr, out_stride, 0); | ||||
|     StridedIterator ed_(idx_ptr, out_stride, axis_size); | ||||
|  | ||||
|     // Initialize with iota | ||||
|     std::iota(st_, ed_, IdxT(0)); | ||||
|  | ||||
|     // Sort according to vals | ||||
|     StridedIterator st(idx_ptr, axis_stride, 0); | ||||
|     StridedIterator md(idx_ptr, axis_stride, kth); | ||||
|     StridedIterator ed(idx_ptr, axis_stride, axis_size); | ||||
|     StridedIterator st(idx_ptr, out_stride, 0); | ||||
|     StridedIterator md(idx_ptr, out_stride, kth); | ||||
|     StridedIterator ed(idx_ptr, out_stride, axis_size); | ||||
|  | ||||
|     std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) { | ||||
|       auto v1 = data_ptr[a * axis_stride]; | ||||
|       auto v2 = data_ptr[b * axis_stride]; | ||||
|     std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { | ||||
|       auto v1 = data_ptr[a * in_stride]; | ||||
|       auto v2 = data_ptr[b * in_stride]; | ||||
|       return v1 < v2 || (v1 == v2 && a < b); | ||||
|     }); | ||||
|   } | ||||
|   | ||||
| @@ -1954,6 +1954,17 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                             M = top_k_mx.shape[axis or 0] | ||||
|                             self.assertEqual(M, (kth + N) % N) | ||||
|  | ||||
|     def test_argpartition(self): | ||||
|         x = mx.broadcast_to(mx.array([1, 2, 3]), (2, 3)) | ||||
|         out = mx.argpartition(x, kth=1, axis=0) | ||||
|         expected = mx.array([[0, 0, 0], [1, 1, 1]]) | ||||
|         self.assertTrue(mx.array_equal(out, expected)) | ||||
|  | ||||
|         x = mx.array([[1, 2], [3, 4]]).T | ||||
|         out = mx.argpartition(x, kth=1, axis=0) | ||||
|         expected = mx.array([[0, 0], [1, 1]]) | ||||
|         self.assertTrue(mx.array_equal(out, expected)) | ||||
|  | ||||
|     @unittest.skipIf( | ||||
|         os.getenv("LOW_MEMORY", None) is not None, | ||||
|         "This test requires a lot of memory", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun