mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix argpartition + faster {arg} sorts / partitions (#1453)
This commit is contained in:
		| @@ -111,7 +111,8 @@ void sort(const array& in, array& out, int axis) { | |||||||
|  |  | ||||||
|   // Get axis, shape and stride info |   // Get axis, shape and stride info | ||||||
|   axis = axis < 0 ? axis + in.ndim() : axis; |   axis = axis < 0 ? axis + in.ndim() : axis; | ||||||
|   size_t n_rows = in.size() / in.shape(axis); |   size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); | ||||||
|  |   size_t n_rows = in_size / in.shape(axis); | ||||||
|  |  | ||||||
|   auto remaining_shape = out.shape(); |   auto remaining_shape = out.shape(); | ||||||
|   remaining_shape.erase(remaining_shape.begin() + axis); |   remaining_shape.erase(remaining_shape.begin() + axis); | ||||||
| @@ -123,14 +124,16 @@ void sort(const array& in, array& out, int axis) { | |||||||
|   int axis_size = out.shape(axis); |   int axis_size = out.shape(axis); | ||||||
|  |  | ||||||
|   // Perform sorting in place |   // Perform sorting in place | ||||||
|  |   ContiguousIterator<size_t> src_it( | ||||||
|  |       remaining_shape, remaining_strides, remaining_shape.size()); | ||||||
|   for (int i = 0; i < n_rows; i++) { |   for (int i = 0; i < n_rows; i++) { | ||||||
|     size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); |     T* data_ptr = out.data<T>() + src_it.loc; | ||||||
|     T* data_ptr = out.data<T>() + loc; |  | ||||||
|  |  | ||||||
|     StridedIterator st(data_ptr, axis_stride, 0); |     StridedIterator st(data_ptr, axis_stride, 0); | ||||||
|     StridedIterator ed(data_ptr, axis_stride, axis_size); |     StridedIterator ed(data_ptr, axis_stride, axis_size); | ||||||
|  |  | ||||||
|     std::stable_sort(st, ed); |     std::stable_sort(st, ed); | ||||||
|  |     src_it.step(); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -160,11 +163,15 @@ void argsort(const array& in, array& out, int axis) { | |||||||
|   int axis_size = in.shape(axis); |   int axis_size = in.shape(axis); | ||||||
|  |  | ||||||
|   // Perform sorting |   // Perform sorting | ||||||
|  |   ContiguousIterator<size_t> in_it( | ||||||
|  |       in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); | ||||||
|  |   ContiguousIterator<size_t> out_it( | ||||||
|  |       out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); | ||||||
|   for (int i = 0; i < n_rows; i++) { |   for (int i = 0; i < n_rows; i++) { | ||||||
|     size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides); |     const T* data_ptr = in.data<T>() + in_it.loc; | ||||||
|     size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides); |     IdxT* idx_ptr = out.data<IdxT>() + out_it.loc; | ||||||
|     const T* data_ptr = in.data<T>() + in_loc; |     in_it.step(); | ||||||
|     IdxT* idx_ptr = out.data<IdxT>() + out_loc; |     out_it.step(); | ||||||
|  |  | ||||||
|     StridedIterator st_(idx_ptr, out_stride, 0); |     StridedIterator st_(idx_ptr, out_stride, 0); | ||||||
|     StridedIterator ed_(idx_ptr, out_stride, axis_size); |     StridedIterator ed_(idx_ptr, out_stride, axis_size); | ||||||
| @@ -192,7 +199,8 @@ void partition(const array& in, array& out, int axis, int kth) { | |||||||
|  |  | ||||||
|   // Get axis, shape and stride info |   // Get axis, shape and stride info | ||||||
|   axis = axis < 0 ? axis + in.ndim() : axis; |   axis = axis < 0 ? axis + in.ndim() : axis; | ||||||
|   size_t n_rows = in.size() / in.shape(axis); |   size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); | ||||||
|  |   size_t n_rows = in_size / in.shape(axis); | ||||||
|  |  | ||||||
|   auto remaining_shape = in.shape(); |   auto remaining_shape = in.shape(); | ||||||
|   remaining_shape.erase(remaining_shape.begin() + axis); |   remaining_shape.erase(remaining_shape.begin() + axis); | ||||||
| @@ -206,9 +214,11 @@ void partition(const array& in, array& out, int axis, int kth) { | |||||||
|   kth = kth < 0 ? kth + axis_size : kth; |   kth = kth < 0 ? kth + axis_size : kth; | ||||||
|  |  | ||||||
|   // Perform partition in place |   // Perform partition in place | ||||||
|  |   ContiguousIterator<size_t> src_it( | ||||||
|  |       remaining_shape, remaining_strides, remaining_shape.size()); | ||||||
|   for (int i = 0; i < n_rows; i++) { |   for (int i = 0; i < n_rows; i++) { | ||||||
|     size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); |     T* data_ptr = out.data<T>() + src_it.loc; | ||||||
|     T* data_ptr = out.data<T>() + loc; |     src_it.step(); | ||||||
|  |  | ||||||
|     StridedIterator st(data_ptr, axis_stride, 0); |     StridedIterator st(data_ptr, axis_stride, 0); | ||||||
|     StridedIterator md(data_ptr, axis_stride, kth); |     StridedIterator md(data_ptr, axis_stride, kth); | ||||||
| @@ -227,37 +237,49 @@ void argpartition(const array& in, array& out, int axis, int kth) { | |||||||
|   axis = axis < 0 ? axis + in.ndim() : axis; |   axis = axis < 0 ? axis + in.ndim() : axis; | ||||||
|   size_t n_rows = in.size() / in.shape(axis); |   size_t n_rows = in.size() / in.shape(axis); | ||||||
|  |  | ||||||
|   auto remaining_shape = in.shape(); |   auto in_remaining_shape = in.shape(); | ||||||
|   remaining_shape.erase(remaining_shape.begin() + axis); |   in_remaining_shape.erase(in_remaining_shape.begin() + axis); | ||||||
|  |  | ||||||
|   auto remaining_strides = in.strides(); |   auto in_remaining_strides = in.strides(); | ||||||
|   remaining_strides.erase(remaining_strides.begin() + axis); |   in_remaining_strides.erase(in_remaining_strides.begin() + axis); | ||||||
|  |  | ||||||
|   size_t axis_stride = in.strides()[axis]; |   auto out_remaining_shape = out.shape(); | ||||||
|  |   out_remaining_shape.erase(out_remaining_shape.begin() + axis); | ||||||
|  |  | ||||||
|  |   auto out_remaining_strides = out.strides(); | ||||||
|  |   out_remaining_strides.erase(out_remaining_strides.begin() + axis); | ||||||
|  |  | ||||||
|  |   size_t in_stride = in.strides()[axis]; | ||||||
|  |   size_t out_stride = out.strides()[axis]; | ||||||
|   int axis_size = in.shape(axis); |   int axis_size = in.shape(axis); | ||||||
|  |  | ||||||
|   kth = kth < 0 ? kth + axis_size : kth; |   kth = kth < 0 ? kth + axis_size : kth; | ||||||
|  |  | ||||||
|   // Perform partition |   // Perform partition | ||||||
|  |   ContiguousIterator<size_t> in_it( | ||||||
|  |       in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); | ||||||
|  |   ContiguousIterator<size_t> out_it( | ||||||
|  |       out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); | ||||||
|   for (int i = 0; i < n_rows; i++) { |   for (int i = 0; i < n_rows; i++) { | ||||||
|     size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); |     const T* data_ptr = in.data<T>() + in_it.loc; | ||||||
|     const T* data_ptr = in.data<T>() + loc; |     IdxT* idx_ptr = out.data<IdxT>() + out_it.loc; | ||||||
|     IdxT* idx_ptr = out.data<IdxT>() + loc; |     in_it.step(); | ||||||
|  |     out_it.step(); | ||||||
|  |  | ||||||
|     StridedIterator st_(idx_ptr, axis_stride, 0); |     StridedIterator st_(idx_ptr, out_stride, 0); | ||||||
|     StridedIterator ed_(idx_ptr, axis_stride, axis_size); |     StridedIterator ed_(idx_ptr, out_stride, axis_size); | ||||||
|  |  | ||||||
|     // Initialize with iota |     // Initialize with iota | ||||||
|     std::iota(st_, ed_, IdxT(0)); |     std::iota(st_, ed_, IdxT(0)); | ||||||
|  |  | ||||||
|     // Sort according to vals |     // Sort according to vals | ||||||
|     StridedIterator st(idx_ptr, axis_stride, 0); |     StridedIterator st(idx_ptr, out_stride, 0); | ||||||
|     StridedIterator md(idx_ptr, axis_stride, kth); |     StridedIterator md(idx_ptr, out_stride, kth); | ||||||
|     StridedIterator ed(idx_ptr, axis_stride, axis_size); |     StridedIterator ed(idx_ptr, out_stride, axis_size); | ||||||
|  |  | ||||||
|     std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) { |     std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { | ||||||
|       auto v1 = data_ptr[a * axis_stride]; |       auto v1 = data_ptr[a * in_stride]; | ||||||
|       auto v2 = data_ptr[b * axis_stride]; |       auto v2 = data_ptr[b * in_stride]; | ||||||
|       return v1 < v2 || (v1 == v2 && a < b); |       return v1 < v2 || (v1 == v2 && a < b); | ||||||
|     }); |     }); | ||||||
|   } |   } | ||||||
|   | |||||||
| @@ -1954,6 +1954,17 @@ class TestOps(mlx_tests.MLXTestCase): | |||||||
|                             M = top_k_mx.shape[axis or 0] |                             M = top_k_mx.shape[axis or 0] | ||||||
|                             self.assertEqual(M, (kth + N) % N) |                             self.assertEqual(M, (kth + N) % N) | ||||||
|  |  | ||||||
|  |     def test_argpartition(self): | ||||||
|  |         x = mx.broadcast_to(mx.array([1, 2, 3]), (2, 3)) | ||||||
|  |         out = mx.argpartition(x, kth=1, axis=0) | ||||||
|  |         expected = mx.array([[0, 0, 0], [1, 1, 1]]) | ||||||
|  |         self.assertTrue(mx.array_equal(out, expected)) | ||||||
|  |  | ||||||
|  |         x = mx.array([[1, 2], [3, 4]]).T | ||||||
|  |         out = mx.argpartition(x, kth=1, axis=0) | ||||||
|  |         expected = mx.array([[0, 0], [1, 1]]) | ||||||
|  |         self.assertTrue(mx.array_equal(out, expected)) | ||||||
|  |  | ||||||
|     @unittest.skipIf( |     @unittest.skipIf( | ||||||
|         os.getenv("LOW_MEMORY", None) is not None, |         os.getenv("LOW_MEMORY", None) is not None, | ||||||
|         "This test requires a lot of memory", |         "This test requires a lot of memory", | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun