mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 22:01:17 +08:00
fix argpartition + faster {arg} sorts / partitions (#1453)
This commit is contained in:
parent
5523d9c426
commit
1bdc038bf9
@ -111,7 +111,8 @@ void sort(const array& in, array& out, int axis) {
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@ -123,14 +124,16 @@ void sort(const array& in, array& out, int axis) {
|
||||
int axis_size = out.shape(axis);
|
||||
|
||||
// Perform sorting in place
|
||||
ContiguousIterator<size_t> src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed);
|
||||
src_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
@ -160,11 +163,15 @@ void argsort(const array& in, array& out, int axis) {
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting
|
||||
ContiguousIterator<size_t> in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator<size_t> out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
|
||||
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + in_loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
@ -192,7 +199,8 @@ void partition(const array& in, array& out, int axis, int kth) {
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@ -206,9 +214,11 @@ void partition(const array& in, array& out, int axis, int kth) {
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
ContiguousIterator<size_t> src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
T* data_ptr = out.data<T>() + src_it.loc;
|
||||
src_it.step();
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
@ -227,37 +237,49 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
auto in_remaining_strides = in.strides();
|
||||
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
auto out_remaining_shape = out.shape();
|
||||
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
|
||||
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition
|
||||
ContiguousIterator<size_t> in_it(
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator<size_t> out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||
StridedIterator md(idx_ptr, axis_stride, kth);
|
||||
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator md(idx_ptr, out_stride, kth);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * axis_stride];
|
||||
auto v2 = data_ptr[b * axis_stride];
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
|
@ -1954,6 +1954,17 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
M = top_k_mx.shape[axis or 0]
|
||||
self.assertEqual(M, (kth + N) % N)
|
||||
|
||||
def test_argpartition(self):
|
||||
x = mx.broadcast_to(mx.array([1, 2, 3]), (2, 3))
|
||||
out = mx.argpartition(x, kth=1, axis=0)
|
||||
expected = mx.array([[0, 0, 0], [1, 1, 1]])
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
x = mx.array([[1, 2], [3, 4]]).T
|
||||
out = mx.argpartition(x, kth=1, axis=0)
|
||||
expected = mx.array([[0, 0], [1, 1]])
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
@unittest.skipIf(
|
||||
os.getenv("LOW_MEMORY", None) is not None,
|
||||
"This test requires a lot of memory",
|
||||
|
Loading…
Reference in New Issue
Block a user